1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ |
17 | #define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ |
18 | |
19 | #include <cmath> |
20 | #include <complex> |
21 | |
22 | // We need cpu_info.h here in order to pick up __BYTE_ORDER__. |
23 | #include "tensorflow/core/platform/cpu_info.h" |
24 | |
25 | #ifdef __CUDACC__ |
26 | // All functions callable from CUDA code must be qualified with __device__ |
27 | #define B16_DEVICE_FUNC __host__ __device__ |
28 | |
29 | #else |
30 | #define B16_DEVICE_FUNC |
31 | |
32 | #endif |
33 | |
34 | namespace Eigen { |
35 | struct half; |
36 | } |
37 | |
38 | namespace tensorflow { |
39 | |
40 | // Single precision complex. |
41 | typedef std::complex<float> complex64; |
42 | // Double precision complex. |
43 | typedef std::complex<double> complex128; |
44 | |
45 | // see framework/bfloat16.h for description. |
46 | struct bfloat16 { |
47 | B16_DEVICE_FUNC bfloat16() {} |
48 | |
49 | B16_DEVICE_FUNC explicit bfloat16(const float v) { |
50 | if (float_isnan(v)) { |
51 | value = NAN_VALUE; |
52 | return; |
53 | } |
54 | const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); |
55 | #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
56 | value = p[0]; |
57 | #else |
58 | value = p[1]; |
59 | #endif |
60 | } |
61 | |
62 | B16_DEVICE_FUNC explicit bfloat16(const double val) |
63 | : bfloat16(static_cast<float>(val)) {} |
64 | // Following the convention of numpy, converting between complex and |
65 | // float will lead to loss of imag value. |
66 | B16_DEVICE_FUNC explicit bfloat16(const complex64& val) |
67 | : bfloat16(val.real()) {} |
68 | |
69 | B16_DEVICE_FUNC explicit bfloat16(const complex128& val) |
70 | : bfloat16(static_cast<float>(val.real())) {} |
71 | |
72 | B16_DEVICE_FUNC explicit bfloat16(const unsigned short val) |
73 | : bfloat16(static_cast<float>(val)) {} |
74 | |
75 | B16_DEVICE_FUNC explicit bfloat16(const unsigned int val) |
76 | : bfloat16(static_cast<float>(val)) {} |
77 | |
78 | B16_DEVICE_FUNC explicit bfloat16(const int val) |
79 | : bfloat16(static_cast<float>(val)) {} |
80 | |
81 | B16_DEVICE_FUNC explicit bfloat16(const long val) |
82 | : bfloat16(static_cast<float>(val)) {} |
83 | |
84 | B16_DEVICE_FUNC explicit bfloat16(const long long val) |
85 | : bfloat16(static_cast<float>(val)) {} |
86 | |
87 | template <class T> |
88 | B16_DEVICE_FUNC explicit bfloat16(const T& val) |
89 | : bfloat16(static_cast<float>(val)) {} |
90 | |
91 | B16_DEVICE_FUNC explicit operator float() const { |
92 | float result; |
93 | |
94 | uint16_t* q = reinterpret_cast<uint16_t*>(&result); |
95 | |
96 | #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
97 | q[0] = value; |
98 | q[1] = 0; |
99 | #else |
100 | q[0] = 0; |
101 | q[1] = value; |
102 | #endif |
103 | return result; |
104 | } |
105 | |
106 | B16_DEVICE_FUNC explicit operator bool() const { |
107 | return static_cast<bool>(float(*this)); |
108 | } |
109 | |
110 | B16_DEVICE_FUNC explicit operator Eigen::half() const; |
111 | |
112 | B16_DEVICE_FUNC explicit operator short() const { |
113 | return static_cast<short>(float(*this)); |
114 | } |
115 | |
116 | B16_DEVICE_FUNC explicit operator int() const { |
117 | return static_cast<int>(float(*this)); |
118 | } |
119 | |
120 | B16_DEVICE_FUNC explicit operator long() const { |
121 | return static_cast<long>(float(*this)); |
122 | } |
123 | |
124 | B16_DEVICE_FUNC explicit operator char() const { |
125 | return static_cast<char>(float(*this)); |
126 | } |
127 | |
128 | B16_DEVICE_FUNC explicit operator signed char() const { |
129 | return static_cast<signed char>(float(*this)); |
130 | } |
131 | |
132 | B16_DEVICE_FUNC explicit operator unsigned char() const { |
133 | return static_cast<unsigned char>(float(*this)); |
134 | } |
135 | |
136 | B16_DEVICE_FUNC explicit operator unsigned short() const { |
137 | return static_cast<unsigned short>(float(*this)); |
138 | } |
139 | |
140 | B16_DEVICE_FUNC explicit operator unsigned int() const { |
141 | return static_cast<unsigned int>(float(*this)); |
142 | } |
143 | |
144 | B16_DEVICE_FUNC explicit operator unsigned long() const { |
145 | return static_cast<unsigned long>(float(*this)); |
146 | } |
147 | |
148 | B16_DEVICE_FUNC explicit operator unsigned long long() const { |
149 | return static_cast<unsigned long long>(float(*this)); |
150 | } |
151 | |
152 | B16_DEVICE_FUNC explicit operator long long() const { |
153 | return static_cast<long long>(float(*this)); |
154 | } |
155 | |
156 | B16_DEVICE_FUNC explicit operator double() const { |
157 | return static_cast<double>(float(*this)); |
158 | } |
159 | |
160 | B16_DEVICE_FUNC explicit operator complex64() const { |
161 | return complex64(float(*this), float(0.0)); |
162 | } |
163 | |
164 | B16_DEVICE_FUNC explicit operator complex128() const { |
165 | return complex128(double(*this), double(0.0)); |
166 | } |
167 | |
168 | union FP32 { |
169 | unsigned int u; |
170 | float f; |
171 | }; |
172 | |
173 | // Converts a float point to bfloat16, with round-nearest-to-even as rounding |
174 | // method. |
175 | // TODO(b/69266521): Add a truncate_to_bfloat16 function and make this |
176 | // function as default behavior. |
177 | // TODO: There is a slightly faster implementation (8% faster on CPU) |
178 | // than this (documented in cl/175987786), that is exponentially harder to |
179 | // understand and document. Switch to the faster version when converting to |
180 | // BF16 becomes compute-bound. |
181 | B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) { |
182 | uint32_t input; |
183 | FP32 f; |
184 | f.f = v; |
185 | input = f.u; |
186 | bfloat16 output; |
187 | |
188 | if (float_isnan(v)) { |
189 | // If the value is a NaN, squash it to a qNaN with msb of fraction set, |
190 | // this makes sure after truncation we don't end up with an inf. |
191 | // |
192 | // qNaN magic: All exponent bits set + most significant bit of fraction |
193 | // set. |
194 | output.value = 0x7fc0; |
195 | } else { |
196 | // Fast rounding algorithm that rounds a half value to nearest even. This |
197 | // reduces expected error when we convert a large number of floats. Here |
198 | // is how it works: |
199 | // |
200 | // Definitions: |
201 | // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits |
202 | // with the following tags: |
203 | // |
204 | // Sign | Exp (8 bits) | Frac (23 bits) |
205 | // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT |
206 | // |
207 | // S: Sign bit. |
208 | // E: Exponent bits. |
209 | // F: First 6 bits of fraction. |
210 | // L: Least significant bit of resulting bfloat16 if we truncate away the |
211 | // rest of the float32. This is also the 7th bit of fraction |
212 | // R: Rounding bit, 8th bit of fraction. |
213 | // T: Sticky bits, rest of fraction, 15 bits. |
214 | // |
215 | // To round half to nearest even, there are 3 cases where we want to round |
216 | // down (simply truncate the result of the bits away, which consists of |
217 | // rounding bit and sticky bits) and two cases where we want to round up |
218 | // (truncate then add one to the result). |
219 | // |
220 | // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of |
221 | // 1s) as the rounding bias, adds the rounding bias to the input, then |
222 | // truncates the last 16 bits away. |
223 | // |
224 | // To understand how it works, we can analyze this algorithm case by case: |
225 | // |
226 | // 1. L = 0, R = 0: |
227 | // Expect: round down, this is less than half value. |
228 | // |
229 | // Algorithm: |
230 | // - Rounding bias: 0x7fff + 0 = 0x7fff |
231 | // - Adding rounding bias to input may create any carry, depending on |
232 | // whether there is any value set to 1 in T bits. |
233 | // - R may be set to 1 if there is a carry. |
234 | // - L remains 0. |
235 | // - Note that this case also handles Inf and -Inf, where all fraction |
236 | // bits, including L, R and Ts are all 0. The output remains Inf after |
237 | // this algorithm. |
238 | // |
239 | // 2. L = 1, R = 0: |
240 | // Expect: round down, this is less than half value. |
241 | // |
242 | // Algorithm: |
243 | // - Rounding bias: 0x7fff + 1 = 0x8000 |
244 | // - Adding rounding bias to input doesn't change sticky bits but |
245 | // adds 1 to rounding bit. |
246 | // - L remains 1. |
247 | // |
248 | // 3. L = 0, R = 1, all of T are 0: |
249 | // Expect: round down, this is exactly at half, the result is already |
250 | // even (L=0). |
251 | // |
252 | // Algorithm: |
253 | // - Rounding bias: 0x7fff + 0 = 0x7fff |
254 | // - Adding rounding bias to input sets all sticky bits to 1, but |
255 | // doesn't create a carry. |
256 | // - R remains 1. |
257 | // - L remains 0. |
258 | // |
259 | // 4. L = 1, R = 1: |
260 | // Expect: round up, this is exactly at half, the result needs to be |
261 | // round to the next even number. |
262 | // |
263 | // Algorithm: |
264 | // - Rounding bias: 0x7fff + 1 = 0x8000 |
265 | // - Adding rounding bias to input doesn't change sticky bits, but |
266 | // creates a carry from rounding bit. |
267 | // - The carry sets L to 0, creates another carry bit and propagate |
268 | // forward to F bits. |
269 | // - If all the F bits are 1, a carry then propagates to the exponent |
270 | // bits, which then creates the minimum value with the next exponent |
271 | // value. Note that we won't have the case where exponents are all 1, |
272 | // since that's either a NaN (handled in the other if condition) or inf |
273 | // (handled in case 1). |
274 | // |
275 | // 5. L = 0, R = 1, any of T is 1: |
276 | // Expect: round up, this is greater than half. |
277 | // |
278 | // Algorithm: |
279 | // - Rounding bias: 0x7fff + 0 = 0x7fff |
280 | // - Adding rounding bias to input creates a carry from sticky bits, |
281 | // sets rounding bit to 0, then create another carry. |
282 | // - The second carry sets L to 1. |
283 | // |
284 | // Examples: |
285 | // |
286 | // Exact half value that is already even: |
287 | // Input: |
288 | // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) |
289 | // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT |
290 | // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000 |
291 | // |
292 | // This falls into case 3. We truncate the rest of 16 bits and no |
293 | // carry is created into F and L: |
294 | // |
295 | // Output: |
296 | // Sign | Exp (8 bit) | Frac (first 7 bit) |
297 | // S E E E E E E E E F F F F F F L |
298 | // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 |
299 | // |
300 | // Exact half value, round to next even number: |
301 | // Input: |
302 | // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) |
303 | // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT |
304 | // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000 |
305 | // |
306 | // This falls into case 4. We create a carry from R and T, |
307 | // which then propagates into L and F: |
308 | // |
309 | // Output: |
310 | // Sign | Exp (8 bit) | Frac (first 7 bit) |
311 | // S E E E E E E E E F F F F F F L |
312 | // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 |
313 | // |
314 | // |
315 | // Max denormal value round to min normal value: |
316 | // Input: |
317 | // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) |
318 | // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT |
319 | // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111 |
320 | // |
321 | // This falls into case 4. We create a carry from R and T, |
322 | // propagate into L and F, which then propagates into exponent |
323 | // bits: |
324 | // |
325 | // Output: |
326 | // Sign | Exp (8 bit) | Frac (first 7 bit) |
327 | // S E E E E E E E E F F F F F F L |
328 | // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 |
329 | // |
330 | // Max normal value round to Inf: |
331 | // Input: |
332 | // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) |
333 | // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT |
334 | // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111 |
335 | // |
336 | // This falls into case 4. We create a carry from R and T, |
337 | // propagate into L and F, which then propagates into exponent |
338 | // bits: |
339 | // |
340 | // Sign | Exp (8 bit) | Frac (first 7 bit) |
341 | // S E E E E E E E E F F F F F F L |
342 | // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 |
343 | // |
344 | // |
345 | // Least significant bit of resulting bfloat. |
346 | uint32_t lsb = (input >> 16) & 1; |
347 | uint32_t rounding_bias = 0x7fff + lsb; |
348 | input += rounding_bias; |
349 | output.value = static_cast<uint16_t>(input >> 16); |
350 | } |
351 | return output; |
352 | } |
353 | |
354 | static bfloat16 epsilon() { |
355 | bfloat16 x; |
356 | x.value = 0x3c00; // 0x1.0p-7 |
357 | return x; |
358 | } |
359 | |
360 | uint16_t value; |
361 | |
362 | // A value that represents "not a number". |
363 | static const uint16_t NAN_VALUE = 0x7FC0; |
364 | |
365 | private: |
366 | B16_DEVICE_FUNC static bool float_isnan(const float& x) { |
367 | #ifdef __CUDA_ARCH__ |
368 | return ::isnan(x); |
369 | #else |
370 | return std::isnan(x); |
371 | #endif |
372 | } |
373 | }; |
374 | |
375 | B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os, |
376 | const bfloat16& dt) { |
377 | os << static_cast<float>(dt); |
378 | return os; |
379 | } |
380 | |
381 | B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) { |
382 | return bfloat16(static_cast<float>(a) + static_cast<float>(b)); |
383 | } |
384 | B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) { |
385 | return bfloat16(static_cast<float>(a) + static_cast<float>(b)); |
386 | } |
387 | B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) { |
388 | return bfloat16(static_cast<float>(a) + static_cast<float>(b)); |
389 | } |
390 | B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) { |
391 | return bfloat16(static_cast<float>(a) - static_cast<float>(b)); |
392 | } |
393 | B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) { |
394 | return bfloat16(static_cast<float>(a) * static_cast<float>(b)); |
395 | } |
396 | B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) { |
397 | return bfloat16(static_cast<float>(a) / static_cast<float>(b)); |
398 | } |
399 | B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) { |
400 | a.value ^= 0x8000; |
401 | return a; |
402 | } |
403 | B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) { |
404 | return static_cast<float>(a) < static_cast<float>(b); |
405 | } |
406 | B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) { |
407 | return static_cast<float>(a) <= static_cast<float>(b); |
408 | } |
409 | B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) { |
410 | return static_cast<float>(a) == static_cast<float>(b); |
411 | } |
412 | B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) { |
413 | return static_cast<float>(a) != static_cast<float>(b); |
414 | } |
415 | B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) { |
416 | return static_cast<float>(a) > static_cast<float>(b); |
417 | } |
418 | B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) { |
419 | return static_cast<float>(a) >= static_cast<float>(b); |
420 | } |
421 | B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) { |
422 | a = a + b; |
423 | return a; |
424 | } |
425 | B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) { |
426 | a = a - b; |
427 | return a; |
428 | } |
429 | B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) { |
430 | a += bfloat16(1); |
431 | return a; |
432 | } |
433 | B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) { |
434 | a -= bfloat16(1); |
435 | return a; |
436 | } |
437 | B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) { |
438 | bfloat16 original_value = a; |
439 | ++a; |
440 | return original_value; |
441 | } |
442 | B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) { |
443 | bfloat16 original_value = a; |
444 | --a; |
445 | return original_value; |
446 | } |
447 | B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) { |
448 | a = a * b; |
449 | return a; |
450 | } |
451 | B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) { |
452 | a = a / b; |
453 | return a; |
454 | } |
455 | } // end namespace tensorflow |
456 | |
457 | namespace std { |
458 | template <> |
459 | struct hash<tensorflow::bfloat16> { |
460 | size_t operator()(const tensorflow::bfloat16& v) const { |
461 | return hash<float>()(static_cast<float>(v)); |
462 | } |
463 | }; |
464 | |
465 | using tensorflow::bfloat16; |
466 | inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); } |
467 | inline bool isnan(const bfloat16& a) { return std::isnan(float(a)); } |
468 | inline bool isfinite(const bfloat16& a) { return std::isfinite(float(a)); } |
469 | inline bfloat16 abs(const bfloat16& a) { return bfloat16(std::abs(float(a))); } |
470 | inline bfloat16 exp(const bfloat16& a) { return bfloat16(std::exp(float(a))); } |
471 | inline bfloat16 log(const bfloat16& a) { return bfloat16(std::log(float(a))); } |
472 | inline bfloat16 log10(const bfloat16& a) { |
473 | return bfloat16(std::log10(float(a))); |
474 | } |
475 | inline bfloat16 sqrt(const bfloat16& a) { |
476 | return bfloat16(std::sqrt(float(a))); |
477 | } |
478 | inline bfloat16 pow(const bfloat16& a, const bfloat16& b) { |
479 | return bfloat16(std::pow(float(a), float(b))); |
480 | } |
481 | inline bfloat16 sin(const bfloat16& a) { return bfloat16(std::sin(float(a))); } |
482 | inline bfloat16 cos(const bfloat16& a) { return bfloat16(std::cos(float(a))); } |
483 | inline bfloat16 tan(const bfloat16& a) { return bfloat16(std::tan(float(a))); } |
484 | inline bfloat16 tanh(const bfloat16& a) { |
485 | return bfloat16(std::tanh(float(a))); |
486 | } |
487 | inline bfloat16 floor(const bfloat16& a) { |
488 | return bfloat16(std::floor(float(a))); |
489 | } |
490 | inline bfloat16 ceil(const bfloat16& a) { |
491 | return bfloat16(std::ceil(float(a))); |
492 | } |
493 | } // namespace std |
494 | |
495 | #endif // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ |
496 | |