nmpp
nmblas_sgemm.h
См. документацию.
1
35
37
38#ifdef __NM__
39#ifdef __GNUC__
40#ifndef NMBLAS_SGEMM_H_
41#define NMBLAS_SGEMM_H_
42
43// aux macros
44#define ALL_FPU( instr ) "fpu 0 " instr "\n\t" \
45 "fpu 1 " instr "\n\t" \
46 "fpu 2 " instr "\n\t" \
47 "fpu 3 " instr "\n\t"
48
49// aux functions
50
51// load old C for specific fpu
52static inline __attribute__((always_inline)) void
53loadCFromMemory( const float* pc, int ldc, const int fpu, int* dummy_to_link )
54{
55 asm (
56 "fpu %4 rep vlen vreg7= [%1++%2];\n\t"
57 : "+m"(*dummy_to_link)
58 : "RA0"( pc + fpu*2 ), "RG0"(ldc), "m" (*(const float (*)[])pc), "i"(fpu) );
59}
60
61// load A for all fpu-s
62static inline __attribute__((always_inline)) void
63loadAFromMemory( const float* pa, int lda, const int vrNum, int* dummy_to_link )
64{
65 asm (
66 "fpu 0 rep vlen vreg%3 = [%1++%2];\n\t"
67// All fpu-s got the same vector
68 "fpu 1 vreg%3 = fpu 0 vreg%3;\n\t"
69 "fpu 2 vreg%3 = fpu 1 vreg%3;\n\t"
70 "fpu 3 vreg%3 = fpu 2 vreg%3;\n\t"
71 : "+m" (*dummy_to_link), "+RA0" (pa)
72 : "RG0"(lda), "i"(vrNum), "m"(*(const float (*)[])pa) );
73}
74
75static inline __attribute__((always_inline)) void
76loadBAndMAdd( const float* pb, const float* pb1, int ldb, const int vrNum, int* dummy_to_link )
77{
78 asm (
79 "fpu 0 rep 1 vreg4 = [%1++];\n\t"
80 "fpu 0 rep 1 vreg5 = [%2++];\n\t"
81 "fpu 1 rep 1 vreg4 = [%1++];\n\t"
82 "fpu 1 rep 1 vreg5 = [%2++];\n\t"
83 "fpu 2 rep 1 vreg4 = [%1++];\n\t"
84 "fpu 2 rep 1 vreg5 = [%2++];\n\t"
85 "fpu 3 rep 1 vreg4 = [%1++];\n\t"
86 "fpu 3 rep 1 vreg5 = [%2++];\n\t"
87 ALL_FPU (".matrix vreg7= vreg%3 * .retrieve (vreg4,vreg5) + vreg7;")
88 : "+m" (*dummy_to_link), "+a" (pb), "+a" (pb1)
89 : "i"(vrNum), "m"(*(const float (*)[])pb), "m"(*(const float (*)[])pb1) );
90}
91
92// Same as loadBAndMAdd, but without "+C"
93static inline __attribute__((always_inline)) void
94loadBAndMultiply( const float* pb, const float* pb1, int ldb, const int vrNum, int* dummy_to_link )
95{
96 asm (
97 "fpu 0 rep 1 vreg4 = [%1++];\n\t"
98 "fpu 0 rep 1 vreg5 = [%2++];\n\t"
99 "fpu 1 rep 1 vreg4 = [%1++];\n\t"
100 "fpu 1 rep 1 vreg5 = [%2++];\n\t"
101 "fpu 2 rep 1 vreg4 = [%1++];\n\t"
102 "fpu 2 rep 1 vreg5 = [%2++];\n\t"
103 "fpu 3 rep 1 vreg4 = [%1++];\n\t"
104 "fpu 3 rep 1 vreg5 = [%2++];\n\t"
105 ALL_FPU (".matrix vreg7= vreg%3 * .retrieve (vreg4,vreg5);")
106 : "+m" (*dummy_to_link), "+a" (pb), "+a" (pb1)
107 : "i"(vrNum), "m"(*(const float (*)[])pb), "m"(*(const float (*)[])pb1) );
108}
109
110
111
112static inline __attribute__((always_inline)) void
113storeCToMemory( float* pc, int ldc, const int fpu, int* dummy_to_link )
114{
115 asm (
116 "fpu %4 rep vlen [ar0++gr0] = vreg7;\n\t"
117 : "=m"(*(float (*)[])pc)
118 : "RA0"(pc), "RG0"(ldc), "m"(*dummy_to_link), "i"(fpu) );
119}
120
121
122
123static inline void
124nmblas_sgemm( const enum nm_trans TransA,
125 const enum nm_trans TransB,
126 const int M,
127 const int N,
128 const int K,
129 const float alpha,
130 const float *A,
131 const int lda,
132 const float *B,
133 const int ldb,
134 const float _beta,
135 float *C,
136 const int ldc
137 )
138{
139 float beta= _beta;
140 if( TransA != nm_n || TransB != nm_n ){}
141
142 // Нижеследующие сравнения и деления должны быть локализованы строго до векторного кода,
143 // поскольку реализующие их интринсики испортят значения в векторных регистрах!
144 // { all float intrinsic calls must be here!
145 int beta0 = beta ==0.0f;
146 int alpha1 = alpha==1.0f;
147 int beta1;
148 if ( !beta0 ){
149 if ( !alpha1 && alpha !=0.0f )
150 beta /= alpha;
151 beta1 = beta ==1.0f;
152 }
153 // } all float intrinsic calls must be here!
154
155
156 const int I=M;
157 const int J=N;
158 int i, j, k;
159 int* dummy_to_link; // workaround to reflect dependence by vector registers
160
161 for(i=0; i<I; i+=32){
162 asm volatile(
163 "vlen= %0;\n\t"
164 :
165 : "g"( I-i-1 >= 31 ? 31 : I-i-1 ) );
166
167 for(j=0; j<J; j+=8){
168
169 k=0;
170 const float* pa;
171 const float* pb;
172 const float* pb1;
173 float bufScalar[2] __attribute__ ((aligned (8)));
174
175 asm("":"=m"(*dummy_to_link),"=a"(dummy_to_link));
176
177 if ( !beta0 ){
178 // read C[i][j]
179 loadCFromMemory( C + i *ldc +j, ldc, 0, dummy_to_link );
180 loadCFromMemory( C + i *ldc +j, ldc, 1, dummy_to_link );
181 loadCFromMemory( C + i *ldc +j, ldc, 2, dummy_to_link );
182 loadCFromMemory( C + i *ldc +j, ldc, 3, dummy_to_link );
183 if ( !beta1 ){
184 // C[i][j] *= beta
185 bufScalar[0]=beta;
186 bufScalar[1]=beta;
187 float* pbeta= bufScalar;
188 asm (
189 "fpu 0 rep 1 vreg4 = [%1++];\n\t"
190 "fpu 1 vreg4 = fpu 0 vreg4;\n\t"
191 "fpu 2 vreg4 = fpu 1 vreg4;\n\t"
192 "fpu 3 vreg4 = fpu 2 vreg4;\n\t"
193 ALL_FPU (".float vreg7= vreg7 * .retrieve (vreg4);")
194 : "+m" (*dummy_to_link), "+a" (pbeta)
195 : "m"(*bufScalar) );
196 }
197 }
198 else{
199 pa = A + i*lda +k;
200 pb = B + k*ldb +j;
201 pb1 = B +(k+1)*ldb +j;
202
203 asm("":"=m"(*dummy_to_link));
204 loadAFromMemory ( pa, lda, 0, dummy_to_link );
205
206 loadBAndMultiply( pb, pb1, ldb, 0, dummy_to_link );
207 k+=2;
208 }
209
210
211
212 for( ; k<K; k+=2){
213 //C[i][j] += A[i][k]*B[k][j];
214 pa = A + i*lda +k;
215 pb = B + k*ldb +j;
216 pb1 = B +(k+1)*ldb +j;
217
218 loadAFromMemory( pa, lda, 3, dummy_to_link );
219
220 loadBAndMAdd( pb, pb1, ldb, 3, dummy_to_link );
221
222 k+=2;
223 if ( !( k<K ) )
224 break;
225 pa = A + i*lda +k;
226 pb = B + k*ldb +j;
227 pb1 = B +(k+1)*ldb +j;
228
229 loadAFromMemory( pa, lda, 0, dummy_to_link );
230
231 loadBAndMAdd( pb, pb1, ldb, 0, dummy_to_link );
232 }
233
234 if ( !alpha1 ){
235 // C[i][j] *= alpha
236 bufScalar[0]=alpha;
237 bufScalar[1]=alpha;
238 float* palpha= bufScalar;
239 asm (
240 "fpu 0 rep 1 vreg4 = [%1++];\n\t"
241 "fpu 1 vreg4 = fpu 0 vreg4;\n\t"
242 "fpu 2 vreg4 = fpu 1 vreg4;\n\t"
243 "fpu 3 vreg4 = fpu 2 vreg4;\n\t"
244 ALL_FPU (".float vreg7= vreg7 * .retrieve (vreg4);")
245 : "+m" (*dummy_to_link), "+a" (palpha)
246 : "m"(*bufScalar) );
247 }
248
249 // write C[i][j]
250 storeCToMemory( C + i *ldc +j+0, ldc, 0, dummy_to_link );
251 if ( J-j<=2 )
252 break;
253 storeCToMemory( C + i *ldc +j+2, ldc, 1, dummy_to_link );
254 if ( J-j<=4 )
255 break;
256 storeCToMemory( C + i *ldc +j+4, ldc, 2, dummy_to_link );
257 if ( J-j<=6 )
258 break;
259 storeCToMemory( C + i *ldc +j+6, ldc, 3, dummy_to_link );
260 }
261 }
262}
263
264
265
266#endif /* NMBLAS_SGEMM_H_ */
267#endif /* __GNUC__ */
268#endif /* __NM__ */