SyterKit 0.4.0.x
SyterKit is a bare-metal framework
Loading...
Searching...
No Matches
arch_cpu.h
Go to the documentation of this file.
1/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
2Licensed under the Apache License, Version 2.0 (the "License");
3you may not use this file except in compliance with the License.
4You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6Unless required by applicable law or agreed to in writing, software
7distributed under the License is distributed on an "AS IS" BASIS,
8WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9See the License for the specific language governing permissions and
10limitations under the License.
11==============================================================================*/
12
13#include "stdlib.h"
14#include "stdint.h"
15#include "tinymaix.h"
16
17#if (TM_MDL_TYPE != TM_MDL_FP8_143) && (TM_MDL_TYPE != TM_MDL_FP8_152)
18//sum = SUM(Ai*Bi)
19TM_INLINE void tm_dot_prod(mtype_t *sptr, mtype_t *kptr, uint32_t size, sumtype_t *result) {
20 sumtype_t sum = 0;
21 uint32_t i = 0;
22 uint32_t cnt = (size >> 3) << 3;//8
23 for (; i + 8 - 1 < cnt;) {
24 sum += sptr[i] * kptr[i];
25 i++;
26 sum += sptr[i] * kptr[i];
27 i++;
28 sum += sptr[i] * kptr[i];
29 i++;
30 sum += sptr[i] * kptr[i];
31 i++;
32 sum += sptr[i] * kptr[i];
33 i++;
34 sum += sptr[i] * kptr[i];
35 i++;
36 sum += sptr[i] * kptr[i];
37 i++;
38 sum += sptr[i] * kptr[i];
39 i++;
40 }
41 for (; i < size; i++) { sum += sptr[i] * kptr[i]; }
42 *result = sum;
43 return;
44}
45
46TM_INLINE void tm_dot_prod_pack2(mtype_t *sptr, mtype_t *kptr, uint32_t size, sumtype_t *result) {
47 sumtype_t sum0 = 0;
48 sumtype_t sum1 = 0;
49 mtype_t *kptr0 = kptr;
50 mtype_t *kptr1 = kptr + size;
51
52 uint32_t i = 0;
53 uint32_t cnt = (size >> 3) << 3;//8
54 for (; i + 8 - 1 < cnt;) {
55 sum0 += sptr[i] * kptr0[i];
56 sum1 += sptr[i] * kptr1[i];
57 i++;
58 sum0 += sptr[i] * kptr0[i];
59 sum1 += sptr[i] * kptr1[i];
60 i++;
61 sum0 += sptr[i] * kptr0[i];
62 sum1 += sptr[i] * kptr1[i];
63 i++;
64 sum0 += sptr[i] * kptr0[i];
65 sum1 += sptr[i] * kptr1[i];
66 i++;
67 sum0 += sptr[i] * kptr0[i];
68 sum1 += sptr[i] * kptr1[i];
69 i++;
70 sum0 += sptr[i] * kptr0[i];
71 sum1 += sptr[i] * kptr1[i];
72 i++;
73 sum0 += sptr[i] * kptr0[i];
74 sum1 += sptr[i] * kptr1[i];
75 i++;
76 sum0 += sptr[i] * kptr0[i];
77 sum1 += sptr[i] * kptr1[i];
78 i++;
79 }
80 for (; i < size; i++) {
81 sum0 += sptr[i] * kptr0[i];
82 sum1 += sptr[i] * kptr1[i];
83 }
84
85 result[0] = sum0;
86 result[1] = sum1;
87 return;
88}
89
91 *result = sptr[k_oft[0]] * kptr[0] + sptr[k_oft[1]] * kptr[1] + sptr[k_oft[2]] * kptr[2] + sptr[k_oft[3]] * kptr[3] + sptr[k_oft[4]] * kptr[4] + sptr[k_oft[5]] * kptr[5] +
92 sptr[k_oft[6]] * kptr[6] + sptr[k_oft[7]] * kptr[7] + sptr[k_oft[8]] * kptr[8];
93 return;
94}
95
97 *result = sptr[0] * kptr[0] + sptr[1] * kptr[1] + sptr[2] * kptr[2] + sptr[3] * kptr[3] + sptr[4] * kptr[4] + sptr[5] * kptr[5] + sptr[6] * kptr[6] + sptr[7] * kptr[7] +
98 sptr[8] * kptr[8];
99 return;
100}
101
102
103#else
104
105#define SUMSCALE 1.0
106
107TM_INLINE void tm_dot_prod(mtype_t *sptr, mtype_t *kptr, uint32_t size, sumtype_t *result) {
108 sumtype_t sum = 0;
109 for (int i = 0; i < size; i++) {
110 float _s = tm_fp8to32(sptr[i]);
111 float _k = tm_fp8to32(kptr[i]);
112 sum += _s * _k;
113 //printf("%.3f*%.3f+",_s,_k);
114 }
115 //printf("\r\n");
116 *result = sum;
117 return;
118}
119
120TM_INLINE void tm_postprocess_sum(sumtype_t sum, btype_t b, int act, mtype_t *outp, sctype_t scale, sctype_t out_s, zptype_t out_zp) {//printf("sum=%.6f,", sum);
121 sum += tm_fp8to32(b); //printf("%.6f,", sum);
122 switch (act) { //activation func
123 case TM_ACT_RELU:
124 sum = sum > 0 ? sum : 0;
125 break;
126 case TM_ACT_RELU6:
127 sum = sum > 0 ? sum : 0;
128 sum = sum > 6 ? 6 : sum;
129 break;
130 default:
131 break;
132 }
133 //printf("%.6f,", sum);
134 *outp = tm_fp32to8(sum);
135 //printf(" %02x,%.6f\r\n", *outp, tm_fp8to32(*outp));
136 return;
137}
138
139#endif
140
141#if (TM_MDL_TYPE == TM_MDL_FP32) || (TM_MDL_TYPE == TM_MDL_FP16)
142
143TM_INLINE void tm_postprocess_sum(int n, sumtype_t *sums, btype_t *bs, int act, mtype_t *outp, sctype_t *scales, sctype_t out_s, zptype_t out_zp) {
144 for (int i = 0; i < n; i++) {
145 sumtype_t sum = sums[i];
146 sum += bs[i];
147 switch (act) {//activation func
148 case TM_ACT_RELU:
149 case TM_ACT_RELU6://treat relu6 as relu in float mode //speed up
150 sum = sum > 0 ? sum : 0;
151 break;
152 // sum = sum>0?sum:0;
153 // sum = sum>6?6:sum;
154 // break;
155 default:
156 break;
157 }
158 outp[i] = (mtype_t) sum;
159 }
160 return;
161}
162
163#elif (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16)
164
165#if !TM_FASTSCALE
166TM_INLINE void tm_postprocess_sum(int n, sumtype_t *sums, btype_t *bs, int act, mtype_t *outp, sctype_t *scales, sctype_t out_s_inv, zptype_t out_zp)
167#else
168TM_INLINE void tm_postprocess_sum(int n, sumtype_t *sums, btype_t *bs, int act, mtype_t *outp, int32_t *scales, int32_t out_s, zptype_t out_zp)
169#endif
170{
171 for (int i = 0; i < n; i++) {
172 sumtype_t sum = sums[i];
173 sum += bs[i];
174#if !TM_FASTSCALE
175 float sumf = sum * scales[i];
176#else
177 sumtype_t sumf = (sum << TM_FASTSCALE_SHIFT) / scales[i];
178#endif
179 switch (act) {//activation func
180 case TM_ACT_RELU:
181 sumf = sumf > 0 ? sumf : 0;
182 break;
183 case TM_ACT_RELU6:
184 sumf = sumf > 0 ? sumf : 0;
185#if (!TM_FASTSCALE)
186 sumf = sumf > 6 ? 6 : sumf;
187#else
188 sumf = sumf > (6 << TM_FASTSCALE_SHIFT) ? (6 << TM_FASTSCALE_SHIFT) : sumf;
189#endif
190 break;
191 default:
192 break;
193 }
194#if !TM_FASTSCALE
195 outp[i] = (mtype_t) (sumf * out_s_inv + out_zp);//(mtype_t)((int)(sumf/out_s) + out_zp) //(mtype_t)((int)(sumf/out_s +0.5) + out_zp)
196#else
197 outp[i] = (mtype_t) (((sumf * out_s) >> (TM_FASTSCALE_SHIFT + TM_FASTSCALE_SHIFT)) + out_zp);
198#endif
199 }
200 return;
201}
202#endif
TM_INLINE void tm_dot_prod(mtype_t *sptr, mtype_t *kptr, uint32_t size, sumtype_t *result)
Definition arch_cpu.h:19
TM_INLINE void tm_dot_prod_pack2(mtype_t *sptr, mtype_t *kptr, uint32_t size, sumtype_t *result)
Definition arch_cpu.h:46
TM_INLINE void tm_postprocess_sum(int n, sumtype_t *sums, btype_t *bs, int act, mtype_t *outp, sctype_t *scales, sctype_t out_s_inv, zptype_t out_zp)
Definition arch_cpu.h:166
TM_INLINE void tm_dot_prod_3x3x1(mtype_t *sptr, mtype_t *kptr, sumtype_t *result)
Definition arch_cpu.h:96
TM_INLINE void tm_dot_prod_gap_3x3x1(mtype_t *sptr, mtype_t *kptr, uint32_t *k_oft, sumtype_t *result)
Definition arch_cpu.h:90
u32_t uint32_t
Definition stdint.h:13
s32_t int32_t
Definition stdint.h:12
@ TM_ACT_RELU
Definition tinymaix.h:125
@ TM_ACT_RELU6
Definition tinymaix.h:127
uint8_t TM_WEAK tm_fp32to8(float fp32)
int32_t btype_t
Definition tinymaix.h:38
float sctype_t
Definition tinymaix.h:90
#define TM_FASTSCALE_SHIFT
Definition tinymaix.h:91
int32_t zptype_t
Definition tinymaix.h:40
int8_t mtype_t
Definition tinymaix.h:36
int32_t sumtype_t
Definition tinymaix.h:39
float TM_WEAK tm_fp8to32(uint8_t fp8)
static uint32_t k_oft[TM_MAX_KSIZE]
Definition tm_layers.c:49
#define TM_INLINE
Definition tm_port.h:43