Grok  9.5.0
dot-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Include guard (still compiled once per target)
16 #include <cmath>
17 
18 #if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \
19  defined(HWY_TARGET_TOGGLE)
20 #ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
21 #undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
22 #else
23 #define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
24 #endif
25 
26 #include "hwy/highway.h"
27 
29 namespace hwy {
30 namespace HWY_NAMESPACE {
31 
32 struct Dot {
33  // Specify zero or more of these, ORed together, as the kAssumptions template
34  // argument to Compute. Each one may improve performance or reduce code size,
35  // at the cost of additional requirements on the arguments.
36  enum Assumptions {
37  // num_elements is at least N, which may be up to HWY_MAX_LANES(T).
39  // num_elements is divisible by N (a power of two, so this can be used if
40  // the problem size is known to be a power of two >= HWY_MAX_LANES(T)).
42  // RoundUpTo(num_elements, N) elements are accessible; their value does not
43  // matter (will be treated as if they were zero).
45  // Pointers pa and pb, respectively, are multiples of N * sizeof(T).
46  // For example, aligned_allocator.h ensures this. Note that it is still
47  // beneficial to ensure such alignment even if these flags are not set.
48  // If not set, the pointers need only be aligned to alignof(T).
51  };
52 
53  // Returns sum{pa[i] * pb[i]} for float or double inputs.
54  template <int kAssumptions, class D, typename T = TFromD<D>,
55  HWY_IF_NOT_LANE_SIZE_D(D, 2)>
56  static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa,
57  const T* const HWY_RESTRICT pb,
58  const size_t num_elements) {
59  static_assert(IsFloat<T>(), "MulAdd requires float type");
60  using V = decltype(Zero(d));
61 
62  const size_t N = Lanes(d);
63  size_t i = 0;
64 
65  constexpr bool kIsAtLeastOneVector = kAssumptions & kAtLeastOneVector;
66  constexpr bool kIsMultipleOfVector = kAssumptions & kMultipleOfVector;
67  constexpr bool kIsPaddedToVector = kAssumptions & kPaddedToVector;
68  constexpr bool kIsAlignedA = kAssumptions & kVectorAlignedA;
69  constexpr bool kIsAlignedB = kAssumptions & kVectorAlignedB;
70 
71  // Won't be able to do a full vector load without padding => scalar loop.
72  if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
73  HWY_UNLIKELY(num_elements < N)) {
74  // Only 2x unroll to avoid excessive code size.
75  T sum0 = T(0);
76  T sum1 = T(0);
77  for (; i + 2 <= num_elements; i += 2) {
78  sum0 += pa[i + 0] * pb[i + 0];
79  sum1 += pa[i + 1] * pb[i + 1];
80  }
81  if (i < num_elements) {
82  sum1 += pa[i] * pb[i];
83  }
84  return sum0 + sum1;
85  }
86 
87  // Compiler doesn't make independent sum* accumulators, so unroll manually.
88  // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
89  // for unaligned inputs (each unaligned pointer halves the throughput
90  // because it occupies both L1 load ports for a cycle). We cannot have
91  // arrays of vectors on RVV/SVE, so always unroll 4x.
92  V sum0 = Zero(d);
93  V sum1 = Zero(d);
94  V sum2 = Zero(d);
95  V sum3 = Zero(d);
96 
97  // Main loop: unrolled
98  for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop
99  const auto a0 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
100  const auto b0 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
101  i += N;
102  sum0 = MulAdd(a0, b0, sum0);
103  const auto a1 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
104  const auto b1 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
105  i += N;
106  sum1 = MulAdd(a1, b1, sum1);
107  const auto a2 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
108  const auto b2 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
109  i += N;
110  sum2 = MulAdd(a2, b2, sum2);
111  const auto a3 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
112  const auto b3 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
113  i += N;
114  sum3 = MulAdd(a3, b3, sum3);
115  }
116 
117  // Up to 3 iterations of whole vectors
118  for (; i + N <= num_elements; i += N) {
119  const auto a = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
120  const auto b = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
121  sum0 = MulAdd(a, b, sum0);
122  }
123 
124  if (!kIsMultipleOfVector) {
125  const size_t remaining = num_elements - i;
126  if (remaining != 0) {
127  if (kIsPaddedToVector) {
128  const auto mask = FirstN(d, remaining);
129  const auto a = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
130  const auto b = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
131  sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
132  } else {
133  // Unaligned load such that the last element is in the highest lane -
134  // ensures we do not touch any elements outside the valid range.
135  // If we get here, then num_elements >= N.
136  HWY_DASSERT(i >= N);
137  i += remaining - N;
138  const auto skip = FirstN(d, N - remaining);
139  const auto a = LoadU(d, pa + i); // always unaligned
140  const auto b = LoadU(d, pb + i);
141  sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
142  }
143  }
144  } // kMultipleOfVector
145 
146  // Reduction tree: sum of all accumulators by pairs, then across lanes.
147  sum0 = Add(sum0, sum1);
148  sum2 = Add(sum2, sum3);
149  sum0 = Add(sum0, sum2);
150  return GetLane(SumOfLanes(d, sum0));
151  }
152 
153  // Returns sum{pa[i] * pb[i]} for bfloat16 inputs.
154  template <int kAssumptions, class D>
155  static HWY_INLINE float Compute(const D d,
156  const bfloat16_t* const HWY_RESTRICT pa,
157  const bfloat16_t* const HWY_RESTRICT pb,
158  const size_t num_elements) {
159  const RebindToUnsigned<D> du16;
160  const Repartition<float, D> df32;
161 
162  using V = decltype(Zero(df32));
163  const size_t N = Lanes(d);
164  size_t i = 0;
165 
166  constexpr bool kIsAtLeastOneVector = kAssumptions & kAtLeastOneVector;
167  constexpr bool kIsMultipleOfVector = kAssumptions & kMultipleOfVector;
168  constexpr bool kIsPaddedToVector = kAssumptions & kPaddedToVector;
169  constexpr bool kIsAlignedA = kAssumptions & kVectorAlignedA;
170  constexpr bool kIsAlignedB = kAssumptions & kVectorAlignedB;
171 
172  // Won't be able to do a full vector load without padding => scalar loop.
173  if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
174  HWY_UNLIKELY(num_elements < N)) {
175  float_t sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
176  float_t sum1 = 0.0f; // this unlikely(?) case.
177  for (; i + 2 <= num_elements; i += 2) {
178  sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
179  sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
180  }
181  if (i < num_elements) {
182  sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
183  }
184  return sum0 + sum1;
185  }
186 
187  // See comment in the other Compute() overload. Unroll 2x, but we need
188  // twice as many sums for ReorderWidenMulAccumulate.
189  V sum0 = Zero(df32);
190  V sum1 = Zero(df32);
191  V sum2 = Zero(df32);
192  V sum3 = Zero(df32);
193 
194  // Main loop: unrolled
195  for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop
196  const auto a0 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
197  const auto b0 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
198  i += N;
199  sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
200  const auto a1 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
201  const auto b1 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
202  i += N;
203  sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3);
204  }
205 
206  // Possibly one more iteration of whole vectors
207  if (i + N <= num_elements) {
208  const auto a0 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
209  const auto b0 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
210  i += N;
211  sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
212  }
213 
214  if (!kIsMultipleOfVector) {
215  const size_t remaining = num_elements - i;
216  if (remaining != 0) {
217  if (kIsPaddedToVector) {
218  const auto mask = FirstN(du16, remaining);
219  const auto va = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
220  const auto vb = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
221  const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
222  const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
223  sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
224 
225  } else {
226  // Unaligned load such that the last element is in the highest lane -
227  // ensures we do not touch any elements outside the valid range.
228  // If we get here, then num_elements >= N.
229  HWY_DASSERT(i >= N);
230  i += remaining - N;
231  const auto skip = FirstN(du16, N - remaining);
232  const auto va = LoadU(d, pa + i); // always unaligned
233  const auto vb = LoadU(d, pb + i);
234  const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
235  const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
236  sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
237  }
238  }
239  } // kMultipleOfVector
240 
241  // Reduction tree: sum of all accumulators by pairs, then across lanes.
242  sum0 = Add(sum0, sum1);
243  sum2 = Add(sum2, sum3);
244  sum0 = Add(sum0, sum2);
245  return GetLane(SumOfLanes(df32, sum0));
246  }
247 };
248 
249 // NOLINTNEXTLINE(google-readability-namespace-comments)
250 } // namespace HWY_NAMESPACE
251 } // namespace hwy
253 
254 #endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
#define HWY_RESTRICT
Definition: base.h:58
#define HWY_INLINE
Definition: base.h:59
#define HWY_DASSERT(condition)
Definition: base.h:163
#define HWY_UNLIKELY(expr)
Definition: base.h:64
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
HWY_API uint8_t GetLane(const Vec128< uint8_t, 16 > v)
Definition: arm_neon-inl.h:744
HWY_API Mask128< T, N > FirstN(const Simd< T, N > d, size_t num)
Definition: arm_neon-inl.h:1806
HWY_API Vec128< T, N > Load(Simd< T, N > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2152
HWY_API Vec128< float, N > MulAdd(const Vec128< float, N > mul, const Vec128< float, N > x, const Vec128< float, N > add)
Definition: arm_neon-inl.h:1232
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition: shared-inl.h:149
constexpr HWY_API size_t Lanes(Simd< T, N >)
Definition: arm_sve-inl.h:226
HWY_API Vec128< T, N > IfThenElseZero(const Mask128< T, N > mask, const Vec128< T, N > yes)
Definition: arm_neon-inl.h:1642
HWY_API V Add(V a, V b)
Definition: arm_neon-inl.h:5000
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:1953
HWY_API Vec128< T, N > BitCast(Simd< T, N > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:687
HWY_API Vec128< T, N > IfThenZeroElse(const Mask128< T, N > mask, const Vec128< T, N > no)
Definition: arm_neon-inl.h:1649
typename D::template Repartition< T > Repartition
Definition: shared-inl.h:155
HWY_API Vec128< T, N > SumOfLanes(Simd< T, N >, const Vec128< T, N > v)
Definition: arm_neon-inl.h:4203
HWY_API Vec128< T, N > Zero(Simd< T, N > d)
Definition: arm_neon-inl.h:710
HWY_API Vec128< float, N > ReorderWidenMulAccumulate(Simd< float, N > df32, Vec128< bfloat16_t, 2 *N > a, Vec128< bfloat16_t, 2 *N > b, const Vec128< float, N > sum0, Vec128< float, N > &sum1)
Definition: arm_neon-inl.h:3545
Definition: aligned_allocator.h:23
HWY_API float F32FromBF16(bfloat16_t bf)
Definition: base.h:648
#define HWY_NAMESPACE
Definition: set_macros-inl.h:77
Definition: dot-inl.h:32
static HWY_INLINE T Compute(const D d, const T *const HWY_RESTRICT pa, const T *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:56
static HWY_INLINE float Compute(const D d, const bfloat16_t *const HWY_RESTRICT pa, const bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:155
Assumptions
Definition: dot-inl.h:36
@ kMultipleOfVector
Definition: dot-inl.h:41
@ kPaddedToVector
Definition: dot-inl.h:44
@ kVectorAlignedA
Definition: dot-inl.h:49
@ kAtLeastOneVector
Definition: dot-inl.h:38
@ kVectorAlignedB
Definition: dot-inl.h:50
Definition: base.h:227