1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Utilities for dealing with Literal protobufs. |
17 | |
18 | #ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ |
19 | #define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ |
20 | |
21 | #include <functional> |
22 | #include <initializer_list> |
23 | #include <iterator> |
24 | #include <memory> |
25 | #include <ostream> |
26 | #include <string> |
27 | #include <type_traits> |
28 | #include <vector> |
29 | |
30 | #include "tensorflow/compiler/xla/array2d.h" |
31 | #include "tensorflow/compiler/xla/array3d.h" |
32 | #include "tensorflow/compiler/xla/array4d.h" |
33 | #include "tensorflow/compiler/xla/index_util.h" |
34 | #include "tensorflow/compiler/xla/layout_util.h" |
35 | #include "tensorflow/compiler/xla/primitive_util.h" |
36 | #include "tensorflow/compiler/xla/ptr_util.h" |
37 | #include "tensorflow/compiler/xla/shape_tree.h" |
38 | #include "tensorflow/compiler/xla/shape_util.h" |
39 | #include "tensorflow/compiler/xla/sparse_index_array.h" |
40 | #include "tensorflow/compiler/xla/status_macros.h" |
41 | #include "tensorflow/compiler/xla/types.h" |
42 | #include "tensorflow/compiler/xla/util.h" |
43 | #include "tensorflow/compiler/xla/xla_data.pb.h" |
44 | #include "tensorflow/core/lib/core/bitmap.h" |
45 | #include "tensorflow/core/lib/core/status.h" |
46 | #include "tensorflow/core/lib/core/stringpiece.h" |
47 | #include "tensorflow/core/lib/gtl/array_slice.h" |
48 | #include "tensorflow/core/platform/logging.h" |
49 | #include "tensorflow/core/platform/macros.h" |
50 | #include "tensorflow/core/platform/protobuf.h" |
51 | #include "tensorflow/core/platform/types.h" |
52 | |
53 | namespace xla { |
54 | |
55 | // Class representing literal values in XLA. |
56 | // |
57 | // TODO(b/67651157): The methods in this class should be reduced to a minimal |
58 | // set of methods which construct Literals and accessors methods. Other methods |
59 | // which perform computation on Literals (Reshape, Slice, etc) should be moved |
60 | // elsewhere, and perhaps combined with evaluator code which operates on |
61 | // Literals. |
62 | class Literal { |
63 | public: |
64 | Literal() : Literal(ShapeUtil::MakeNil()) {} |
65 | |
66 | // Create a literal of the given shape. The literal is allocated sufficient |
67 | // memory to hold the shape. Memory is uninitialized. |
68 | explicit Literal(const Shape& shape); |
69 | virtual ~Literal(); |
70 | |
71 | // Literals are moveable, but not copyable. To copy a literal use |
72 | // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies |
73 | // of literals which can be expensive. |
74 | Literal(const Literal& other) = delete; |
75 | Literal& operator=(const Literal& other) = delete; |
76 | Literal(Literal&& other); |
77 | Literal& operator=(Literal&& other); |
78 | |
79 | // Literals are equal if they have compatible shapes and the same data |
80 | // values. Layout is not compared. |
81 | bool operator==(const Literal& other) const; |
82 | bool operator!=(const Literal& other) const { return !(*this == other); } |
83 | |
84 | // Serialize to and from a proto. |
85 | static StatusOr<std::unique_ptr<Literal>> CreateFromProto( |
86 | const LiteralProto& proto); |
87 | LiteralProto ToProto() const; |
88 | |
89 | // Return the shape of the literal. |
90 | const Shape& shape() const { return shape_; } |
91 | |
92 | // TODO(b/67651157): Remove this accessor. Literal users should not be able to |
93 | // mutate the shape as this can produce malformed Literals. |
94 | Shape* mutable_shape_do_not_use() { return &shape_; } |
95 | |
96 | // Returns a (Mutable)ArraySlice view of the array for this literal for the |
97 | // given NativeT (e.g., float). CHECKs if the subshape of the literal at the |
98 | // given ShapeIndex is not array. See primitive_util.h for the mapping from |
99 | // XLA type to native type. |
100 | template <typename NativeT> |
101 | tensorflow::gtl::ArraySlice<NativeT> data( |
102 | const ShapeIndex& shape_index = {}) const; |
103 | template <typename NativeT> |
104 | tensorflow::gtl::MutableArraySlice<NativeT> data( |
105 | const ShapeIndex& shape_index = {}); |
106 | |
107 | // Returns a pointer to the sparse index array. Returns nullptr if the literal |
108 | // is not a sparse array. |
109 | const SparseIndexArray* sparse_indices( |
110 | const ShapeIndex& shape_index = {}) const; |
111 | SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); |
112 | |
113 | // Returns a pointer to (or size of) the underlying buffer holding the array |
114 | // at the given shape index. CHECKs if the subshape of the literal at the |
115 | // given ShapeIndex is not array. |
116 | const void* untyped_data(const ShapeIndex& shape_index = {}) const; |
117 | void* untyped_data(const ShapeIndex& shape_index = {}); |
118 | int64 size_bytes(const ShapeIndex& shape_index = {}) const; |
119 | |
120 | // Creates a new literal of a given rank. To minimize ambiguity (for users |
121 | // and the compiler) these CreateR[0-2] methods should explicitly specify the |
122 | // native type. For example: |
123 | // |
124 | // CreateR1<float>({1.0, 42.0}); |
125 | // CreateR2<uint32>({{1, 2}, {3, 4}}); |
126 | // |
127 | // The variants not ending with WithLayout use the default XLA layout for the |
128 | // literal's linear representation in memory. |
129 | template <typename NativeT> |
130 | static std::unique_ptr<Literal> CreateR0(NativeT value); |
131 | template <typename NativeT> |
132 | static std::unique_ptr<Literal> CreateR1( |
133 | tensorflow::gtl::ArraySlice<NativeT> values); |
134 | static std::unique_ptr<Literal> CreateR1( |
135 | const tensorflow::core::Bitmap& values); |
136 | template <typename NativeT> |
137 | static std::unique_ptr<Literal> CreateR2( |
138 | std::initializer_list<std::initializer_list<NativeT>> values); |
139 | template <typename NativeT> |
140 | static std::unique_ptr<Literal> CreateR2WithLayout( |
141 | std::initializer_list<std::initializer_list<NativeT>> values, |
142 | const Layout& layout); |
143 | template <typename NativeT> |
144 | static std::unique_ptr<Literal> CreateR3( |
145 | std::initializer_list< |
146 | std::initializer_list<std::initializer_list<NativeT>>> |
147 | values); |
148 | template <typename NativeT> |
149 | static std::unique_ptr<Literal> CreateR3WithLayout( |
150 | std::initializer_list< |
151 | std::initializer_list<std::initializer_list<NativeT>>> |
152 | values, |
153 | const Layout& layout); |
154 | template <typename NativeT> |
155 | static std::unique_ptr<Literal> CreateR4( |
156 | std::initializer_list<std::initializer_list< |
157 | std::initializer_list<std::initializer_list<NativeT>>>> |
158 | values); |
159 | template <typename NativeT> |
160 | static std::unique_ptr<Literal> CreateR4WithLayout( |
161 | std::initializer_list<std::initializer_list< |
162 | std::initializer_list<std::initializer_list<NativeT>>>> |
163 | values, |
164 | const Layout& layout); |
165 | |
166 | // Returns this literal's data as a string. This literal must be a rank-1 U8 |
167 | // array. |
168 | string GetR1U8AsString() const; |
169 | |
170 | // Creates a literal with a sparse layout and the given indices and values. |
171 | // The shape is initialized from the given dimensions. The minor dimension of |
172 | // the indices array must equal the rank of the shape (i.e. size of the |
173 | // dimensions array). The major dimension of the indices array must equal the |
174 | // number of elements in the values array. The maximum number of elements in |
175 | // the array is taken from the max_indices() value of the index array. |
176 | // |
177 | // XLA assumes that sparse literals are in sorted order for all operations. If |
178 | // the `sort` argument is true, then the indices and values will be sorted |
179 | // while copying them into the literal. If you have ensured that the indices |
180 | // and values are already sorted, then you may set the `sort` argument to |
181 | // false to skip the sorting step. |
182 | // |
183 | // For example: |
184 | // |
185 | // CreateSparse( |
186 | // {12, 12, 12}, |
187 | // SparseIndexArray(10, 3, |
188 | // Array2D{ |
189 | // {0, 1, 2}, |
190 | // {3, 4, 5}, |
191 | // {6, 7, 8}, |
192 | // {9, 10, 11}, |
193 | // }), |
194 | // {1.0, 2.0 3.0, 4.0}) |
195 | // |
196 | // This creates an array with shape F64[12,12,12]sparse{10}, that has the |
197 | // following non-zero values: |
198 | // |
199 | // [0, 1, 2]: 1.0 |
200 | // [3, 4, 5]: 2.0 |
201 | // [6, 7, 8]: 3.0 |
202 | // [9, 10, 11]: 4.0 |
203 | // |
204 | template <typename NativeT> |
205 | static std::unique_ptr<Literal> CreateSparse( |
206 | tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices, |
207 | tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true); |
208 | |
209 | // Populates a literal with a sparse layout with the given indices and values. |
210 | // Each index in the indices array is CHECKed against the dimensions in the |
211 | // literal's shape. If sort is true, then the indices and values will be |
212 | // sorted. If sort is false, then the indices and values are assumed to |
213 | // already be in sorted order. See CreateSparse for an example of how data |
214 | // are populated. |
215 | template <typename NativeT> |
216 | void PopulateSparse(SparseIndexArray indices, |
217 | tensorflow::gtl::ArraySlice<NativeT> values, |
218 | bool sort = true); |
219 | |
220 | // Creates a new Literal object with the shape specified as parameter. |
221 | // The content of the literal values is the default value of the primitive |
222 | // type of literal itself (0 for numeric types, and false for predicates). |
223 | static std::unique_ptr<Literal> CreateFromShape(const Shape& shape); |
224 | |
225 | // Creates a new Literal object with its values havings the primitive_type |
226 | // type, and with dimensions defined by the dimensions parameter. |
227 | // The content of the literal values is the default value of the primitive |
228 | // type of literal itself (0 for numeric types, and false for predicates). |
229 | static std::unique_ptr<Literal> CreateFromDimensions( |
230 | PrimitiveType primitive_type, |
231 | tensorflow::gtl::ArraySlice<int64> dimensions); |
232 | |
233 | // Copy values from 'src_literal' rooted at 'src_shape_index' into this |
234 | // literal rooted at 'dest_shape_index'. The subshape of this literal rooted |
235 | // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' |
236 | // rooted at 'src_shape_index', but need not be arrays. |
237 | Status CopyFrom(const Literal& src_literal, |
238 | const ShapeIndex& dest_shape_index = {}, |
239 | const ShapeIndex& src_shape_index = {}); |
240 | |
241 | // Similar to CopyFrom, but with move semantincs. The subshape of this literal |
242 | // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' |
243 | // (layouts and shapes must match), but need not be arrays. The memory |
244 | // allocated in this literal for the subshape at dest_shape_index is |
245 | // deallocated, and the respective buffers are replaced with those in |
246 | // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). |
247 | Status MoveFrom(Literal&& src_literal, |
248 | const ShapeIndex& dest_shape_index = {}); |
249 | |
250 | // Copies the values from src_literal, starting at src_base shape indexes, |
251 | // to this literal, starting at dest_base, where the copy size in each |
252 | // dimension is specified by copy_size. |
253 | // The src_literal and this literal must have the same primitive type, |
254 | // src_base+copy_size must fit the source literal dimensions, as well as |
255 | // dest_base+copy_size must fit the destination literal dimensions. |
256 | // Note: if either src_literal or this literal contains dimensions with zero |
257 | // element, then copy_size must be 0 in these dimensions while the |
258 | // corresponding base indices being 0. |
259 | // This literal and 'src_literal' must be arrays. |
260 | Status CopySliceFrom(const Literal& src_literal, |
261 | tensorflow::gtl::ArraySlice<int64> src_base, |
262 | tensorflow::gtl::ArraySlice<int64> dest_base, |
263 | tensorflow::gtl::ArraySlice<int64> copy_size); |
264 | |
265 | // Copies one element from src_literal[src_index] to (*this)[dest_index]. |
266 | Status CopyElementFrom(const Literal& src_literal, |
267 | tensorflow::gtl::ArraySlice<int64> src_index, |
268 | tensorflow::gtl::ArraySlice<int64> dest_index); |
269 | |
270 | // Returns a vector containing the tuple elements of this Literal as separate |
271 | // Literals. This Literal must be tuple-shaped and can be a nested tuple. The |
272 | // elements are moved into the new Literals; no data is copied. Upon return |
273 | // this Literal is set to a nil shape (empty tuple) |
274 | std::vector<Literal> DecomposeTuple(); |
275 | |
276 | // This operation is the inverse of DecomposeTuple. The given elements are |
277 | // moved into the tuple elements of a new tuple-shaped Literal which is |
278 | // returned. Upon return, each of the Literals in 'elements' is set to a nil |
279 | // shape (empty tuple). |
280 | static Literal MoveIntoTuple( |
281 | tensorflow::gtl::MutableArraySlice<Literal> elements); |
282 | |
283 | // Creates a new value that has the equivalent value as this literal, but |
284 | // conforms to new_layout; e.g. a literal matrix that was in {0, 1} |
285 | // minor-to-major dimension layout can be re-layed-out as {1, 0} |
286 | // minor-to-major dimension layout and the value in the cell at any given |
287 | // logical index (i0, i1) will be the same. |
288 | // |
289 | // For tuple shaped literals, shape_index should be used to select the inner |
290 | // array that the new layout applies to. |
291 | // |
292 | // Note: this is useful when the client wants to ensure that a value placed in |
293 | // the XLA allocation tracker has a particular layout; for efficiency |
294 | // purposes or avoiding unimplemented operation/layout combinations. |
295 | std::unique_ptr<Literal> Relayout(const Layout& new_layout, |
296 | const ShapeIndex& shape_index = {}) const; |
297 | |
298 | // An overload of Relayout which changes the layout of the entire shape rather |
299 | // than being limited to a single array within the shape. |
300 | std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const; |
301 | |
302 | // Creates a new literal by reshaping this literal to have the given |
303 | // dimensions. The total number of elements must not change; The |
304 | // implementation currently only supports monotonic dim0-major layouts. |
305 | // This literal must be an array. |
306 | StatusOr<std::unique_ptr<Literal>> Reshape( |
307 | tensorflow::gtl::ArraySlice<int64> dimensions) const; |
308 | |
309 | // Creates a new literal by reordering the dimensions of this literal. |
310 | // The given `permutation` must be a permutation of the dimension numbers |
311 | // in the original literal, and it specifies the order of the new dimensions |
312 | // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). |
313 | // For example, a transpose call on a literal of shape [3 x 8 x 4] and |
314 | // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. |
315 | // This literal must be an array. |
316 | std::unique_ptr<Literal> Transpose( |
317 | tensorflow::gtl::ArraySlice<int64> permutation) const; |
318 | |
319 | // Creates a sub-array from this literal by extracting the indices |
320 | // [start_index, limit_index) of each dimension. The result literal has the |
321 | // same rank and layout as for the given literal. The number of indices in |
322 | // start_indices and limit_indices must be the rank of the literal, and the |
323 | // indices follow the order of the dimensions. |
324 | // This literal must be an array. |
325 | std::unique_ptr<Literal> Slice( |
326 | tensorflow::gtl::ArraySlice<int64> start_indices, |
327 | tensorflow::gtl::ArraySlice<int64> limit_indices) const; |
328 | |
329 | // Creates a literal with a prepended dimension with bound "times"; e.g. a |
330 | // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this |
331 | // literal replicated four times. |
332 | // This literal must be an array. |
333 | template <typename NativeT> |
334 | std::unique_ptr<Literal> Replicate(int64 times) const; |
335 | |
336 | // Converts this literal to another primitive type using |
337 | // static_cast<>. Returns an error if the conversion is not possible. This |
338 | // literal must be array-shaped. |
339 | StatusOr<std::unique_ptr<Literal>> Convert( |
340 | PrimitiveType primitive_dest_type) const; |
341 | |
342 | // Converts this literal to another primitive type using a bitcast |
343 | // conversion. The to and from primitive types must have the same bit |
344 | // width. Returns an error if the conversion is not possible. This literal |
345 | // must be array-shaped. |
346 | StatusOr<std::unique_ptr<Literal>> BitcastConvert( |
347 | PrimitiveType primitive_dest_type) const; |
348 | |
349 | // Converts this literal to the given shape. Returns an error is the |
350 | // conversion is not possible. |
351 | // |
352 | // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding |
353 | // instead of truncation; otherwise, truncation is used. |
354 | // |
355 | // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes |
356 | // the default behavior. |
357 | StatusOr<std::unique_ptr<Literal>> ConvertToShape( |
358 | const Shape& dest_shape, bool round_f32_to_bf16 = false) const; |
359 | |
360 | // Creates a scalar literal value zero of the given primitive type. |
361 | static Literal Zero(PrimitiveType primitive_type); |
362 | |
363 | // Creates a scalar literal value one of the given primitive type. |
364 | static Literal One(PrimitiveType primitive_type); |
365 | |
366 | // Creates a scalar literal value containing the minimum value of the given |
367 | // primitive type. For floating-point types, returns -inf. |
368 | static Literal MinValue(PrimitiveType primitive_type); |
369 | |
370 | // Creates a scalar literal value containing the maximum value of the given |
371 | // primitive type. For floating-point types, returns inf. |
372 | static Literal MaxValue(PrimitiveType primitive_type); |
373 | |
374 | // Creates a literal of the given shape where each element is `value`. |
375 | template <typename NativeT> |
376 | static std::unique_ptr<Literal> CreateFullWithDescendingLayout( |
377 | tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value); |
378 | |
379 | // Creates a new literal from an Array type. The variants not ending with |
380 | // WithLayout use the default XLA layout for the literal's linear |
381 | // representation in memory. |
382 | template <typename NativeT> |
383 | static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values); |
384 | template <typename NativeT> |
385 | static std::unique_ptr<Literal> CreateFromArrayWithLayout( |
386 | const Array<NativeT>& values, const Layout& layout); |
387 | template <typename NativeT> |
388 | static std::unique_ptr<Literal> CreateR2FromArray2D( |
389 | const Array2D<NativeT>& values); |
390 | template <typename NativeT> |
391 | static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout( |
392 | const Array2D<NativeT>& values, const Layout& layout); |
393 | template <typename NativeT> |
394 | static std::unique_ptr<Literal> CreateR3FromArray3D( |
395 | const Array3D<NativeT>& values); |
396 | template <typename NativeT> |
397 | static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout( |
398 | const Array3D<NativeT>& values, const Layout& layout); |
399 | template <typename NativeT> |
400 | static std::unique_ptr<Literal> CreateR4FromArray4D( |
401 | const Array4D<NativeT>& values); |
402 | template <typename NativeT> |
403 | static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout( |
404 | const Array4D<NativeT>& values, const Layout& layout); |
405 | |
406 | // Creates a new vector of U8s literal value from a string. |
407 | static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value); |
408 | |
409 | // Creates a linspace-populated literal with the given number of rows and |
410 | // columns. |
411 | static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to, |
412 | int64 rows, int64 cols); |
413 | |
414 | // Creates a literal that projects the (x, y) dimensions given in values into |
415 | // the z dimension given by "projection". |
416 | template <typename NativeT> |
417 | static std::unique_ptr<Literal> CreateR3Projected( |
418 | std::initializer_list<std::initializer_list<NativeT>> values, |
419 | int64 projection); |
420 | |
421 | // Creates a literal that projects the (x, y) dimensions given in values into |
422 | // the z and p dimensions given. |
423 | template <typename NativeT> |
424 | static std::unique_ptr<Literal> CreateR4Projected( |
425 | std::initializer_list<std::initializer_list<NativeT>> values, |
426 | int64 projection_p, int64 projection_z); |
427 | |
428 | // Clones this literal into a new Literal, or new std::unique_ptr<Literal>. |
429 | Literal Clone() const; |
430 | std::unique_ptr<Literal> CloneToUnique() const; |
431 | |
432 | // Gets or sets an element in the literal at the given index. The multi_index |
433 | // is CHECKed against the dimension sizes. |
434 | template <typename NativeT> |
435 | NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index, |
436 | const ShapeIndex& shape_index) const; |
437 | template <typename NativeT> |
438 | void Set(tensorflow::gtl::ArraySlice<int64> multi_index, |
439 | const ShapeIndex& shape_index, NativeT value); |
440 | |
441 | // Overloads of Get and Set for array literals. CHECKs if the literal is not |
442 | // array-shaped and dense. |
443 | template <typename NativeT> |
444 | NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const; |
445 | template <typename NativeT> |
446 | void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value); |
447 | |
448 | // Returns the multi-index of the element in a sparse literal at the given |
449 | // sparse element number. The sparse element number is the position with in |
450 | // the sparse array's list of (index, value) pairs, and is checked against the |
451 | // total number of (index, value) pairs in the sparse array. |
452 | tensorflow::gtl::ArraySlice<int64> GetSparseIndex( |
453 | int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; |
454 | |
455 | // Returns the value of the element in a sparse literal at the given sparse |
456 | // element number. The sparse element number is the position with in the |
457 | // sparse array's list of (index, value) pairs, and is checked against the |
458 | // total number of (index, value) pairs in the sparse array. |
459 | template <typename NativeT> |
460 | NativeT GetSparseElement(int64 sparse_element_number, |
461 | const ShapeIndex& shape_index = {}) const; |
462 | |
463 | // Appends the given element to the literal. If the elements are not appended |
464 | // in sorted order, then SortSparseElements should be called before calling |
465 | // other methods. This literal must have a sparse layout. |
466 | template <typename NativeT> |
467 | void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index, |
468 | NativeT value, const ShapeIndex& shape_index = {}); |
469 | |
470 | // Sorts the elements in a sparse array. |
471 | void SortSparseElements(const ShapeIndex& shape_index = {}); |
472 | |
473 | // Returns the element value at index (0, ..., 0), however many zeroes are |
474 | // required for that index. |
475 | template <typename NativeT> |
476 | NativeT GetFirstElement() const; |
477 | |
478 | // Returns a literal scalar representing the first element. |
479 | Literal GetFirstScalarLiteral() const; |
480 | |
481 | // As Get(), but determines the correct type and converts the value |
482 | // into text. |
483 | string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index, |
484 | const ShapeIndex& shape_index = {}) const; |
485 | |
486 | // As GetSparseElement(), but determines the correct type and converts the |
487 | // value into text. |
488 | string GetSparseElementAsString(int64 sparse_element_number, |
489 | const ShapeIndex& shape_index = {}) const; |
490 | |
491 | // As Get(), but determines the correct type and converts the value into |
492 | // int64. This literal must be an array. |
493 | StatusOr<int64> GetIntegralAsS64( |
494 | tensorflow::gtl::ArraySlice<int64> multi_index) const; |
495 | |
496 | // As Set(), but truncates `value` to the literal element type before storing. |
497 | // This literal must be an array. |
498 | Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index, |
499 | int64 value); |
500 | |
501 | // Returns an identity matrix (rank 2) with the given row and column count. |
502 | template <typename NativeT> |
503 | static std::unique_ptr<Literal> MakeIdentityR2(int64 size); |
504 | |
505 | // Returns a tuple literal composed of given literals. Data is copied from the |
506 | // given elements into the returned literal. |
507 | static std::unique_ptr<Literal> MakeTuple( |
508 | tensorflow::gtl::ArraySlice<const Literal*> elements); |
509 | |
510 | // As above, but intended to be invoked with move semantics; i.e. |
511 | // |
512 | // std::vector<std::unique_ptr<Literal>> elements = ...; |
513 | // auto result = Literal::MakeTupleOwned(std::move(elements)); |
514 | // |
515 | // This would have been declared as an overload, but there is ambiguity |
516 | // in invocation between the above signature and this one. |
517 | static std::unique_ptr<Literal> MakeTupleOwned( |
518 | std::vector<std::unique_ptr<Literal>> elements); |
519 | |
520 | // This overload lets you pass a braced list of unique_ptr<Literal>s to |
521 | // MakeTupleOwned: |
522 | // |
523 | // Literal::MakeTupleOwned(Literal::CreateR1(...), ...). |
524 | // |
525 | // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>) |
526 | // overload doesn't work because std::initializer_list's elements are always |
527 | // const. |
528 | // |
529 | // The arguments to this function must all be unique_ptr<Literal>. |
530 | template <typename... Ts> |
531 | static std::unique_ptr<Literal> MakeTupleOwned( |
532 | std::unique_ptr<Ts>... elements) { |
533 | std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{ |
534 | std::move(elements)...}; |
535 | std::vector<std::unique_ptr<Literal>> v; |
536 | v.insert(v.begin(), std::make_move_iterator(arr.begin()), |
537 | std::make_move_iterator(arr.end())); |
538 | return MakeTupleOwned(std::move(v)); |
539 | } |
540 | |
541 | // Returns a string representation of the literal value. |
542 | // Warning: this function can take minutes for multi-million element Literals. |
543 | string ToString(bool print_layout = false) const; |
544 | |
545 | // Invokes the "per cell" callback for each element in the provided |
546 | // literal with the element's indices and a string representation of |
547 | // the element's value. |
548 | // |
549 | // This function is useful if you want a polymorphic representation |
550 | // of the tensor's elements (turning it to a string for something |
551 | // like representation in a protobuf). |
552 | // |
553 | // This literal must have a dense layout. |
554 | void EachCellAsString( |
555 | const std::function<void(tensorflow::gtl::ArraySlice<int64> indices, |
556 | const string& value)>& per_cell) const; |
557 | template <typename NativeT> |
558 | void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices, |
559 | NativeT value)> |
560 | per_cell) const; |
561 | |
562 | // Populate this literal with the given values. Examples: |
563 | // |
564 | // // Populate with floats. |
565 | // Array2D<float> float_values = ... |
566 | // literal.PopulateR2FromArray2D(values); |
567 | // |
568 | // // Populate with int32s. |
569 | // literal.PopulateR2<int32>({{1, 2}, {3, 4}}); |
570 | // |
571 | // The shape and element type of this literal must match given values. For |
572 | // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 |
573 | // array of S32. |
574 | template <typename NativeT> |
575 | void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values); |
576 | void PopulateR1(const tensorflow::core::Bitmap& values); |
577 | template <typename NativeT> |
578 | void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values); |
579 | template <typename NativeT> |
580 | void PopulateFromArray(const Array<NativeT>& values); |
581 | template <typename NativeT> |
582 | void PopulateR2FromArray2D(const Array2D<NativeT>& values); |
583 | template <typename NativeT> |
584 | void PopulateR3FromArray3D(const Array3D<NativeT>& values); |
585 | template <typename NativeT> |
586 | void PopulateR4FromArray4D(const Array4D<NativeT>& values); |
587 | |
588 | // Populates literal values by calling the generator function for every cell |
589 | // in this literal object. |
590 | // |
591 | // generator must be a callable of the type |
592 | // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible. |
593 | // |
594 | // This literal must have a dense layout. |
595 | template <typename NativeT, typename FnType> |
596 | Status Populate(const FnType& generator); |
597 | |
598 | // A parallel version of Populate(). This can be used if the generator is |
599 | // thread-safe and the values for the shape's different elements are |
600 | // independent. |
601 | template <typename NativeT, typename FnType> |
602 | Status PopulateParallel(const FnType& generator); |
603 | |
604 | // Fills this literal with the given value. |
605 | template <typename NativeT> |
606 | void PopulateWithValue(NativeT value); |
607 | |
608 | // Returns whether every element in this literal is equal to value. |
609 | // |
610 | // value is an int8 because we expect this to be called with small |
611 | // compile-time constants (0, -1, etc.) and so that whatever value you pass |
612 | // can be represented exactly by floating-point types as small as 16 bits. |
613 | // |
614 | // If value doesn't fit in this literal's type, returns false. Values of 1/0 |
615 | // are considered equal to true/false; other values are not considered equal |
616 | // to true. Also if this literal is not array-shaped false is returned. |
617 | bool IsAll(int8 value) const; |
618 | |
619 | // Like IsAll(const Literal&, int8), except we check whether the literal is |
620 | // equal to a particular floating-point number. |
621 | // |
622 | // If the literal is not a floating-point value, this always returns false. |
623 | // |
624 | // This casts value to the type of literal, then compares using ==. The usual |
625 | // admonishments about floating-point equality checks apply. We expect you to |
626 | // use this to check for values that can be expressed precisely as a float, |
627 | // e.g. -0.5. Also if this literal is not array-shaped false is returned. |
628 | bool IsAllFloat(float value) const; |
629 | |
630 | // Like IsAll(const Literal&, int8), except we check whether the literal is |
631 | // equal to a particular complex number. |
632 | // |
633 | // If the literal is not a complex value, this always returns false. |
634 | // |
635 | // This casts value to the type of literal, then compares using ==. The usual |
636 | // admonishments about floating-point equality checks apply. We expect you to |
637 | // use this to check for complex values that can be expressed precisely as |
638 | // float pairs e.g. (-0.5, 1.0). |
639 | // |
640 | // This literal must have a dense layout. |
641 | bool IsAllComplex(complex64 value) const; |
642 | |
643 | // Literal consists entirely of the first element of the literal. |
644 | bool IsAllFirst() const; |
645 | |
646 | // Returns whether this literal is zero at the specified index. This literal |
647 | // must be an array with a dense layout. |
648 | bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const; |
649 | |
650 | // Return the count of the elements in the array at the given shape index in |
651 | // this literal. |
652 | int64 element_count(const ShapeIndex& index = {}) const { |
653 | return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); |
654 | } |
655 | |
656 | // Return the count of the elements in the sparse array at the given shape |
657 | // index in this literal, which will be no larger than |
658 | // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). |
659 | int64 sparse_element_count() const; |
660 | |
661 | protected: |
662 | // 'allocate_arrays' indicates whether to allocate memory for the arrays in |
663 | // the shape. If false, buffer pointers inside of the Literal::Pieces are set |
664 | // to nullptr. |
665 | Literal(const Shape& shape, bool allocate_arrays); |
666 | |
667 | // Internal template helper for the Literal::CopySliceFrom(), matching its |
668 | // arguments one by one. |
669 | template <typename NativeT> |
670 | Status CopySliceFromInternal(const Literal& src_literal, |
671 | tensorflow::gtl::ArraySlice<int64> src_base, |
672 | tensorflow::gtl::ArraySlice<int64> dest_base, |
673 | tensorflow::gtl::ArraySlice<int64> copy_size); |
674 | |
675 | // Utility structure which is used to create the optimal configuration for |
676 | // a ShapeUtil::ForEachIndex() scan across two literals. |
677 | struct StrideConfig { |
678 | StrideConfig(const Shape& source_shape, const Shape& dest_shape, |
679 | tensorflow::gtl::ArraySlice<int64> dimensions); |
680 | |
681 | // The dimensions of the stride operation. Essentially every dimension |
682 | // will be iterated from base[i] to base[i]+dimensions[i], in step[i] |
683 | // steps. |
684 | tensorflow::gtl::ArraySlice<int64> dimensions; |
685 | DimensionVector base; |
686 | DimensionVector step; |
687 | int64 minor_dimension = 0; |
688 | // The size of the strides for source and destination. One of the two |
689 | // (the one looping through its most minor dimension) will be 1, while |
690 | // the other will be the stride size at the dimension matching the other |
691 | // shape most minor dimension being scanned. |
692 | int64 dest_stride = 1; |
693 | int64 source_stride = 1; |
694 | // The size of the inner loop on the most minor dimension. |
695 | int64 minor_loop_size = 1; |
696 | }; |
697 | |
698 | // A data structure representing a subshape at a particular ShapeIndex within |
699 | // the literal. For array-shaped ShapeIndexes, this data structure holds the |
700 | // pointer to the memory allocated for the array data. |
701 | class Piece { |
702 | public: |
703 | // Return the buffer holding the array data for this piece as an array |
704 | // slice. This piece must be array-shaped. |
705 | template <typename NativeT> |
706 | tensorflow::gtl::ArraySlice<NativeT> data() const; |
707 | template <typename NativeT> |
708 | tensorflow::gtl::MutableArraySlice<NativeT> data(); |
709 | |
710 | // Return the buffer holding the array data for this piece as a void*. This |
711 | // piece must be array-shaped. |
712 | void* untyped_data(); |
713 | const void* untyped_data() const; |
714 | |
715 | // Gets or sets an element in the array at the given index. The multi_index |
716 | // is CHECKed against the dimension sizes of the array. This piece must be |
717 | // array-shaped. |
718 | template <typename NativeT> |
719 | NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const; |
720 | template <typename NativeT> |
721 | void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value); |
722 | |
723 | // Gets/sets the buffer holding the array data. |
724 | char* buffer() const { return buffer_; } |
725 | void set_buffer(char* buffer) { buffer_ = buffer; } |
726 | |
727 | // The array of multi-indices that provide the locations of non-zero |
728 | // elements in a sparse array. Only used if |
729 | // LayoutUtil::IsSparseArray(shape()) is true. |
730 | SparseIndexArray* sparse_indices() const { return sparse_indices_; } |
731 | void set_sparse_indices(SparseIndexArray* sparse_indices) { |
732 | sparse_indices_ = sparse_indices; |
733 | } |
734 | |
735 | // Gets or sets the subshape of this piece. This reference points to a |
736 | // subshape within the shape in the containing Literal (Literal::shape_). |
737 | const Shape& subshape() const { return *subshape_; } |
738 | void set_subshape(const Shape* subshape) { subshape_ = subshape; } |
739 | |
740 | // Returns the size in bytes of the buffer holding the array data. |
741 | int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } |
742 | |
743 | // Returns the number of elements in this piece's array. |
744 | int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); } |
745 | |
746 | // Copy the data from 'src' into this piece's buffer. Shapes of this piece |
747 | // and src must be compatible. |
748 | Status CopyFrom(const Piece& src); |
749 | |
750 | // Returns true if this piece and 'other' contain the same data. This piece |
751 | // and 'other' must be array-shaped and compatible. |
752 | bool EqualElements(const Piece& other) const; |
753 | |
754 | // Writes the shape and data (if array-shaped) into the given proto. |
755 | void WriteToProto(LiteralProto* proto) const; |
756 | |
757 | // Copies the data from the given proto into this piece. The shape of this |
758 | // piece must be equal (not just compatible) to the shape of the proto. |
759 | Status CopyFromProto(const LiteralProto& proto); |
760 | |
761 | // Sorts the elements in a sparse array. |
762 | void SortSparseElements(); |
763 | |
764 | private: |
765 | // Recursive helper for EqualElements. |
766 | template <typename NativeT> |
767 | bool EqualElementsInternal(const Piece& other, |
768 | std::vector<int64>* multi_index) const; |
769 | |
770 | // Helper for SortSparseElements that has the element type as a template |
771 | // parameter. |
772 | template <typename NativeT> |
773 | void SortSparseElementsInternal(); |
774 | |
775 | // For array-shaped pieces, this is the buffer holding the literal data. |
776 | char* buffer_ = nullptr; |
777 | |
778 | // For sparse arrays, this is the array of indices. |
779 | SparseIndexArray* sparse_indices_ = nullptr; |
780 | |
781 | // The shape of piece. This points into the shape of the containing Literal |
782 | // (Literal::shape_). |
783 | const Shape* subshape_ = nullptr; |
784 | }; |
785 | |
786 | // Returns the piece at the given ShapeIndex. |
787 | Piece& piece(const ShapeIndex& shape_index) { |
788 | return *pieces_.mutable_element(shape_index); |
789 | } |
790 | const Piece& piece(const ShapeIndex& shape_index) const { |
791 | return pieces_.element(shape_index); |
792 | } |
793 | |
794 | // Returns the piece at the root of the shape (empty ShapeIndex). |
795 | Piece& root_piece() { return piece({}); } |
796 | const Piece& root_piece() const { return piece({}); } |
797 | |
798 | // Deallocate the buffers held by this literal (if the literal owns the |
799 | // buffer). |
800 | void DeallocateBuffers(); |
801 | |
802 | // Implementation details shared between Populate() and PopulateParallel() |
803 | template <typename NativeT, typename FnType> |
804 | Status PopulateInternal(const FnType& generator, bool parallel); |
805 | |
806 | Shape shape_; |
807 | ShapeTree<Piece> pieces_; |
808 | |
809 | // Whether the buffers held in pieces_ are owned by this Literal. |
810 | bool owns_buffers_; |
811 | |
812 | // LiteralView must access and manipulate Pieces of other Literals. |
813 | friend class LiteralView; |
814 | }; // namespace xla |
815 | |
816 | std::ostream& operator<<(std::ostream& out, const Literal& literal); |
817 | |
818 | // A read-only view of a Literal. A LiteralView contains pointers to buffers |
819 | // owned by the viewed Literal. |
820 | // |
821 | // TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable |
822 | // and mutable) similar to (Mutable)ArraySlice. |
823 | class LiteralView : public Literal { |
824 | public: |
825 | // Create and return a view of the given literal rooted at the given shape |
826 | // index within the given literal. A factory is used rather than a public |
827 | // constructor because only const LiteralViews are supported. It's still |
828 | // possible to create non-const LiteralViews via the copy constructors, but |
829 | // the factory method makes it a bit less likely. Implementing literal slices |
830 | // will fix this undesirable situation (b/71550060). |
831 | static const LiteralView Create(const Literal& literal, |
832 | const ShapeIndex& view_root = {}); |
833 | |
834 | LiteralView(const LiteralView& other); |
835 | LiteralView& operator=(const LiteralView& other); |
836 | |
837 | virtual ~LiteralView(); |
838 | |
839 | private: |
840 | LiteralView(const Literal& literal, const ShapeIndex& view_root); |
841 | |
842 | // Helper for the copy constructor and copy assignment operator. |
843 | void CopyFrom(const LiteralView& other); |
844 | }; |
845 | |
846 | template <typename NativeT> |
847 | tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const { |
848 | CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); |
849 | CHECK_EQ(subshape().element_type(), |
850 | primitive_util::NativeToPrimitiveType<NativeT>()) |
851 | << "Attempting to access " |
852 | << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>()) |
853 | << " type, but literal element type is " |
854 | << PrimitiveType_Name(subshape().element_type()); |
855 | return tensorflow::gtl::ArraySlice<NativeT>( |
856 | reinterpret_cast<const NativeT*>(buffer()), |
857 | ShapeUtil::ElementsIn(subshape())); |
858 | } |
859 | |
860 | template <typename NativeT> |
861 | tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() { |
862 | CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); |
863 | CHECK_EQ(subshape().element_type(), |
864 | primitive_util::NativeToPrimitiveType<NativeT>()) |
865 | << "Attempting to access " |
866 | << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>()) |
867 | << " type, but literal element type is " |
868 | << PrimitiveType_Name(subshape().element_type()); |
869 | return tensorflow::gtl::MutableArraySlice<NativeT>( |
870 | reinterpret_cast<NativeT*>(buffer()), ShapeUtil::ElementsIn(subshape())); |
871 | } |
872 | |
873 | template <typename NativeT> |
874 | NativeT Literal::Piece::Get( |
875 | tensorflow::gtl::ArraySlice<int64> multi_index) const { |
876 | CHECK(LayoutUtil::IsDenseArray(subshape())); |
877 | return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex( |
878 | subshape(), multi_index)]; |
879 | } |
880 | |
881 | template <typename NativeT> |
882 | void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index, |
883 | NativeT value) { |
884 | CHECK(LayoutUtil::IsDenseArray(subshape())); |
885 | data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex( |
886 | subshape(), multi_index)] = value; |
887 | } |
888 | |
889 | template <typename NativeT> |
890 | tensorflow::gtl::ArraySlice<NativeT> Literal::data( |
891 | const ShapeIndex& shape_index) const { |
892 | return piece(shape_index).data<NativeT>(); |
893 | } |
894 | |
895 | template <typename NativeT> |
896 | tensorflow::gtl::MutableArraySlice<NativeT> Literal::data( |
897 | const ShapeIndex& shape_index) { |
898 | return piece(shape_index).data<NativeT>(); |
899 | } |
900 | |
901 | template <typename NativeT> |
902 | inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index, |
903 | const ShapeIndex& shape_index) const { |
904 | return piece(shape_index).Get<NativeT>(multi_index); |
905 | } |
906 | |
907 | template <typename NativeT> |
908 | inline NativeT Literal::Get( |
909 | tensorflow::gtl::ArraySlice<int64> multi_index) const { |
910 | return root_piece().Get<NativeT>(multi_index); |
911 | } |
912 | |
913 | template <typename NativeT> |
914 | inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, |
915 | const ShapeIndex& shape_index, NativeT value) { |
916 | return piece(shape_index).Set<NativeT>(multi_index, value); |
917 | } |
918 | |
919 | template <typename NativeT> |
920 | inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, |
921 | NativeT value) { |
922 | return root_piece().Set<NativeT>(multi_index, value); |
923 | } |
924 | |
925 | template <typename NativeT> |
926 | /* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) { |
927 | auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape( |
928 | primitive_util::NativeToPrimitiveType<NativeT>(), {})); |
929 | literal->Set({}, value); |
930 | return literal; |
931 | } |
932 | |
933 | template <typename NativeT> |
934 | /* static */ std::unique_ptr<Literal> Literal::CreateR1( |
935 | tensorflow::gtl::ArraySlice<NativeT> values) { |
936 | auto literal = MakeUnique<Literal>( |
937 | ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(), |
938 | {static_cast<int64>(values.size())})); |
939 | literal->PopulateR1(values); |
940 | return literal; |
941 | } |
942 | |
943 | template <typename NativeT> |
944 | /* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout( |
945 | std::initializer_list<std::initializer_list<NativeT>> values, |
946 | const Layout& layout) { |
947 | auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout( |
948 | primitive_util::NativeToPrimitiveType<NativeT>(), |
949 | {static_cast<int64>(values.size()), |
950 | static_cast<int64>(values.begin()->size())}, |
951 | AsInt64Slice(layout.minor_to_major()))); |
952 | literal->PopulateR2(values); |
953 | return literal; |
954 | } |
955 | |
956 | template <typename NativeT> |
957 | /* static */ std::unique_ptr<Literal> Literal::CreateR2( |
958 | std::initializer_list<std::initializer_list<NativeT>> values) { |
959 | return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); |
960 | } |
961 | |
962 | template <typename NativeT> |
963 | /* static */ std::unique_ptr<Literal> Literal::CreateR3WithLayout( |
964 | std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> |
965 | values, |
966 | const Layout& layout) { |
967 | const int64 d0 = values.size(); |
968 | const int64 d1 = values.begin()->size(); |
969 | const int64 d2 = values.begin()->begin()->size(); |
970 | Array3D<NativeT> tmp(d0, d1, d2); |
971 | int64 i0 = 0; |
972 | for (auto d1_values : values) { |
973 | int64 i1 = 0; |
974 | for (auto d2_values : d1_values) { |
975 | int64 i2 = 0; |
976 | for (auto value : d2_values) { |
977 | tmp(i0, i1, i2) = value; |
978 | ++i2; |
979 | } |
980 | ++i1; |
981 | } |
982 | ++i0; |
983 | } |
984 | return CreateR3FromArray3DWithLayout(tmp, layout); |
985 | } |
986 | |
987 | template <typename NativeT> |
988 | /* static */ std::unique_ptr<Literal> Literal::CreateR3( |
989 | std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> |
990 | values) { |
991 | return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); |
992 | } |
993 | |
994 | template <typename NativeT> |
995 | /* static */ std::unique_ptr<Literal> Literal::CreateR4WithLayout( |
996 | std::initializer_list<std::initializer_list< |
997 | std::initializer_list<std::initializer_list<NativeT>>>> |
998 | values, |
999 | const Layout& layout) { |
1000 | const int64 d0 = values.size(); |
1001 | const int64 d1 = values.begin()->size(); |
1002 | const int64 d2 = values.begin()->begin()->size(); |
1003 | const int64 d3 = values.begin()->begin()->begin()->size(); |
1004 | Array4D<NativeT> tmp(d0, d1, d2, d3); |
1005 | int64 i0 = 0; |
1006 | for (auto d1_values : values) { |
1007 | int64 i1 = 0; |
1008 | for (auto d2_values : d1_values) { |
1009 | int64 i2 = 0; |
1010 | for (auto d3_values : d2_values) { |
1011 | int64 i3 = 0; |
1012 | for (auto value : d3_values) { |
1013 | tmp(i0, i1, i2, i3) = value; |
1014 | ++i3; |
1015 | } |
1016 | ++i2; |
1017 | } |
1018 | ++i1; |
1019 | } |
1020 | ++i0; |
1021 | } |
1022 | return CreateR4FromArray4DWithLayout(tmp, layout); |
1023 | } |
1024 | |
1025 | template <typename NativeT> |
1026 | /* static */ std::unique_ptr<Literal> Literal::CreateSparse( |
1027 | tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices, |
1028 | tensorflow::gtl::ArraySlice<NativeT> values, bool sort) { |
1029 | int64 num_elements = values.size(); |
1030 | int64 rank = dimensions.size(); |
1031 | CHECK_EQ(num_elements, indices.index_count()); |
1032 | CHECK_EQ(rank, indices.rank()); |
1033 | auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout( |
1034 | primitive_util::NativeToPrimitiveType<NativeT>(), dimensions, |
1035 | indices.max_indices())); |
1036 | literal->PopulateSparse(indices, values, sort); |
1037 | return literal; |
1038 | } |
1039 | |
1040 | template <typename NativeT> |
1041 | /* static */ std::unique_ptr<Literal> Literal::CreateR4( |
1042 | std::initializer_list<std::initializer_list< |
1043 | std::initializer_list<std::initializer_list<NativeT>>>> |
1044 | values) { |
1045 | return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); |
1046 | } |
1047 | |
1048 | template <typename NativeT> |
1049 | /* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout( |
1050 | const Array<NativeT>& values, const Layout& layout) { |
1051 | auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout( |
1052 | primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(), |
1053 | AsInt64Slice(layout.minor_to_major()))); |
1054 | literal->PopulateFromArray(values); |
1055 | return literal; |
1056 | } |
1057 | |
1058 | template <typename NativeT> |
1059 | /* static */ std::unique_ptr<Literal> Literal::CreateFromArray( |
1060 | const Array<NativeT>& values) { |
1061 | return CreateFromArrayWithLayout( |
1062 | values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); |
1063 | } |
1064 | |
1065 | template <typename NativeT> |
1066 | /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout( |
1067 | const Array2D<NativeT>& values, const Layout& layout) { |
1068 | return CreateFromArrayWithLayout(values, layout); |
1069 | } |
1070 | |
1071 | template <typename NativeT> |
1072 | /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D( |
1073 | const Array2D<NativeT>& values) { |
1074 | return CreateFromArray(values); |
1075 | } |
1076 | |
1077 | template <typename NativeT> |
1078 | /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout( |
1079 | const Array3D<NativeT>& values, const Layout& layout) { |
1080 | return CreateFromArrayWithLayout(values, layout); |
1081 | } |
1082 | |
1083 | template <typename NativeT> |
1084 | /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D( |
1085 | const Array3D<NativeT>& values) { |
1086 | return CreateFromArray(values); |
1087 | } |
1088 | |
1089 | template <typename NativeT> |
1090 | /* static */ std::unique_ptr<Literal> Literal::CreateR3Projected( |
1091 | std::initializer_list<std::initializer_list<NativeT>> values, |
1092 | int64 projection) { |
1093 | int64 dim0_size = projection; |
1094 | int64 dim1_size = values.size(); |
1095 | int64 dim2_size = values.begin()->size(); |
1096 | |
1097 | Array3D<NativeT> array(dim0_size, dim1_size, dim2_size); |
1098 | for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { |
1099 | int64 dim1 = 0; |
1100 | for (auto inner_list : values) { |
1101 | int64 dim2 = 0; |
1102 | for (auto value : inner_list) { |
1103 | array(dim0, dim1, dim2) = value; |
1104 | ++dim2; |
1105 | } |
1106 | CHECK_EQ(dim2_size, dim2); |
1107 | ++dim1; |
1108 | } |
1109 | CHECK_EQ(dim1_size, dim1); |
1110 | } |
1111 | return CreateR3FromArray3D(array); |
1112 | } |
1113 | |
1114 | template <typename NativeT> |
1115 | /* static */ std::unique_ptr<Literal> Literal::CreateR4Projected( |
1116 | std::initializer_list<std::initializer_list<NativeT>> values, |
1117 | int64 projection_p, int64 projection_z) { |
1118 | int64 dim0_size = projection_p; |
1119 | int64 dim1_size = projection_z; |
1120 | int64 dim2_size = values.size(); |
1121 | int64 dim3_size = values.begin()->size(); |
1122 | |
1123 | Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size); |
1124 | for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { |
1125 | for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { |
1126 | int64 dim2 = 0; |
1127 | for (auto inner_list : values) { |
1128 | int64 dim3 = 0; |
1129 | for (auto value : inner_list) { |
1130 | array(dim0, dim1, dim2, dim3) = value; |
1131 | ++dim3; |
1132 | } |
1133 | CHECK_EQ(dim3_size, dim3); |
1134 | ++dim2; |
1135 | } |
1136 | CHECK_EQ(dim2_size, dim2); |
1137 | } |
1138 | } |
1139 | return CreateR4FromArray4D(array); |
1140 | } |
1141 | |
1142 | template <typename NativeT> |
1143 | /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D( |
1144 | const Array4D<NativeT>& values) { |
1145 | return CreateFromArray(values); |
1146 | } |
1147 | |
1148 | template <typename NativeT> |
1149 | /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout( |
1150 | const Array4D<NativeT>& values, const Layout& layout) { |
1151 | return CreateFromArrayWithLayout(values, layout); |
1152 | } |
1153 | |
1154 | template <typename NativeT> |
1155 | NativeT Literal::GetFirstElement() const { |
1156 | return data<NativeT>().at(0); |
1157 | } |
1158 | |
1159 | template <typename NativeT> |
1160 | NativeT Literal::GetSparseElement(int64 sparse_element_number, |
1161 | const ShapeIndex& shape_index) const { |
1162 | CHECK( |
1163 | LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); |
1164 | return data<NativeT>(shape_index)[sparse_element_number]; |
1165 | } |
1166 | |
1167 | template <typename NativeT> |
1168 | void Literal::AppendSparseElement( |
1169 | tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value, |
1170 | const ShapeIndex& shape_index) { |
1171 | Piece& p = piece(shape_index); |
1172 | const Shape& subshape = p.subshape(); |
1173 | CHECK(LayoutUtil::IsSparseArray(subshape)); |
1174 | int64 rank = ShapeUtil::Rank(subshape); |
1175 | CHECK_EQ(multi_index.size(), rank); |
1176 | int64 last_element = p.sparse_indices()->index_count(); |
1177 | CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); |
1178 | p.sparse_indices()->Append(multi_index); |
1179 | CHECK_LT(last_element, p.data<NativeT>().size()); |
1180 | p.data<NativeT>()[last_element] = value; |
1181 | } |
1182 | |
1183 | // Returns an identity matrix (rank 2) with the given row and column count. |
1184 | template <typename NativeT> |
1185 | /* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) { |
1186 | Array2D<NativeT> array(size, size, 0); |
1187 | for (int64 i = 0; i < size; ++i) { |
1188 | array(i, i) = 1; |
1189 | } |
1190 | return CreateR2FromArray2D(array); |
1191 | } |
1192 | |
1193 | template <typename NativeT> |
1194 | void Literal::EachCell( |
1195 | std::function<void(tensorflow::gtl::ArraySlice<int64> indices, |
1196 | NativeT value)> |
1197 | per_cell) const { |
1198 | if (ShapeUtil::HasZeroElements(shape())) { |
1199 | return; |
1200 | } |
1201 | std::vector<int64> indices(ShapeUtil::Rank(shape()), 0); |
1202 | do { |
1203 | per_cell(indices, Get<NativeT>(indices)); |
1204 | } while (IndexUtil::BumpIndices(shape(), &indices)); |
1205 | } |
1206 | |
1207 | template <typename NativeT> |
1208 | inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) { |
1209 | CHECK(ShapeUtil::IsArray(shape())); |
1210 | CHECK_EQ(ShapeUtil::Rank(shape()), 1); |
1211 | CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); |
1212 | CHECK_EQ(shape().element_type(), |
1213 | primitive_util::NativeToPrimitiveType<NativeT>()); |
1214 | for (int64 i = 0; i < values.size(); ++i) { |
1215 | Set({i}, values[i]); |
1216 | } |
1217 | } |
1218 | |
1219 | template <typename NativeT> |
1220 | void Literal::PopulateR2( |
1221 | std::initializer_list<std::initializer_list<NativeT>> values) { |
1222 | CHECK(ShapeUtil::IsArray(shape())); |
1223 | CHECK_EQ(ShapeUtil::Rank(shape()), 2); |
1224 | CHECK_EQ(shape().element_type(), |
1225 | primitive_util::NativeToPrimitiveType<NativeT>()); |
1226 | |
1227 | const int64 dim0_size = values.size(); |
1228 | const int64 dim1_size = values.begin()->size(); |
1229 | CHECK_EQ(dim0_size, shape().dimensions(0)); |
1230 | CHECK_EQ(dim1_size, shape().dimensions(1)); |
1231 | |
1232 | int64 dim0 = 0; |
1233 | for (auto inner_list : values) { |
1234 | int64 dim1 = 0; |
1235 | for (auto value : inner_list) { |
1236 | Set({dim0, dim1}, value); |
1237 | ++dim1; |
1238 | } |
1239 | CHECK_EQ(dim1_size, dim1); |
1240 | ++dim0; |
1241 | } |
1242 | } |
1243 | |
1244 | template <typename NativeT> |
1245 | void Literal::PopulateFromArray(const Array<NativeT>& values) { |
1246 | CHECK(ShapeUtil::IsArray(shape())); |
1247 | CHECK_EQ(shape().element_type(), |
1248 | primitive_util::NativeToPrimitiveType<NativeT>()); |
1249 | CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); |
1250 | for (int dim = 0; dim < values.num_dimensions(); ++dim) { |
1251 | CHECK_EQ(values.dim(dim), shape().dimensions(dim)); |
1252 | } |
1253 | values.Each([this](tensorflow::gtl::ArraySlice<int64> indices, |
1254 | NativeT value) { this->Set(indices, value); }); |
1255 | } |
1256 | |
1257 | template <typename NativeT> |
1258 | void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) { |
1259 | PopulateFromArray(values); |
1260 | } |
1261 | |
1262 | template <typename NativeT> |
1263 | void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) { |
1264 | PopulateFromArray(values); |
1265 | } |
1266 | |
1267 | template <typename NativeT> |
1268 | void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) { |
1269 | PopulateFromArray(values); |
1270 | } |
1271 | |
1272 | template <typename NativeT> |
1273 | void Literal::PopulateSparse(SparseIndexArray indices, |
1274 | tensorflow::gtl::ArraySlice<NativeT> values, |
1275 | bool sort) { |
1276 | CHECK(LayoutUtil::IsSparseArray(shape())); |
1277 | int rank = ShapeUtil::Rank(shape()); |
1278 | CHECK_EQ(indices.rank(), rank); |
1279 | int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); |
1280 | CHECK_LE(indices.max_indices(), max_elements); |
1281 | int64 num_elements = values.size(); |
1282 | CHECK_LE(num_elements, max_elements); |
1283 | CHECK_EQ(num_elements, indices.index_count()); |
1284 | auto root_data = root_piece().data<NativeT>(); |
1285 | root_data.remove_suffix(max_elements - values.size()); |
1286 | std::copy(values.begin(), values.end(), root_data.begin()); |
1287 | *this->root_piece().sparse_indices() = std::move(indices); |
1288 | if (sort) { |
1289 | auto root_data = this->root_piece().data<NativeT>(); |
1290 | root_data.remove_suffix(root_data.size() - num_elements); |
1291 | this->root_piece().sparse_indices()->SortWithValues(root_data); |
1292 | } |
1293 | DCHECK(this->root_piece().sparse_indices()->Validate(shape())); |
1294 | } |
1295 | |
1296 | template <typename NativeT, typename FnType> |
1297 | Status Literal::PopulateInternal(const FnType& generator, bool parallel) { |
1298 | const Shape& this_shape = shape(); |
1299 | const int64 rank = ShapeUtil::Rank(this_shape); |
1300 | TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); |
1301 | TF_RET_CHECK(this_shape.element_type() == |
1302 | primitive_util::NativeToPrimitiveType<NativeT>()); |
1303 | tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>(); |
1304 | if (rank > 0) { |
1305 | StrideConfig stride_config(this_shape, this_shape, |
1306 | AsInt64Slice(this_shape.dimensions())); |
1307 | int64 minor_dimension_size = |
1308 | ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); |
1309 | |
1310 | auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) { |
1311 | DimensionVector minor_scan_indexes(rank, 0); |
1312 | const int64 index = |
1313 | IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); |
1314 | std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); |
1315 | for (int64 i = 0; i < minor_dimension_size; ++i) { |
1316 | minor_scan_indexes[stride_config.minor_dimension] = i; |
1317 | literal_data.at(index + i) = generator(minor_scan_indexes); |
1318 | } |
1319 | }; |
1320 | if (parallel) { |
1321 | ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, |
1322 | stride_config.dimensions, |
1323 | stride_config.step, init_function); |
1324 | } else { |
1325 | ShapeUtil::ForEachIndex( |
1326 | this_shape, stride_config.base, stride_config.dimensions, |
1327 | stride_config.step, |
1328 | [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) { |
1329 | init_function(indexes); |
1330 | return true; |
1331 | }); |
1332 | } |
1333 | } else { |
1334 | // For scalars. |
1335 | literal_data.at(0) = generator({}); |
1336 | } |
1337 | return Status::OK(); |
1338 | } |
1339 | template <typename NativeT, typename FnType> |
1340 | Status Literal::Populate(const FnType& generator) { |
1341 | return PopulateInternal<NativeT>(generator, /*parallel=*/false); |
1342 | } |
1343 | |
1344 | template <typename NativeT, typename FnType> |
1345 | Status Literal::PopulateParallel(const FnType& generator) { |
1346 | return PopulateInternal<NativeT>(generator, /*parallel=*/true); |
1347 | } |
1348 | |
1349 | template <typename NativeT> |
1350 | void Literal::PopulateWithValue(NativeT value) { |
1351 | CHECK(ShapeUtil::IsArray(shape())); |
1352 | CHECK_EQ(shape().element_type(), |
1353 | primitive_util::NativeToPrimitiveType<NativeT>()); |
1354 | for (NativeT& element : data<NativeT>()) { |
1355 | element = value; |
1356 | } |
1357 | } |
1358 | |
1359 | template <typename NativeT> |
1360 | /* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout( |
1361 | tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) { |
1362 | auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout( |
1363 | primitive_util::NativeToPrimitiveType<NativeT>(), dimensions)); |
1364 | literal->PopulateWithValue(value); |
1365 | return literal; |
1366 | } |
1367 | |
1368 | template <typename NativeT> |
1369 | std::unique_ptr<Literal> Literal::Replicate(int64 times) const { |
1370 | DimensionVector bounds = {times}; |
1371 | bounds.reserve(shape().dimensions_size() + 1); |
1372 | for (int64 bound : shape().dimensions()) { |
1373 | bounds.push_back(bound); |
1374 | } |
1375 | auto literal = |
1376 | MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds)); |
1377 | int64 elements = ShapeUtil::ElementsIn(literal->shape()); |
1378 | if (elements == 0) { |
1379 | return literal; |
1380 | } |
1381 | |
1382 | DimensionVector output_indices(bounds.size(), 0); |
1383 | tensorflow::gtl::ArraySlice<int64> input_indices = output_indices; |
1384 | input_indices.remove_prefix(1); |
1385 | |
1386 | bool done = false; |
1387 | while (!done) { |
1388 | const auto element = Get<NativeT>(input_indices); |
1389 | literal->Set<NativeT>(output_indices, element); |
1390 | |
1391 | done = true; |
1392 | for (int n = 0; n < output_indices.size(); ++n) { |
1393 | ++output_indices[n]; |
1394 | if (output_indices[n] < bounds[n]) { |
1395 | done = false; |
1396 | break; |
1397 | } |
1398 | output_indices[n] = 0; |
1399 | } |
1400 | } |
1401 | return literal; |
1402 | } |
1403 | |
1404 | } // namespace xla |
1405 | |
1406 | #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ |
1407 | |