Miao Wang | 70ba50c | 2019-08-08 12:30:36 -0700 | [diff] [blame] | 1 | // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | // pack_avx.h: optimized AVX specializations of the templates in pack.h. |
| 16 | |
| 17 | #ifndef GEMMLOWP_INTERNAL_PACK_AVX_H_ |
| 18 | #define GEMMLOWP_INTERNAL_PACK_AVX_H_ |
| 19 | |
| 20 | #include <immintrin.h> |
| 21 | #include "pack.h" |
| 22 | |
| 23 | namespace gemmlowp { |
| 24 | |
| 25 | // TODO: Add DepthMajorUint8SideMap |
| 26 | |
| 27 | typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> |
| 28 | WidthMajorUint8SideMap; |
| 29 | |
| 30 | template <int Cells> |
| 31 | using WidthMajorSideFormatNCells4x2 = |
| 32 | KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, Cells>; |
| 33 | |
| 34 | template <int Cells> |
| 35 | class PackingRegisterBlock< |
| 36 | WidthMajorUint8SideMap, |
| 37 | PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> |
| 38 | : public PackingRegisterBlockBase< |
| 39 | WidthMajorUint8SideMap, |
| 40 | PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> { |
| 41 | public: |
| 42 | typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; |
| 43 | typedef typename KernelSideFormat::Cell CellFormat; |
| 44 | static const int kCells = KernelSideFormat::kCells; |
| 45 | static const int kCellWidth = CellFormat::kWidth; |
| 46 | static const int kKernelWidth = CellFormat::kWidth * kCells; |
| 47 | static const int kCellDepth = CellFormat::kDepth; |
| 48 | static const int kCellSize = CellFormat::kSize; |
| 49 | |
| 50 | void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) { |
| 51 | std::uint8_t *dst_ptr = dst->current_data(); |
| 52 | const int width_stride = this->complete_src_.width_stride(); |
| 53 | int depth_step = 16; |
| 54 | |
| 55 | __m256i one = _mm256_set1_epi16(1); |
| 56 | for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; |
| 57 | cell_start_depth += depth_step) { |
| 58 | for (int cell_start_width = 0; cell_start_width < kKernelWidth; |
| 59 | cell_start_width += kCellWidth) { |
| 60 | std::int32_t *cell_sums_of_each_slice_ptr = |
| 61 | dst->sums_of_each_slice() + start_width + cell_start_width; |
| 62 | const std::uint8_t *src_data = |
| 63 | this->complete_src_.data(cell_start_width, cell_start_depth); |
| 64 | |
| 65 | __m128i xmm1 = |
| 66 | _mm_loadu_si128(reinterpret_cast<const __m128i *>(&src_data[0])); |
| 67 | __m128i xmm2 = _mm_loadu_si128( |
| 68 | reinterpret_cast<const __m128i *>(&src_data[1 * width_stride])); |
| 69 | __m128i xmm3 = _mm_loadu_si128( |
| 70 | reinterpret_cast<const __m128i *>(&src_data[2 * width_stride])); |
| 71 | __m128i xmm4 = _mm_loadu_si128( |
| 72 | reinterpret_cast<const __m128i *>(&src_data[3 * width_stride])); |
| 73 | __m128i xmm5 = _mm_loadu_si128( |
| 74 | reinterpret_cast<const __m128i *>(&src_data[4 * width_stride])); |
| 75 | __m128i xmm6 = _mm_loadu_si128( |
| 76 | reinterpret_cast<const __m128i *>(&src_data[5 * width_stride])); |
| 77 | __m128i xmm7 = _mm_loadu_si128( |
| 78 | reinterpret_cast<const __m128i *>(&src_data[6 * width_stride])); |
| 79 | __m128i xmm8 = _mm_loadu_si128( |
| 80 | reinterpret_cast<const __m128i *>(&src_data[7 * width_stride])); |
| 81 | |
| 82 | __m256i ymm1 = _mm256_set_m128i(xmm5, xmm1); |
| 83 | __m256i ymm2 = _mm256_set_m128i(xmm6, xmm2); |
| 84 | __m256i ymm3 = _mm256_set_m128i(xmm7, xmm3); |
| 85 | __m256i ymm4 = _mm256_set_m128i(xmm8, xmm4); |
| 86 | |
| 87 | __m256i ymm5 = _mm256_unpacklo_epi16(ymm1, ymm2); |
| 88 | __m256i ymm6 = _mm256_unpacklo_epi16(ymm3, ymm4); |
| 89 | |
| 90 | __m256i ymm9 = _mm256_unpackhi_epi16(ymm1, ymm2); |
| 91 | __m256i ymm10 = _mm256_unpackhi_epi16(ymm3, ymm4); |
| 92 | |
| 93 | __m256i ymm7 = _mm256_unpacklo_epi32(ymm5, ymm6); |
| 94 | __m256i ymm8 = _mm256_unpackhi_epi32(ymm5, ymm6); |
| 95 | |
| 96 | __m256i ymm13 = _mm256_unpacklo_epi32(ymm9, ymm10); |
| 97 | __m256i ymm14 = _mm256_unpackhi_epi32(ymm9, ymm10); |
| 98 | |
| 99 | __m256i ymm11 = _mm256_permute4x64_epi64(ymm7, 0xd8); |
| 100 | __m256i ymm12 = _mm256_permute4x64_epi64(ymm8, 0xd8); |
| 101 | |
| 102 | __m256i ymm15 = _mm256_permute4x64_epi64(ymm13, 0xd8); |
| 103 | __m256i ymm16 = _mm256_permute4x64_epi64(ymm14, 0xd8); |
| 104 | |
| 105 | __m128i xmm9 = _mm256_castsi256_si128(ymm11); |
| 106 | __m128i xmm10 = _mm256_castsi256_si128(ymm12); |
| 107 | __m128i xmm11 = _mm256_extracti128_si256(ymm11, 1); |
| 108 | __m128i xmm12 = _mm256_extracti128_si256(ymm12, 1); |
| 109 | |
| 110 | xmm1 = _mm256_castsi256_si128(ymm15); |
| 111 | xmm2 = _mm256_castsi256_si128(ymm16); |
| 112 | xmm3 = _mm256_extracti128_si256(ymm15, 1); |
| 113 | xmm4 = _mm256_extracti128_si256(ymm16, 1); |
| 114 | |
| 115 | _mm_storeu_si128(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9); |
| 116 | _mm_storeu_si128( |
| 117 | reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm11); |
| 118 | _mm_storeu_si128( |
| 119 | reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]), |
| 120 | xmm10); |
| 121 | _mm_storeu_si128( |
| 122 | reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]), |
| 123 | xmm12); |
| 124 | _mm_storeu_si128( |
| 125 | reinterpret_cast<__m128i *>(&dst_ptr[4 * kCellSize * kCells]), |
| 126 | xmm1); |
| 127 | _mm_storeu_si128( |
| 128 | reinterpret_cast<__m128i *>(&dst_ptr[5 * kCellSize * kCells]), |
| 129 | xmm3); |
| 130 | |
| 131 | _mm_storeu_si128( |
| 132 | reinterpret_cast<__m128i *>(&dst_ptr[6 * kCellSize * kCells]), |
| 133 | xmm2); |
| 134 | _mm_storeu_si128( |
| 135 | reinterpret_cast<__m128i *>(&dst_ptr[7 * kCellSize * kCells]), |
| 136 | xmm4); |
| 137 | |
| 138 | ymm6 = _mm256_cvtepu8_epi16(xmm9); |
| 139 | ymm7 = _mm256_madd_epi16(ymm6, one); |
| 140 | __m256i sums_of_each_slice_xmm = _mm256_loadu_si256( |
| 141 | reinterpret_cast<const __m256i *>(&cell_sums_of_each_slice_ptr[0])); |
| 142 | sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); |
| 143 | |
| 144 | ymm6 = _mm256_cvtepu8_epi16(xmm11); |
| 145 | ymm7 = _mm256_madd_epi16(ymm6, one); |
| 146 | sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); |
| 147 | |
| 148 | ymm6 = _mm256_cvtepu8_epi16(xmm10); |
| 149 | ymm7 = _mm256_madd_epi16(ymm6, one); |
| 150 | sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); |
| 151 | |
| 152 | ymm6 = _mm256_cvtepu8_epi16(xmm12); |
| 153 | ymm7 = _mm256_madd_epi16(ymm6, one); |
| 154 | sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); |
| 155 | |
| 156 | ymm6 = _mm256_cvtepu8_epi16(xmm1); |
| 157 | ymm7 = _mm256_madd_epi16(ymm6, one); |
| 158 | sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); |
| 159 | |
| 160 | ymm6 = _mm256_cvtepu8_epi16(xmm3); |
| 161 | ymm7 = _mm256_madd_epi16(ymm6, one); |
| 162 | sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); |
| 163 | |
| 164 | ymm6 = _mm256_cvtepu8_epi16(xmm2); |
| 165 | ymm7 = _mm256_madd_epi16(ymm6, one); |
| 166 | sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); |
| 167 | |
| 168 | ymm6 = _mm256_cvtepu8_epi16(xmm4); |
| 169 | ymm7 = _mm256_madd_epi16(ymm6, one); |
| 170 | sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); |
| 171 | |
| 172 | _mm256_storeu_si256( |
| 173 | reinterpret_cast<__m256i *>(&cell_sums_of_each_slice_ptr[0]), |
| 174 | sums_of_each_slice_xmm); |
| 175 | dst_ptr += kCellSize; |
| 176 | } |
| 177 | dst_ptr += 7 * kCellSize * kCells; |
| 178 | } |
| 179 | dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); |
| 180 | } |
| 181 | }; |
| 182 | |
| 183 | // Pack format for 4x2 rhs format |
| 184 | template <int Cells> |
| 185 | using RhsWidthMajorSideFormatNCells4x2 = |
| 186 | KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; |
| 187 | |
| 188 | template <int Cells> |
| 189 | class PackingRegisterBlock< |
| 190 | WidthMajorUint8SideMap, |
| 191 | PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> |
| 192 | : public PackingRegisterBlockBase< |
| 193 | WidthMajorUint8SideMap, |
| 194 | PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> { |
| 195 | public: |
| 196 | typedef RhsWidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; |
| 197 | typedef typename KernelSideFormat::Cell CellFormat; |
| 198 | static const int kCells = KernelSideFormat::kCells; |
| 199 | static const int kCellWidth = CellFormat::kWidth; |
| 200 | static const int kKernelWidth = CellFormat::kWidth * kCells; |
| 201 | static const int kCellDepth = CellFormat::kDepth; |
| 202 | static const int kCellSize = CellFormat::kSize; |
| 203 | |
| 204 | void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) { |
| 205 | std::uint8_t *dst_ptr = dst->current_data(); |
| 206 | const int width_stride = this->complete_src_.width_stride(); |
| 207 | int depth_step = 8; |
| 208 | |
| 209 | __m128i one = _mm_set1_epi16(1); |
| 210 | for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; |
| 211 | cell_start_depth += depth_step) { |
| 212 | for (int cell_start_width = 0; cell_start_width < kKernelWidth; |
| 213 | cell_start_width += kCellWidth) { |
| 214 | std::int32_t *cell_sums_of_each_slice_ptr = |
| 215 | dst->sums_of_each_slice() + start_width + cell_start_width; |
| 216 | const std::uint8_t *src_data = |
| 217 | this->complete_src_.data(cell_start_width, cell_start_depth); |
| 218 | |
| 219 | __m128i xmm1 = |
| 220 | _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&src_data[0])); |
| 221 | __m128i xmm2 = _mm_loadl_epi64( |
| 222 | reinterpret_cast<const __m128i *>(&src_data[1 * width_stride])); |
| 223 | __m128i xmm3 = _mm_loadl_epi64( |
| 224 | reinterpret_cast<const __m128i *>(&src_data[2 * width_stride])); |
| 225 | __m128i xmm4 = _mm_loadl_epi64( |
| 226 | reinterpret_cast<const __m128i *>(&src_data[3 * width_stride])); |
| 227 | |
| 228 | __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2); |
| 229 | __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31); |
| 230 | |
| 231 | __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4); |
| 232 | __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80); |
| 233 | |
| 234 | __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc); |
| 235 | __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc); |
| 236 | |
| 237 | _mm_storel_epi64(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9); |
| 238 | _mm_storel_epi64( |
| 239 | reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm10); |
| 240 | |
| 241 | __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee); |
| 242 | __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee); |
| 243 | |
| 244 | _mm_storel_epi64( |
| 245 | reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]), |
| 246 | xmm11); |
| 247 | _mm_storel_epi64( |
| 248 | reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]), |
| 249 | xmm12); |
| 250 | |
| 251 | xmm1 = _mm_cvtepu8_epi16(xmm9); |
| 252 | xmm2 = _mm_madd_epi16(xmm1, one); |
| 253 | __m128i sums_of_each_slice_xmm = _mm_loadu_si128( |
| 254 | reinterpret_cast<const __m128i *>(&cell_sums_of_each_slice_ptr[0])); |
| 255 | sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); |
| 256 | |
| 257 | xmm1 = _mm_cvtepu8_epi16(xmm10); |
| 258 | xmm2 = _mm_madd_epi16(xmm1, one); |
| 259 | sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); |
| 260 | |
| 261 | xmm1 = _mm_cvtepu8_epi16(xmm11); |
| 262 | xmm2 = _mm_madd_epi16(xmm1, one); |
| 263 | sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); |
| 264 | |
| 265 | xmm1 = _mm_cvtepu8_epi16(xmm12); |
| 266 | xmm2 = _mm_madd_epi16(xmm1, one); |
| 267 | sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); |
| 268 | |
| 269 | _mm_storeu_si128( |
| 270 | reinterpret_cast<__m128i *>(&cell_sums_of_each_slice_ptr[0]), |
| 271 | sums_of_each_slice_xmm); |
| 272 | dst_ptr += kCellSize; |
| 273 | } |
| 274 | dst_ptr += 3 * kCellSize * kCells; |
| 275 | } |
| 276 | dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); |
| 277 | } |
| 278 | }; |
| 279 | |
| 280 | } // namespace gemmlowp |
| 281 | |
| 282 | #endif // GEMMLOWP_INTERNAL_PACK_AVX_H_ |