blob: 1ef5ce1efe6e15206957f1eb8cf1ccb275832832 [file] [log] [blame]
Miao Wang70ba50c2019-08-08 12:30:36 -07001// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// pack_avx.h: optimized AVX specializations of the templates in pack.h.
16
17#ifndef GEMMLOWP_INTERNAL_PACK_AVX_H_
18#define GEMMLOWP_INTERNAL_PACK_AVX_H_
19
20#include <immintrin.h>
21#include "pack.h"
22
23namespace gemmlowp {
24
25// TODO: Add DepthMajorUint8SideMap
26
27typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
28 WidthMajorUint8SideMap;
29
30template <int Cells>
31using WidthMajorSideFormatNCells4x2 =
32 KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, Cells>;
33
34template <int Cells>
35class PackingRegisterBlock<
36 WidthMajorUint8SideMap,
37 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
38 : public PackingRegisterBlockBase<
39 WidthMajorUint8SideMap,
40 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
41 public:
42 typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
43 typedef typename KernelSideFormat::Cell CellFormat;
44 static const int kCells = KernelSideFormat::kCells;
45 static const int kCellWidth = CellFormat::kWidth;
46 static const int kKernelWidth = CellFormat::kWidth * kCells;
47 static const int kCellDepth = CellFormat::kDepth;
48 static const int kCellSize = CellFormat::kSize;
49
50 void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
51 std::uint8_t *dst_ptr = dst->current_data();
52 const int width_stride = this->complete_src_.width_stride();
53 int depth_step = 16;
54
55 __m256i one = _mm256_set1_epi16(1);
56 for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
57 cell_start_depth += depth_step) {
58 for (int cell_start_width = 0; cell_start_width < kKernelWidth;
59 cell_start_width += kCellWidth) {
60 std::int32_t *cell_sums_of_each_slice_ptr =
61 dst->sums_of_each_slice() + start_width + cell_start_width;
62 const std::uint8_t *src_data =
63 this->complete_src_.data(cell_start_width, cell_start_depth);
64
65 __m128i xmm1 =
66 _mm_loadu_si128(reinterpret_cast<const __m128i *>(&src_data[0]));
67 __m128i xmm2 = _mm_loadu_si128(
68 reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
69 __m128i xmm3 = _mm_loadu_si128(
70 reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
71 __m128i xmm4 = _mm_loadu_si128(
72 reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
73 __m128i xmm5 = _mm_loadu_si128(
74 reinterpret_cast<const __m128i *>(&src_data[4 * width_stride]));
75 __m128i xmm6 = _mm_loadu_si128(
76 reinterpret_cast<const __m128i *>(&src_data[5 * width_stride]));
77 __m128i xmm7 = _mm_loadu_si128(
78 reinterpret_cast<const __m128i *>(&src_data[6 * width_stride]));
79 __m128i xmm8 = _mm_loadu_si128(
80 reinterpret_cast<const __m128i *>(&src_data[7 * width_stride]));
81
82 __m256i ymm1 = _mm256_set_m128i(xmm5, xmm1);
83 __m256i ymm2 = _mm256_set_m128i(xmm6, xmm2);
84 __m256i ymm3 = _mm256_set_m128i(xmm7, xmm3);
85 __m256i ymm4 = _mm256_set_m128i(xmm8, xmm4);
86
87 __m256i ymm5 = _mm256_unpacklo_epi16(ymm1, ymm2);
88 __m256i ymm6 = _mm256_unpacklo_epi16(ymm3, ymm4);
89
90 __m256i ymm9 = _mm256_unpackhi_epi16(ymm1, ymm2);
91 __m256i ymm10 = _mm256_unpackhi_epi16(ymm3, ymm4);
92
93 __m256i ymm7 = _mm256_unpacklo_epi32(ymm5, ymm6);
94 __m256i ymm8 = _mm256_unpackhi_epi32(ymm5, ymm6);
95
96 __m256i ymm13 = _mm256_unpacklo_epi32(ymm9, ymm10);
97 __m256i ymm14 = _mm256_unpackhi_epi32(ymm9, ymm10);
98
99 __m256i ymm11 = _mm256_permute4x64_epi64(ymm7, 0xd8);
100 __m256i ymm12 = _mm256_permute4x64_epi64(ymm8, 0xd8);
101
102 __m256i ymm15 = _mm256_permute4x64_epi64(ymm13, 0xd8);
103 __m256i ymm16 = _mm256_permute4x64_epi64(ymm14, 0xd8);
104
105 __m128i xmm9 = _mm256_castsi256_si128(ymm11);
106 __m128i xmm10 = _mm256_castsi256_si128(ymm12);
107 __m128i xmm11 = _mm256_extracti128_si256(ymm11, 1);
108 __m128i xmm12 = _mm256_extracti128_si256(ymm12, 1);
109
110 xmm1 = _mm256_castsi256_si128(ymm15);
111 xmm2 = _mm256_castsi256_si128(ymm16);
112 xmm3 = _mm256_extracti128_si256(ymm15, 1);
113 xmm4 = _mm256_extracti128_si256(ymm16, 1);
114
115 _mm_storeu_si128(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
116 _mm_storeu_si128(
117 reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm11);
118 _mm_storeu_si128(
119 reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
120 xmm10);
121 _mm_storeu_si128(
122 reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
123 xmm12);
124 _mm_storeu_si128(
125 reinterpret_cast<__m128i *>(&dst_ptr[4 * kCellSize * kCells]),
126 xmm1);
127 _mm_storeu_si128(
128 reinterpret_cast<__m128i *>(&dst_ptr[5 * kCellSize * kCells]),
129 xmm3);
130
131 _mm_storeu_si128(
132 reinterpret_cast<__m128i *>(&dst_ptr[6 * kCellSize * kCells]),
133 xmm2);
134 _mm_storeu_si128(
135 reinterpret_cast<__m128i *>(&dst_ptr[7 * kCellSize * kCells]),
136 xmm4);
137
138 ymm6 = _mm256_cvtepu8_epi16(xmm9);
139 ymm7 = _mm256_madd_epi16(ymm6, one);
140 __m256i sums_of_each_slice_xmm = _mm256_loadu_si256(
141 reinterpret_cast<const __m256i *>(&cell_sums_of_each_slice_ptr[0]));
142 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
143
144 ymm6 = _mm256_cvtepu8_epi16(xmm11);
145 ymm7 = _mm256_madd_epi16(ymm6, one);
146 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
147
148 ymm6 = _mm256_cvtepu8_epi16(xmm10);
149 ymm7 = _mm256_madd_epi16(ymm6, one);
150 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
151
152 ymm6 = _mm256_cvtepu8_epi16(xmm12);
153 ymm7 = _mm256_madd_epi16(ymm6, one);
154 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
155
156 ymm6 = _mm256_cvtepu8_epi16(xmm1);
157 ymm7 = _mm256_madd_epi16(ymm6, one);
158 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
159
160 ymm6 = _mm256_cvtepu8_epi16(xmm3);
161 ymm7 = _mm256_madd_epi16(ymm6, one);
162 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
163
164 ymm6 = _mm256_cvtepu8_epi16(xmm2);
165 ymm7 = _mm256_madd_epi16(ymm6, one);
166 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
167
168 ymm6 = _mm256_cvtepu8_epi16(xmm4);
169 ymm7 = _mm256_madd_epi16(ymm6, one);
170 sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
171
172 _mm256_storeu_si256(
173 reinterpret_cast<__m256i *>(&cell_sums_of_each_slice_ptr[0]),
174 sums_of_each_slice_xmm);
175 dst_ptr += kCellSize;
176 }
177 dst_ptr += 7 * kCellSize * kCells;
178 }
179 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
180 }
181};
182
183// Pack format for 4x2 rhs format
184template <int Cells>
185using RhsWidthMajorSideFormatNCells4x2 =
186 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
187
188template <int Cells>
189class PackingRegisterBlock<
190 WidthMajorUint8SideMap,
191 PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>>
192 : public PackingRegisterBlockBase<
193 WidthMajorUint8SideMap,
194 PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> {
195 public:
196 typedef RhsWidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
197 typedef typename KernelSideFormat::Cell CellFormat;
198 static const int kCells = KernelSideFormat::kCells;
199 static const int kCellWidth = CellFormat::kWidth;
200 static const int kKernelWidth = CellFormat::kWidth * kCells;
201 static const int kCellDepth = CellFormat::kDepth;
202 static const int kCellSize = CellFormat::kSize;
203
204 void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
205 std::uint8_t *dst_ptr = dst->current_data();
206 const int width_stride = this->complete_src_.width_stride();
207 int depth_step = 8;
208
209 __m128i one = _mm_set1_epi16(1);
210 for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
211 cell_start_depth += depth_step) {
212 for (int cell_start_width = 0; cell_start_width < kKernelWidth;
213 cell_start_width += kCellWidth) {
214 std::int32_t *cell_sums_of_each_slice_ptr =
215 dst->sums_of_each_slice() + start_width + cell_start_width;
216 const std::uint8_t *src_data =
217 this->complete_src_.data(cell_start_width, cell_start_depth);
218
219 __m128i xmm1 =
220 _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&src_data[0]));
221 __m128i xmm2 = _mm_loadl_epi64(
222 reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
223 __m128i xmm3 = _mm_loadl_epi64(
224 reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
225 __m128i xmm4 = _mm_loadl_epi64(
226 reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
227
228 __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
229 __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
230
231 __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
232 __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
233
234 __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
235 __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
236
237 _mm_storel_epi64(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
238 _mm_storel_epi64(
239 reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm10);
240
241 __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
242 __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
243
244 _mm_storel_epi64(
245 reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
246 xmm11);
247 _mm_storel_epi64(
248 reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
249 xmm12);
250
251 xmm1 = _mm_cvtepu8_epi16(xmm9);
252 xmm2 = _mm_madd_epi16(xmm1, one);
253 __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
254 reinterpret_cast<const __m128i *>(&cell_sums_of_each_slice_ptr[0]));
255 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
256
257 xmm1 = _mm_cvtepu8_epi16(xmm10);
258 xmm2 = _mm_madd_epi16(xmm1, one);
259 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
260
261 xmm1 = _mm_cvtepu8_epi16(xmm11);
262 xmm2 = _mm_madd_epi16(xmm1, one);
263 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
264
265 xmm1 = _mm_cvtepu8_epi16(xmm12);
266 xmm2 = _mm_madd_epi16(xmm1, one);
267 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
268
269 _mm_storeu_si128(
270 reinterpret_cast<__m128i *>(&cell_sums_of_each_slice_ptr[0]),
271 sums_of_each_slice_xmm);
272 dst_ptr += kCellSize;
273 }
274 dst_ptr += 3 * kCellSize * kCells;
275 }
276 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
277 }
278};
279
280} // namespace gemmlowp
281
282#endif // GEMMLOWP_INTERNAL_PACK_AVX_H_