1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
17#define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
18
19#include <cmath>
20#include <complex>
21
22// We need cpu_info.h here in order to pick up __BYTE_ORDER__.
23#include "tensorflow/core/platform/cpu_info.h"
24
25#ifdef __CUDACC__
26// All functions callable from CUDA code must be qualified with __device__
27#define B16_DEVICE_FUNC __host__ __device__
28
29#else
30#define B16_DEVICE_FUNC
31
32#endif
33
34namespace Eigen {
35struct half;
36}
37
38namespace tensorflow {
39
40// Single precision complex.
41typedef std::complex<float> complex64;
42// Double precision complex.
43typedef std::complex<double> complex128;
44
45// see framework/bfloat16.h for description.
46struct bfloat16 {
47 B16_DEVICE_FUNC bfloat16() {}
48
49 B16_DEVICE_FUNC explicit bfloat16(const float v) {
50 if (float_isnan(v)) {
51 value = NAN_VALUE;
52 return;
53 }
54 const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
55#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
56 value = p[0];
57#else
58 value = p[1];
59#endif
60 }
61
62 B16_DEVICE_FUNC explicit bfloat16(const double val)
63 : bfloat16(static_cast<float>(val)) {}
64 // Following the convention of numpy, converting between complex and
65 // float will lead to loss of imag value.
66 B16_DEVICE_FUNC explicit bfloat16(const complex64& val)
67 : bfloat16(val.real()) {}
68
69 B16_DEVICE_FUNC explicit bfloat16(const complex128& val)
70 : bfloat16(static_cast<float>(val.real())) {}
71
72 B16_DEVICE_FUNC explicit bfloat16(const unsigned short val)
73 : bfloat16(static_cast<float>(val)) {}
74
75 B16_DEVICE_FUNC explicit bfloat16(const unsigned int val)
76 : bfloat16(static_cast<float>(val)) {}
77
78 B16_DEVICE_FUNC explicit bfloat16(const int val)
79 : bfloat16(static_cast<float>(val)) {}
80
81 B16_DEVICE_FUNC explicit bfloat16(const long val)
82 : bfloat16(static_cast<float>(val)) {}
83
84 B16_DEVICE_FUNC explicit bfloat16(const long long val)
85 : bfloat16(static_cast<float>(val)) {}
86
87 template <class T>
88 B16_DEVICE_FUNC explicit bfloat16(const T& val)
89 : bfloat16(static_cast<float>(val)) {}
90
91 B16_DEVICE_FUNC explicit operator float() const {
92 float result;
93
94 uint16_t* q = reinterpret_cast<uint16_t*>(&result);
95
96#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
97 q[0] = value;
98 q[1] = 0;
99#else
100 q[0] = 0;
101 q[1] = value;
102#endif
103 return result;
104 }
105
106 B16_DEVICE_FUNC explicit operator bool() const {
107 return static_cast<bool>(float(*this));
108 }
109
110 B16_DEVICE_FUNC explicit operator Eigen::half() const;
111
112 B16_DEVICE_FUNC explicit operator short() const {
113 return static_cast<short>(float(*this));
114 }
115
116 B16_DEVICE_FUNC explicit operator int() const {
117 return static_cast<int>(float(*this));
118 }
119
120 B16_DEVICE_FUNC explicit operator long() const {
121 return static_cast<long>(float(*this));
122 }
123
124 B16_DEVICE_FUNC explicit operator char() const {
125 return static_cast<char>(float(*this));
126 }
127
128 B16_DEVICE_FUNC explicit operator signed char() const {
129 return static_cast<signed char>(float(*this));
130 }
131
132 B16_DEVICE_FUNC explicit operator unsigned char() const {
133 return static_cast<unsigned char>(float(*this));
134 }
135
136 B16_DEVICE_FUNC explicit operator unsigned short() const {
137 return static_cast<unsigned short>(float(*this));
138 }
139
140 B16_DEVICE_FUNC explicit operator unsigned int() const {
141 return static_cast<unsigned int>(float(*this));
142 }
143
144 B16_DEVICE_FUNC explicit operator unsigned long() const {
145 return static_cast<unsigned long>(float(*this));
146 }
147
148 B16_DEVICE_FUNC explicit operator unsigned long long() const {
149 return static_cast<unsigned long long>(float(*this));
150 }
151
152 B16_DEVICE_FUNC explicit operator long long() const {
153 return static_cast<long long>(float(*this));
154 }
155
156 B16_DEVICE_FUNC explicit operator double() const {
157 return static_cast<double>(float(*this));
158 }
159
160 B16_DEVICE_FUNC explicit operator complex64() const {
161 return complex64(float(*this), float(0.0));
162 }
163
164 B16_DEVICE_FUNC explicit operator complex128() const {
165 return complex128(double(*this), double(0.0));
166 }
167
168 union FP32 {
169 unsigned int u;
170 float f;
171 };
172
173 // Converts a float point to bfloat16, with round-nearest-to-even as rounding
174 // method.
175 // TODO(b/69266521): Add a truncate_to_bfloat16 function and make this
176 // function as default behavior.
177 // TODO: There is a slightly faster implementation (8% faster on CPU)
178 // than this (documented in cl/175987786), that is exponentially harder to
179 // understand and document. Switch to the faster version when converting to
180 // BF16 becomes compute-bound.
181 B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) {
182 uint32_t input;
183 FP32 f;
184 f.f = v;
185 input = f.u;
186 bfloat16 output;
187
188 if (float_isnan(v)) {
189 // If the value is a NaN, squash it to a qNaN with msb of fraction set,
190 // this makes sure after truncation we don't end up with an inf.
191 //
192 // qNaN magic: All exponent bits set + most significant bit of fraction
193 // set.
194 output.value = 0x7fc0;
195 } else {
196 // Fast rounding algorithm that rounds a half value to nearest even. This
197 // reduces expected error when we convert a large number of floats. Here
198 // is how it works:
199 //
200 // Definitions:
201 // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
202 // with the following tags:
203 //
204 // Sign | Exp (8 bits) | Frac (23 bits)
205 // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
206 //
207 // S: Sign bit.
208 // E: Exponent bits.
209 // F: First 6 bits of fraction.
210 // L: Least significant bit of resulting bfloat16 if we truncate away the
211 // rest of the float32. This is also the 7th bit of fraction
212 // R: Rounding bit, 8th bit of fraction.
213 // T: Sticky bits, rest of fraction, 15 bits.
214 //
215 // To round half to nearest even, there are 3 cases where we want to round
216 // down (simply truncate the result of the bits away, which consists of
217 // rounding bit and sticky bits) and two cases where we want to round up
218 // (truncate then add one to the result).
219 //
220 // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
221 // 1s) as the rounding bias, adds the rounding bias to the input, then
222 // truncates the last 16 bits away.
223 //
224 // To understand how it works, we can analyze this algorithm case by case:
225 //
226 // 1. L = 0, R = 0:
227 // Expect: round down, this is less than half value.
228 //
229 // Algorithm:
230 // - Rounding bias: 0x7fff + 0 = 0x7fff
231 // - Adding rounding bias to input may create any carry, depending on
232 // whether there is any value set to 1 in T bits.
233 // - R may be set to 1 if there is a carry.
234 // - L remains 0.
235 // - Note that this case also handles Inf and -Inf, where all fraction
236 // bits, including L, R and Ts are all 0. The output remains Inf after
237 // this algorithm.
238 //
239 // 2. L = 1, R = 0:
240 // Expect: round down, this is less than half value.
241 //
242 // Algorithm:
243 // - Rounding bias: 0x7fff + 1 = 0x8000
244 // - Adding rounding bias to input doesn't change sticky bits but
245 // adds 1 to rounding bit.
246 // - L remains 1.
247 //
248 // 3. L = 0, R = 1, all of T are 0:
249 // Expect: round down, this is exactly at half, the result is already
250 // even (L=0).
251 //
252 // Algorithm:
253 // - Rounding bias: 0x7fff + 0 = 0x7fff
254 // - Adding rounding bias to input sets all sticky bits to 1, but
255 // doesn't create a carry.
256 // - R remains 1.
257 // - L remains 0.
258 //
259 // 4. L = 1, R = 1:
260 // Expect: round up, this is exactly at half, the result needs to be
261 // round to the next even number.
262 //
263 // Algorithm:
264 // - Rounding bias: 0x7fff + 1 = 0x8000
265 // - Adding rounding bias to input doesn't change sticky bits, but
266 // creates a carry from rounding bit.
267 // - The carry sets L to 0, creates another carry bit and propagate
268 // forward to F bits.
269 // - If all the F bits are 1, a carry then propagates to the exponent
270 // bits, which then creates the minimum value with the next exponent
271 // value. Note that we won't have the case where exponents are all 1,
272 // since that's either a NaN (handled in the other if condition) or inf
273 // (handled in case 1).
274 //
275 // 5. L = 0, R = 1, any of T is 1:
276 // Expect: round up, this is greater than half.
277 //
278 // Algorithm:
279 // - Rounding bias: 0x7fff + 0 = 0x7fff
280 // - Adding rounding bias to input creates a carry from sticky bits,
281 // sets rounding bit to 0, then create another carry.
282 // - The second carry sets L to 1.
283 //
284 // Examples:
285 //
286 // Exact half value that is already even:
287 // Input:
288 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
289 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
290 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
291 //
292 // This falls into case 3. We truncate the rest of 16 bits and no
293 // carry is created into F and L:
294 //
295 // Output:
296 // Sign | Exp (8 bit) | Frac (first 7 bit)
297 // S E E E E E E E E F F F F F F L
298 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
299 //
300 // Exact half value, round to next even number:
301 // Input:
302 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
303 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
304 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
305 //
306 // This falls into case 4. We create a carry from R and T,
307 // which then propagates into L and F:
308 //
309 // Output:
310 // Sign | Exp (8 bit) | Frac (first 7 bit)
311 // S E E E E E E E E F F F F F F L
312 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
313 //
314 //
315 // Max denormal value round to min normal value:
316 // Input:
317 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
318 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
319 // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
320 //
321 // This falls into case 4. We create a carry from R and T,
322 // propagate into L and F, which then propagates into exponent
323 // bits:
324 //
325 // Output:
326 // Sign | Exp (8 bit) | Frac (first 7 bit)
327 // S E E E E E E E E F F F F F F L
328 // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
329 //
330 // Max normal value round to Inf:
331 // Input:
332 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
333 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
334 // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
335 //
336 // This falls into case 4. We create a carry from R and T,
337 // propagate into L and F, which then propagates into exponent
338 // bits:
339 //
340 // Sign | Exp (8 bit) | Frac (first 7 bit)
341 // S E E E E E E E E F F F F F F L
342 // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
343 //
344 //
345 // Least significant bit of resulting bfloat.
346 uint32_t lsb = (input >> 16) & 1;
347 uint32_t rounding_bias = 0x7fff + lsb;
348 input += rounding_bias;
349 output.value = static_cast<uint16_t>(input >> 16);
350 }
351 return output;
352 }
353
354 static bfloat16 epsilon() {
355 bfloat16 x;
356 x.value = 0x3c00; // 0x1.0p-7
357 return x;
358 }
359
360 uint16_t value;
361
362 // A value that represents "not a number".
363 static const uint16_t NAN_VALUE = 0x7FC0;
364
365 private:
366 B16_DEVICE_FUNC static bool float_isnan(const float& x) {
367#ifdef __CUDA_ARCH__
368 return ::isnan(x);
369#else
370 return std::isnan(x);
371#endif
372 }
373};
374
375B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os,
376 const bfloat16& dt) {
377 os << static_cast<float>(dt);
378 return os;
379}
380
381B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
382 return bfloat16(static_cast<float>(a) + static_cast<float>(b));
383}
384B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) {
385 return bfloat16(static_cast<float>(a) + static_cast<float>(b));
386}
387B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) {
388 return bfloat16(static_cast<float>(a) + static_cast<float>(b));
389}
390B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
391 return bfloat16(static_cast<float>(a) - static_cast<float>(b));
392}
393B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
394 return bfloat16(static_cast<float>(a) * static_cast<float>(b));
395}
396B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
397 return bfloat16(static_cast<float>(a) / static_cast<float>(b));
398}
399B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) {
400 a.value ^= 0x8000;
401 return a;
402}
403B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) {
404 return static_cast<float>(a) < static_cast<float>(b);
405}
406B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) {
407 return static_cast<float>(a) <= static_cast<float>(b);
408}
409B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) {
410 return static_cast<float>(a) == static_cast<float>(b);
411}
412B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) {
413 return static_cast<float>(a) != static_cast<float>(b);
414}
415B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) {
416 return static_cast<float>(a) > static_cast<float>(b);
417}
418B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) {
419 return static_cast<float>(a) >= static_cast<float>(b);
420}
421B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) {
422 a = a + b;
423 return a;
424}
425B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) {
426 a = a - b;
427 return a;
428}
429B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) {
430 a += bfloat16(1);
431 return a;
432}
433B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) {
434 a -= bfloat16(1);
435 return a;
436}
437B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) {
438 bfloat16 original_value = a;
439 ++a;
440 return original_value;
441}
442B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) {
443 bfloat16 original_value = a;
444 --a;
445 return original_value;
446}
447B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) {
448 a = a * b;
449 return a;
450}
451B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) {
452 a = a / b;
453 return a;
454}
455} // end namespace tensorflow
456
457namespace std {
458template <>
459struct hash<tensorflow::bfloat16> {
460 size_t operator()(const tensorflow::bfloat16& v) const {
461 return hash<float>()(static_cast<float>(v));
462 }
463};
464
465using tensorflow::bfloat16;
466inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); }
467inline bool isnan(const bfloat16& a) { return std::isnan(float(a)); }
468inline bool isfinite(const bfloat16& a) { return std::isfinite(float(a)); }
469inline bfloat16 abs(const bfloat16& a) { return bfloat16(std::abs(float(a))); }
470inline bfloat16 exp(const bfloat16& a) { return bfloat16(std::exp(float(a))); }
471inline bfloat16 log(const bfloat16& a) { return bfloat16(std::log(float(a))); }
472inline bfloat16 log10(const bfloat16& a) {
473 return bfloat16(std::log10(float(a)));
474}
475inline bfloat16 sqrt(const bfloat16& a) {
476 return bfloat16(std::sqrt(float(a)));
477}
478inline bfloat16 pow(const bfloat16& a, const bfloat16& b) {
479 return bfloat16(std::pow(float(a), float(b)));
480}
481inline bfloat16 sin(const bfloat16& a) { return bfloat16(std::sin(float(a))); }
482inline bfloat16 cos(const bfloat16& a) { return bfloat16(std::cos(float(a))); }
483inline bfloat16 tan(const bfloat16& a) { return bfloat16(std::tan(float(a))); }
484inline bfloat16 tanh(const bfloat16& a) {
485 return bfloat16(std::tanh(float(a)));
486}
487inline bfloat16 floor(const bfloat16& a) {
488 return bfloat16(std::floor(float(a)));
489}
490inline bfloat16 ceil(const bfloat16& a) {
491 return bfloat16(std::ceil(float(a)));
492}
493} // namespace std
494
495#endif // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_
496