SyterKit 0.4.0.x
SyterKit is a bare-metal framework
Loading...
Searching...
No Matches
tinymaix.h
Go to the documentation of this file.
1/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
2Licensed under the Apache License, Version 2.0 (the "License");
3you may not use this file except in compliance with the License.
4You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6Unless required by applicable law or agreed to in writing, software
7distributed under the License is distributed on an "AS IS" BASIS,
8WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9See the License for the specific language governing permissions and
10limitations under the License.
11==============================================================================*/
12
13#ifndef __TINYMAIX_H
14#define __TINYMAIX_H
15
16#include <stdint.h>
17
18#include <stdlib.h>
19#include <string.h>
20
21#define TM_MDL_INT8 0
22#define TM_MDL_INT16 1
23#define TM_MDL_FP32 2
24#define TM_MDL_FP16 3
25#define TM_MDL_FP8_143 4//experimental
26#define TM_MDL_FP8_152 5//experimental
27#include "tm_port.h"
28
29
30#define TM_MDL_MAGIC 0x5849414d//mdl magic sign
31#define TM_ALIGN_SIZE (8) //8 byte align
32#define TM_ALIGN(addr) ((((size_t) (addr)) + (TM_ALIGN_SIZE - 1)) / TM_ALIGN_SIZE * TM_ALIGN_SIZE)
33#define TM_MATP(mat, y, x, ch) ((mat)->data + ((y) * (mat)->w + (x)) * (mat)->c + (ch))
34//HWC
35#if TM_MDL_TYPE == TM_MDL_INT8
36typedef int8_t mtype_t; //mat data type
37typedef int8_t wtype_t; //weight data type
38typedef int32_t btype_t; //bias data type
39typedef int32_t sumtype_t;//sum data type
40typedef int32_t zptype_t; //zeropoint data type
41#define UINT2INT_SHIFT (0)
42#elif TM_MDL_TYPE == TM_MDL_INT16
43typedef int16_t mtype_t; //mat data type
44typedef int16_t wtype_t; //weight data type
45typedef int32_t btype_t; //bias data type
46typedef int32_t sumtype_t;//sum data type
47typedef int32_t zptype_t; //zeropoint data type
48#define UINT2INT_SHIFT (8)
49#elif TM_MDL_TYPE == TM_MDL_FP32
50typedef float mtype_t; //mat data type
51typedef float wtype_t; //weight data type
52typedef float btype_t; //bias data type
53typedef float sumtype_t;//sum data type
54typedef float zptype_t; //zeropoint data type
55#elif TM_MDL_TYPE == TM_MDL_FP16
56#if TM_ARCH != TM_ARCH_RV64V
57#error "only support RV64V's float16!"
58#endif
59#include <riscv_vector.h>
60typedef float16_t mtype_t; //mat data type
61typedef float16_t wtype_t; //weight data type
62typedef float16_t btype_t; //bias data type
63typedef float16_t sumtype_t;//sum data type
64typedef float16_t zptype_t; //zeropoint data type
65#elif (TM_MDL_TYPE == TM_MDL_FP8_143) || (TM_MDL_TYPE == TM_MDL_FP8_152)
66#if TM_ARCH != TM_ARCH_CPU
67#error "only support CPU simulation now!"
68#endif
69typedef uint8_t mtype_t;//mat data type
70typedef uint8_t wtype_t;//weight data type
71typedef uint8_t btype_t;//bias data type
72typedef float sumtype_t;//sum data type
73typedef float zptype_t; //zeropoint data type
74#else
75#error "Not support this MDL_TYPE!"
76#endif
77
78#if TM_MDL_TYPE == TM_MDL_FP8_143
79#define TM_FP8_SCNT (1)
80#define TM_FP8_ECNT (4)
81#define TM_FP8_MCNT (3)
82#define TM_FP8_BIAS (9)
83#elif TM_MDL_TYPE == TM_MDL_FP8_152
84#define TM_FP8_SCNT (1)
85#define TM_FP8_ECNT (5)
86#define TM_FP8_MCNT (2)
87#define TM_FP8_BIAS (15)
88#endif
89
90typedef float sctype_t;
91#define TM_FASTSCALE_SHIFT (8)
92
93
106
117
118typedef enum {
122
132
133
134typedef enum {
136 TMPP_FP2INT = 1, //user own fp buf -> int input buf
137 TMPP_UINT2INT = 2, //int8: cvt in place; int16: can't cvt in place
138 TMPP_UINT2FP01 = 3, // u8/255.0
139 TMPP_UINT2FPN11 = 4,// (u8-128)/128
140 TMPP_UINT2DTYPE = 5,//uint8 to fp16,fp8
142} tm_pp_t;
143
144
145//mdlbin in flash
146typedef struct {
147 uint32_t magic; //"MAIX"
148 uint8_t mdl_type; //0 int8, 1 int16, 2 fp32,
149 uint8_t out_deq; //0 don't dequant out; 1 dequant out
150 uint16_t input_cnt; //only support 1 yet
151 uint16_t output_cnt;//only support 1 yet
153 uint32_t buf_size; //main buf size for middle result = pingpong+keep
154 uint32_t sub_size; //pingpong buf size;
155 uint16_t in_dims[4];//0:dims; 1:dim0; 2:dim1; 3:dim2
156 uint16_t out_dims[4];
157 uint8_t reserve[28]; //reserve for future
158 uint8_t layers_body[0];//oft 64 here
160
161//mdl meta data in ram
162typedef struct {
164 void *cb; //Layer callback
165 uint8_t *buf; //main buf addr
166 uint8_t *subbuf; //sub buf addr
167 uint16_t main_alloc;//is main buf alloc or static
168 uint16_t layer_i; //current layer index
169 uint8_t *layer_body;//current layer body addr
170} tm_mdl_t;
171
172//dims==3, hwc
173//dims==2, 1wc
174//dims==1, 11c
175typedef struct {
180 union {
182 float *dataf;
183 };
184} tm_mat_t;
185
186
187typedef struct { //48byte
188 uint16_t type; //layer type
189 uint16_t is_out; //is output
190 uint32_t size; //8 byte align size for this layer
191 uint32_t in_oft; //input oft in main buf
192 uint32_t out_oft; //output oft in main buf
193 uint16_t in_dims[4];//0:dims; 1:dim0; 2:dim1; 3:dim2
194 uint16_t out_dims[4];
195 //following unit not used in fp32 mode
196 sctype_t in_s; //input scale,
197 zptype_t in_zp; //input zeropoint
198 sctype_t out_s; //output scale
199 zptype_t out_zp;//output zeropoint
200 //note: real = scale*(q-zeropoint)
201} tml_head_t;
202
203typedef struct {
205
210
213 uint16_t act;//0 none, 1 relu, 2 relu1, 3 relu6, 4 tanh, 5 sign_bit
214
215 uint8_t pad[4];//top,bottom,left,right
216
217 uint32_t depth_mul;//depth_multiplier: if conv2d,=0; else: >=1
218 uint32_t reserve; //for 8byte align
219
220 uint32_t ws_oft;//weight scale oft from this layer start
221 //skip bias scale: bias_scale = weight_scale*in_scale
222 uint32_t w_oft; //weight oft from this layer start
223 uint32_t b_oft; //bias oft from this layer start
224 //note: bias[c] = bias[c] + (-out_zp)*sum(w[c*chi*maxk:(c+1)*chi*maxk])
225 // fused in advance (when convert model)
226} tml_conv2d_dw_t; //compatible with conv2d and dwconv2d
227
228typedef struct {
230} tml_gap_t;
231
232typedef struct {
234
235 uint32_t ws_oft; //weight scale oft from this layer start
236 uint32_t w_oft; //weight oft from this layer start
237 uint32_t b_oft; //bias oft from this layer start
238 uint32_t reserve;//for 8byte align
239} tml_fc_t;
240
241typedef struct {
244
245typedef struct {
248
249typedef struct {
251
256
259 uint16_t act;//0 none, 1 relu, 2 relu1, 3 relu6, 4 tanh, 5 sign_bit
260
261 uint8_t pad[4];//top,bottom,left,right
262
263
264 uint32_t ws_oft;//weight scale oft from this layer start
265 //skip bias scale: bias_scale = weight_scale*in_scale
266 uint32_t w_oft; //weight oft from this layer start
267 uint32_t b_oft; //bias oft from this layer start
268 //note: bias[c] = bias[c] + (-out_zp)*sum(w[c*chi*maxk:(c+1)*chi*maxk])
269 // fused in advance (when convert model)
271
272typedef struct {
275 sctype_t in_s1; //input scale,
276 zptype_t in_zp1; //input zeropoint
278} tml_add_t;
279
280
281
282typedef tm_err_t (*tml_stat_t)(tml_head_t *layer, tm_mat_t *in, tm_mat_t *out);
283typedef tm_err_t (*tm_cb_t)(tm_mdl_t *mdl, tml_head_t *lh);
284
285
286
289
290tm_err_t tm_load(tm_mdl_t *mdl, const uint8_t *bin, uint8_t *buf, tm_cb_t cb, tm_mat_t *in);//load model
291void tm_unload(tm_mdl_t *mdl); //remove model
292tm_err_t tm_preprocess(tm_mdl_t *mdl, tm_pp_t pp_type, tm_mat_t *in, tm_mat_t *out); //preprocess input data
293tm_err_t tm_run(tm_mdl_t *mdl, tm_mat_t *in, tm_mat_t *out); //run model
294
295
296
297tm_err_t tml_conv2d_dwconv2d(tm_mat_t *in, tm_mat_t *out, wtype_t *w, btype_t *b, int kw, int kh, int sx, int sy, int dx, int dy, int act, int pad_top, int pad_bottom,
298 int pad_left, int pad_right, int dmul, sctype_t *ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
299tm_err_t tml_gap(tm_mat_t *in, tm_mat_t *out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
300tm_err_t tml_fc(tm_mat_t *in, tm_mat_t *out, wtype_t *w, btype_t *b, sctype_t *ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
301tm_err_t tml_softmax(tm_mat_t *in, tm_mat_t *out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
302tm_err_t tml_reshape(tm_mat_t *in, tm_mat_t *out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
303tm_err_t tml_add(tm_mat_t *in0, tm_mat_t *in1, tm_mat_t *out, sctype_t in_s0, zptype_t in_zp0, sctype_t in_s1, zptype_t in_zp1, sctype_t out_s, zptype_t out_zp);
304
305
306#if TM_ENABLE_STAT
307tm_err_t tm_stat(tm_mdlbin_t *mdl);//stat model
308#endif
309
310
313
314
315
317#define TML_GET_INPUT(mdl, lh) ((mtype_t *) ((mdl)->buf + (lh)->in_oft))
318#define TML_GET_OUTPUT(mdl, lh) ((mtype_t *) ((mdl)->buf + (lh)->out_oft))
319#if (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16)
320#define TML_DEQUANT(lh, x) (((sumtype_t) (x) - ((lh)->out_zp)) * ((lh)->out_s))
321#define TM_DEQUANT(i8, s, zp) (((sumtype_t) (i8) - (zp)) * (s))
322#define TM_QUANT(fp32, s, zp) ((mtype_t) ((fp32) / (s) + zp))
323#elif (TM_MDL_TYPE == TM_MDL_FP8_143) || (TM_MDL_TYPE == TM_MDL_FP8_152)
324#define TML_DEQUANT(lh, x) (tm_fp8to32(x))
325#else//FP32,FP16
326#define TML_DEQUANT(lh, x) ((float) (x))
327#define TM_DEQUANT(x, s, zp) (x)
328#define TM_QUANT(x, s, zp) (x)
329#endif
330
331
332#if TM_LOCAL_MATH
333//http://www.machinedlearnings.com/2011/06/fast-approximate-logarithm-exponential.html
334static inline float _exp(float x) {
335 float p = 1.442695040f * x;
336 uint32_t i = 0;
337 uint32_t sign = (i >> 31);
338 int w = (int) p;
339 float z = p - (float) w + (float) sign;
340 union {
341 uint32_t i;
342 float f;
343 } v = {.i = (uint32_t) ((1 << 23) * (p + 121.2740838f + 27.7280233f / (4.84252568f - z) - 1.49012907f * z))};
344 return v.f;
345}
346#define tm_exp _exp//maybe some arch have exp acceleration, use macro in arch_xxx.h to reload it
347#else
348#define tm_exp exp
349#endif
350
351#endif
u32_t uint32_t
Definition stdint.h:13
s16_t int16_t
Definition stdint.h:9
s8_t int8_t
Definition stdint.h:6
s32_t int32_t
Definition stdint.h:12
u8_t uint8_t
Definition stdint.h:7
u16_t uint16_t
Definition stdint.h:10
Definition tinymaix.h:175
uint16_t dims
Definition tinymaix.h:176
uint16_t h
Definition tinymaix.h:177
mtype_t * data
Definition tinymaix.h:181
float * dataf
Definition tinymaix.h:182
uint16_t c
Definition tinymaix.h:179
uint16_t w
Definition tinymaix.h:178
Definition tinymaix.h:162
void * cb
Definition tinymaix.h:164
uint8_t * layer_body
Definition tinymaix.h:169
uint16_t main_alloc
Definition tinymaix.h:167
uint16_t layer_i
Definition tinymaix.h:168
uint8_t * buf
Definition tinymaix.h:165
uint8_t * subbuf
Definition tinymaix.h:166
tm_mdlbin_t * b
Definition tinymaix.h:163
Definition tinymaix.h:146
uint16_t input_cnt
Definition tinymaix.h:150
uint32_t magic
Definition tinymaix.h:147
uint16_t layer_cnt
Definition tinymaix.h:152
uint32_t buf_size
Definition tinymaix.h:153
uint16_t output_cnt
Definition tinymaix.h:151
uint32_t sub_size
Definition tinymaix.h:154
uint8_t out_deq
Definition tinymaix.h:149
uint8_t mdl_type
Definition tinymaix.h:148
Definition tinymaix.h:272
tml_head_t h
Definition tinymaix.h:273
sctype_t in_s1
Definition tinymaix.h:275
zptype_t in_zp1
Definition tinymaix.h:276
uint32_t reserve
Definition tinymaix.h:277
uint32_t in_oft1
Definition tinymaix.h:274
Definition tinymaix.h:203
tml_head_t h
Definition tinymaix.h:204
uint8_t kernel_w
Definition tinymaix.h:206
uint16_t act
Definition tinymaix.h:213
uint8_t kernel_h
Definition tinymaix.h:207
uint32_t b_oft
Definition tinymaix.h:223
uint32_t ws_oft
Definition tinymaix.h:220
uint8_t stride_w
Definition tinymaix.h:208
uint32_t w_oft
Definition tinymaix.h:222
uint32_t depth_mul
Definition tinymaix.h:217
uint8_t dilation_h
Definition tinymaix.h:212
uint8_t stride_h
Definition tinymaix.h:209
uint32_t reserve
Definition tinymaix.h:218
uint8_t dilation_w
Definition tinymaix.h:211
Definition tinymaix.h:249
uint8_t stride_h
Definition tinymaix.h:255
uint16_t act
Definition tinymaix.h:259
uint8_t dilation_w
Definition tinymaix.h:257
uint32_t b_oft
Definition tinymaix.h:267
uint8_t kernel_h
Definition tinymaix.h:253
uint8_t stride_w
Definition tinymaix.h:254
uint8_t dilation_h
Definition tinymaix.h:258
uint32_t ws_oft
Definition tinymaix.h:264
tml_head_t h
Definition tinymaix.h:250
uint32_t w_oft
Definition tinymaix.h:266
uint8_t kernel_w
Definition tinymaix.h:252
Definition tinymaix.h:232
uint32_t w_oft
Definition tinymaix.h:236
uint32_t b_oft
Definition tinymaix.h:237
tml_head_t h
Definition tinymaix.h:233
uint32_t ws_oft
Definition tinymaix.h:235
uint32_t reserve
Definition tinymaix.h:238
Definition tinymaix.h:228
tml_head_t h
Definition tinymaix.h:229
Definition tinymaix.h:187
uint32_t in_oft
Definition tinymaix.h:191
zptype_t out_zp
Definition tinymaix.h:199
sctype_t out_s
Definition tinymaix.h:198
uint16_t type
Definition tinymaix.h:188
uint16_t is_out
Definition tinymaix.h:189
sctype_t in_s
Definition tinymaix.h:196
zptype_t in_zp
Definition tinymaix.h:197
uint32_t out_oft
Definition tinymaix.h:192
uint32_t size
Definition tinymaix.h:190
Definition tinymaix.h:245
tml_head_t h
Definition tinymaix.h:246
Definition tinymaix.h:241
tml_head_t h
Definition tinymaix.h:242
tm_err_t tm_load(tm_mdl_t *mdl, const uint8_t *bin, uint8_t *buf, tm_cb_t cb, tm_mat_t *in)
Definition tm_model.c:17
tm_err_t tm_stat(tm_mdlbin_t *mdl)
Definition tm_stat.c:45
tm_err_t(* tml_stat_t)(tml_head_t *layer, tm_mat_t *in, tm_mat_t *out)
Definition tinymaix.h:282
tm_act_type_t
Definition tinymaix.h:123
@ TM_ACT_NONE
Definition tinymaix.h:124
@ TM_ACT_MAXCNT
Definition tinymaix.h:130
@ TM_ACT_RELU
Definition tinymaix.h:125
@ TM_ACT_RELU1
Definition tinymaix.h:126
@ TM_ACT_TANH
Definition tinymaix.h:128
@ TM_ACT_SIGNBIT
Definition tinymaix.h:129
@ TM_ACT_RELU6
Definition tinymaix.h:127
tm_err_t tml_gap(tm_mat_t *in, tm_mat_t *out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
Definition tm_layers.c:233
tm_err_t tml_reshape(tm_mat_t *in, tm_mat_t *out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
Definition tm_layers.c:305
uint8_t TM_WEAK tm_fp32to8(float fp32)
int32_t btype_t
Definition tinymaix.h:38
void tm_unload(tm_mdl_t *mdl)
Definition tm_model.c:48
float sctype_t
Definition tinymaix.h:90
tm_err_t
Definition tinymaix.h:94
@ TM_ERR_KSIZE
Definition tinymaix.h:104
@ TM_ERR_TODO
Definition tinymaix.h:102
@ TM_OK
Definition tinymaix.h:95
@ TM_ERR
Definition tinymaix.h:96
@ TM_ERR_MDLTYPE
Definition tinymaix.h:103
@ TM_ERR_LAYERTYPE
Definition tinymaix.h:100
@ TM_ERR_DIMS
Definition tinymaix.h:101
@ TM_ERR_UNSUPPORT
Definition tinymaix.h:98
@ TM_ERR_OOM
Definition tinymaix.h:99
@ TM_ERR_MAGIC
Definition tinymaix.h:97
tm_err_t tm_run(tm_mdl_t *mdl, tm_mat_t *in, tm_mat_t *out)
Definition tm_model.c:86
tm_err_t tml_fc(tm_mat_t *in, tm_mat_t *out, wtype_t *w, btype_t *b, sctype_t *ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
Definition tm_layers.c:256
int8_t wtype_t
Definition tinymaix.h:37
tm_err_t tm_preprocess(tm_mdl_t *mdl, tm_pp_t pp_type, tm_mat_t *in, tm_mat_t *out)
Definition tm_model.c:55
int32_t zptype_t
Definition tinymaix.h:40
tm_err_t tml_add(tm_mat_t *in0, tm_mat_t *in1, tm_mat_t *out, sctype_t in_s0, zptype_t in_zp0, sctype_t in_s1, zptype_t in_zp1, sctype_t out_s, zptype_t out_zp)
Definition tm_layers.c:311
int8_t mtype_t
Definition tinymaix.h:36
int32_t sumtype_t
Definition tinymaix.h:39
tm_layer_type_t
Definition tinymaix.h:107
@ TML_CONV2D
Definition tinymaix.h:108
@ TML_DWCONV2D
Definition tinymaix.h:113
@ TML_FC
Definition tinymaix.h:110
@ TML_SOFTMAX
Definition tinymaix.h:111
@ TML_ADD
Definition tinymaix.h:114
@ TML_GAP
Definition tinymaix.h:109
@ TML_MAXCNT
Definition tinymaix.h:115
@ TML_RESHAPE
Definition tinymaix.h:112
float TM_WEAK tm_fp8to32(uint8_t fp8)
tm_err_t tml_conv2d_dwconv2d(tm_mat_t *in, tm_mat_t *out, wtype_t *w, btype_t *b, int kw, int kh, int sx, int sy, int dx, int dy, int act, int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, sctype_t *ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
Definition tm_layers.c:68
tm_err_t tml_softmax(tm_mat_t *in, tm_mat_t *out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
Definition tm_layers.c:273
tm_err_t(* tm_cb_t)(tm_mdl_t *mdl, tml_head_t *lh)
Definition tinymaix.h:283
tm_pad_type_t
Definition tinymaix.h:118
@ TM_PAD_VALID
Definition tinymaix.h:119
@ TM_PAD_SAME
Definition tinymaix.h:120
tm_pp_t
Definition tinymaix.h:134
@ TMPP_UINT2FP01
Definition tinymaix.h:138
@ TMPP_UINT2INT
Definition tinymaix.h:137
@ TMPP_FP2INT
Definition tinymaix.h:136
@ TMPP_NONE
Definition tinymaix.h:135
@ TMPP_UINT2DTYPE
Definition tinymaix.h:140
@ TMPP_MAXCNT
Definition tinymaix.h:141
@ TMPP_UINT2FPN11
Definition tinymaix.h:139
static float _exp(float x)
Definition tinymaix.h:334
#define TM_WEAK
Definition tm_port.h:44