1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Utilities for dealing with Literal protobufs.
17
18#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
19#define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
20
21#include <functional>
22#include <initializer_list>
23#include <iterator>
24#include <memory>
25#include <ostream>
26#include <string>
27#include <type_traits>
28#include <vector>
29
30#include "tensorflow/compiler/xla/array2d.h"
31#include "tensorflow/compiler/xla/array3d.h"
32#include "tensorflow/compiler/xla/array4d.h"
33#include "tensorflow/compiler/xla/index_util.h"
34#include "tensorflow/compiler/xla/layout_util.h"
35#include "tensorflow/compiler/xla/primitive_util.h"
36#include "tensorflow/compiler/xla/ptr_util.h"
37#include "tensorflow/compiler/xla/shape_tree.h"
38#include "tensorflow/compiler/xla/shape_util.h"
39#include "tensorflow/compiler/xla/sparse_index_array.h"
40#include "tensorflow/compiler/xla/status_macros.h"
41#include "tensorflow/compiler/xla/types.h"
42#include "tensorflow/compiler/xla/util.h"
43#include "tensorflow/compiler/xla/xla_data.pb.h"
44#include "tensorflow/core/lib/core/bitmap.h"
45#include "tensorflow/core/lib/core/status.h"
46#include "tensorflow/core/lib/core/stringpiece.h"
47#include "tensorflow/core/lib/gtl/array_slice.h"
48#include "tensorflow/core/platform/logging.h"
49#include "tensorflow/core/platform/macros.h"
50#include "tensorflow/core/platform/protobuf.h"
51#include "tensorflow/core/platform/types.h"
52
53namespace xla {
54
55// Class representing literal values in XLA.
56//
57// TODO(b/67651157): The methods in this class should be reduced to a minimal
58// set of methods which construct Literals and accessors methods. Other methods
59// which perform computation on Literals (Reshape, Slice, etc) should be moved
60// elsewhere, and perhaps combined with evaluator code which operates on
61// Literals.
62class Literal {
63 public:
64 Literal() : Literal(ShapeUtil::MakeNil()) {}
65
66 // Create a literal of the given shape. The literal is allocated sufficient
67 // memory to hold the shape. Memory is uninitialized.
68 explicit Literal(const Shape& shape);
69 virtual ~Literal();
70
71 // Literals are moveable, but not copyable. To copy a literal use
72 // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
73 // of literals which can be expensive.
74 Literal(const Literal& other) = delete;
75 Literal& operator=(const Literal& other) = delete;
76 Literal(Literal&& other);
77 Literal& operator=(Literal&& other);
78
79 // Literals are equal if they have compatible shapes and the same data
80 // values. Layout is not compared.
81 bool operator==(const Literal& other) const;
82 bool operator!=(const Literal& other) const { return !(*this == other); }
83
84 // Serialize to and from a proto.
85 static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
86 const LiteralProto& proto);
87 LiteralProto ToProto() const;
88
89 // Return the shape of the literal.
90 const Shape& shape() const { return shape_; }
91
92 // TODO(b/67651157): Remove this accessor. Literal users should not be able to
93 // mutate the shape as this can produce malformed Literals.
94 Shape* mutable_shape_do_not_use() { return &shape_; }
95
96 // Returns a (Mutable)ArraySlice view of the array for this literal for the
97 // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
98 // given ShapeIndex is not array. See primitive_util.h for the mapping from
99 // XLA type to native type.
100 template <typename NativeT>
101 tensorflow::gtl::ArraySlice<NativeT> data(
102 const ShapeIndex& shape_index = {}) const;
103 template <typename NativeT>
104 tensorflow::gtl::MutableArraySlice<NativeT> data(
105 const ShapeIndex& shape_index = {});
106
107 // Returns a pointer to the sparse index array. Returns nullptr if the literal
108 // is not a sparse array.
109 const SparseIndexArray* sparse_indices(
110 const ShapeIndex& shape_index = {}) const;
111 SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
112
113 // Returns a pointer to (or size of) the underlying buffer holding the array
114 // at the given shape index. CHECKs if the subshape of the literal at the
115 // given ShapeIndex is not array.
116 const void* untyped_data(const ShapeIndex& shape_index = {}) const;
117 void* untyped_data(const ShapeIndex& shape_index = {});
118 int64 size_bytes(const ShapeIndex& shape_index = {}) const;
119
120 // Creates a new literal of a given rank. To minimize ambiguity (for users
121 // and the compiler) these CreateR[0-2] methods should explicitly specify the
122 // native type. For example:
123 //
124 // CreateR1<float>({1.0, 42.0});
125 // CreateR2<uint32>({{1, 2}, {3, 4}});
126 //
127 // The variants not ending with WithLayout use the default XLA layout for the
128 // literal's linear representation in memory.
129 template <typename NativeT>
130 static std::unique_ptr<Literal> CreateR0(NativeT value);
131 template <typename NativeT>
132 static std::unique_ptr<Literal> CreateR1(
133 tensorflow::gtl::ArraySlice<NativeT> values);
134 static std::unique_ptr<Literal> CreateR1(
135 const tensorflow::core::Bitmap& values);
136 template <typename NativeT>
137 static std::unique_ptr<Literal> CreateR2(
138 std::initializer_list<std::initializer_list<NativeT>> values);
139 template <typename NativeT>
140 static std::unique_ptr<Literal> CreateR2WithLayout(
141 std::initializer_list<std::initializer_list<NativeT>> values,
142 const Layout& layout);
143 template <typename NativeT>
144 static std::unique_ptr<Literal> CreateR3(
145 std::initializer_list<
146 std::initializer_list<std::initializer_list<NativeT>>>
147 values);
148 template <typename NativeT>
149 static std::unique_ptr<Literal> CreateR3WithLayout(
150 std::initializer_list<
151 std::initializer_list<std::initializer_list<NativeT>>>
152 values,
153 const Layout& layout);
154 template <typename NativeT>
155 static std::unique_ptr<Literal> CreateR4(
156 std::initializer_list<std::initializer_list<
157 std::initializer_list<std::initializer_list<NativeT>>>>
158 values);
159 template <typename NativeT>
160 static std::unique_ptr<Literal> CreateR4WithLayout(
161 std::initializer_list<std::initializer_list<
162 std::initializer_list<std::initializer_list<NativeT>>>>
163 values,
164 const Layout& layout);
165
166 // Returns this literal's data as a string. This literal must be a rank-1 U8
167 // array.
168 string GetR1U8AsString() const;
169
170 // Creates a literal with a sparse layout and the given indices and values.
171 // The shape is initialized from the given dimensions. The minor dimension of
172 // the indices array must equal the rank of the shape (i.e. size of the
173 // dimensions array). The major dimension of the indices array must equal the
174 // number of elements in the values array. The maximum number of elements in
175 // the array is taken from the max_indices() value of the index array.
176 //
177 // XLA assumes that sparse literals are in sorted order for all operations. If
178 // the `sort` argument is true, then the indices and values will be sorted
179 // while copying them into the literal. If you have ensured that the indices
180 // and values are already sorted, then you may set the `sort` argument to
181 // false to skip the sorting step.
182 //
183 // For example:
184 //
185 // CreateSparse(
186 // {12, 12, 12},
187 // SparseIndexArray(10, 3,
188 // Array2D{
189 // {0, 1, 2},
190 // {3, 4, 5},
191 // {6, 7, 8},
192 // {9, 10, 11},
193 // }),
194 // {1.0, 2.0 3.0, 4.0})
195 //
196 // This creates an array with shape F64[12,12,12]sparse{10}, that has the
197 // following non-zero values:
198 //
199 // [0, 1, 2]: 1.0
200 // [3, 4, 5]: 2.0
201 // [6, 7, 8]: 3.0
202 // [9, 10, 11]: 4.0
203 //
204 template <typename NativeT>
205 static std::unique_ptr<Literal> CreateSparse(
206 tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
207 tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
208
209 // Populates a literal with a sparse layout with the given indices and values.
210 // Each index in the indices array is CHECKed against the dimensions in the
211 // literal's shape. If sort is true, then the indices and values will be
212 // sorted. If sort is false, then the indices and values are assumed to
213 // already be in sorted order. See CreateSparse for an example of how data
214 // are populated.
215 template <typename NativeT>
216 void PopulateSparse(SparseIndexArray indices,
217 tensorflow::gtl::ArraySlice<NativeT> values,
218 bool sort = true);
219
220 // Creates a new Literal object with the shape specified as parameter.
221 // The content of the literal values is the default value of the primitive
222 // type of literal itself (0 for numeric types, and false for predicates).
223 static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
224
225 // Creates a new Literal object with its values havings the primitive_type
226 // type, and with dimensions defined by the dimensions parameter.
227 // The content of the literal values is the default value of the primitive
228 // type of literal itself (0 for numeric types, and false for predicates).
229 static std::unique_ptr<Literal> CreateFromDimensions(
230 PrimitiveType primitive_type,
231 tensorflow::gtl::ArraySlice<int64> dimensions);
232
233 // Copy values from 'src_literal' rooted at 'src_shape_index' into this
234 // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
235 // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
236 // rooted at 'src_shape_index', but need not be arrays.
237 Status CopyFrom(const Literal& src_literal,
238 const ShapeIndex& dest_shape_index = {},
239 const ShapeIndex& src_shape_index = {});
240
241 // Similar to CopyFrom, but with move semantincs. The subshape of this literal
242 // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
243 // (layouts and shapes must match), but need not be arrays. The memory
244 // allocated in this literal for the subshape at dest_shape_index is
245 // deallocated, and the respective buffers are replaced with those in
246 // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
247 Status MoveFrom(Literal&& src_literal,
248 const ShapeIndex& dest_shape_index = {});
249
250 // Copies the values from src_literal, starting at src_base shape indexes,
251 // to this literal, starting at dest_base, where the copy size in each
252 // dimension is specified by copy_size.
253 // The src_literal and this literal must have the same primitive type,
254 // src_base+copy_size must fit the source literal dimensions, as well as
255 // dest_base+copy_size must fit the destination literal dimensions.
256 // Note: if either src_literal or this literal contains dimensions with zero
257 // element, then copy_size must be 0 in these dimensions while the
258 // corresponding base indices being 0.
259 // This literal and 'src_literal' must be arrays.
260 Status CopySliceFrom(const Literal& src_literal,
261 tensorflow::gtl::ArraySlice<int64> src_base,
262 tensorflow::gtl::ArraySlice<int64> dest_base,
263 tensorflow::gtl::ArraySlice<int64> copy_size);
264
265 // Copies one element from src_literal[src_index] to (*this)[dest_index].
266 Status CopyElementFrom(const Literal& src_literal,
267 tensorflow::gtl::ArraySlice<int64> src_index,
268 tensorflow::gtl::ArraySlice<int64> dest_index);
269
270 // Returns a vector containing the tuple elements of this Literal as separate
271 // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
272 // elements are moved into the new Literals; no data is copied. Upon return
273 // this Literal is set to a nil shape (empty tuple)
274 std::vector<Literal> DecomposeTuple();
275
276 // This operation is the inverse of DecomposeTuple. The given elements are
277 // moved into the tuple elements of a new tuple-shaped Literal which is
278 // returned. Upon return, each of the Literals in 'elements' is set to a nil
279 // shape (empty tuple).
280 static Literal MoveIntoTuple(
281 tensorflow::gtl::MutableArraySlice<Literal> elements);
282
283 // Creates a new value that has the equivalent value as this literal, but
284 // conforms to new_layout; e.g. a literal matrix that was in {0, 1}
285 // minor-to-major dimension layout can be re-layed-out as {1, 0}
286 // minor-to-major dimension layout and the value in the cell at any given
287 // logical index (i0, i1) will be the same.
288 //
289 // For tuple shaped literals, shape_index should be used to select the inner
290 // array that the new layout applies to.
291 //
292 // Note: this is useful when the client wants to ensure that a value placed in
293 // the XLA allocation tracker has a particular layout; for efficiency
294 // purposes or avoiding unimplemented operation/layout combinations.
295 std::unique_ptr<Literal> Relayout(const Layout& new_layout,
296 const ShapeIndex& shape_index = {}) const;
297
298 // An overload of Relayout which changes the layout of the entire shape rather
299 // than being limited to a single array within the shape.
300 std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
301
302 // Creates a new literal by reshaping this literal to have the given
303 // dimensions. The total number of elements must not change; The
304 // implementation currently only supports monotonic dim0-major layouts.
305 // This literal must be an array.
306 StatusOr<std::unique_ptr<Literal>> Reshape(
307 tensorflow::gtl::ArraySlice<int64> dimensions) const;
308
309 // Creates a new literal by reordering the dimensions of this literal.
310 // The given `permutation` must be a permutation of the dimension numbers
311 // in the original literal, and it specifies the order of the new dimensions
312 // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
313 // For example, a transpose call on a literal of shape [3 x 8 x 4] and
314 // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
315 // This literal must be an array.
316 std::unique_ptr<Literal> Transpose(
317 tensorflow::gtl::ArraySlice<int64> permutation) const;
318
319 // Creates a sub-array from this literal by extracting the indices
320 // [start_index, limit_index) of each dimension. The result literal has the
321 // same rank and layout as for the given literal. The number of indices in
322 // start_indices and limit_indices must be the rank of the literal, and the
323 // indices follow the order of the dimensions.
324 // This literal must be an array.
325 std::unique_ptr<Literal> Slice(
326 tensorflow::gtl::ArraySlice<int64> start_indices,
327 tensorflow::gtl::ArraySlice<int64> limit_indices) const;
328
329 // Creates a literal with a prepended dimension with bound "times"; e.g. a
330 // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
331 // literal replicated four times.
332 // This literal must be an array.
333 template <typename NativeT>
334 std::unique_ptr<Literal> Replicate(int64 times) const;
335
336 // Converts this literal to another primitive type using
337 // static_cast<>. Returns an error if the conversion is not possible. This
338 // literal must be array-shaped.
339 StatusOr<std::unique_ptr<Literal>> Convert(
340 PrimitiveType primitive_dest_type) const;
341
342 // Converts this literal to another primitive type using a bitcast
343 // conversion. The to and from primitive types must have the same bit
344 // width. Returns an error if the conversion is not possible. This literal
345 // must be array-shaped.
346 StatusOr<std::unique_ptr<Literal>> BitcastConvert(
347 PrimitiveType primitive_dest_type) const;
348
349 // Converts this literal to the given shape. Returns an error is the
350 // conversion is not possible.
351 //
352 // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
353 // instead of truncation; otherwise, truncation is used.
354 //
355 // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
356 // the default behavior.
357 StatusOr<std::unique_ptr<Literal>> ConvertToShape(
358 const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
359
360 // Creates a scalar literal value zero of the given primitive type.
361 static Literal Zero(PrimitiveType primitive_type);
362
363 // Creates a scalar literal value one of the given primitive type.
364 static Literal One(PrimitiveType primitive_type);
365
366 // Creates a scalar literal value containing the minimum value of the given
367 // primitive type. For floating-point types, returns -inf.
368 static Literal MinValue(PrimitiveType primitive_type);
369
370 // Creates a scalar literal value containing the maximum value of the given
371 // primitive type. For floating-point types, returns inf.
372 static Literal MaxValue(PrimitiveType primitive_type);
373
374 // Creates a literal of the given shape where each element is `value`.
375 template <typename NativeT>
376 static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
377 tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
378
379 // Creates a new literal from an Array type. The variants not ending with
380 // WithLayout use the default XLA layout for the literal's linear
381 // representation in memory.
382 template <typename NativeT>
383 static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
384 template <typename NativeT>
385 static std::unique_ptr<Literal> CreateFromArrayWithLayout(
386 const Array<NativeT>& values, const Layout& layout);
387 template <typename NativeT>
388 static std::unique_ptr<Literal> CreateR2FromArray2D(
389 const Array2D<NativeT>& values);
390 template <typename NativeT>
391 static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
392 const Array2D<NativeT>& values, const Layout& layout);
393 template <typename NativeT>
394 static std::unique_ptr<Literal> CreateR3FromArray3D(
395 const Array3D<NativeT>& values);
396 template <typename NativeT>
397 static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
398 const Array3D<NativeT>& values, const Layout& layout);
399 template <typename NativeT>
400 static std::unique_ptr<Literal> CreateR4FromArray4D(
401 const Array4D<NativeT>& values);
402 template <typename NativeT>
403 static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
404 const Array4D<NativeT>& values, const Layout& layout);
405
406 // Creates a new vector of U8s literal value from a string.
407 static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value);
408
409 // Creates a linspace-populated literal with the given number of rows and
410 // columns.
411 static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
412 int64 rows, int64 cols);
413
414 // Creates a literal that projects the (x, y) dimensions given in values into
415 // the z dimension given by "projection".
416 template <typename NativeT>
417 static std::unique_ptr<Literal> CreateR3Projected(
418 std::initializer_list<std::initializer_list<NativeT>> values,
419 int64 projection);
420
421 // Creates a literal that projects the (x, y) dimensions given in values into
422 // the z and p dimensions given.
423 template <typename NativeT>
424 static std::unique_ptr<Literal> CreateR4Projected(
425 std::initializer_list<std::initializer_list<NativeT>> values,
426 int64 projection_p, int64 projection_z);
427
428 // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
429 Literal Clone() const;
430 std::unique_ptr<Literal> CloneToUnique() const;
431
432 // Gets or sets an element in the literal at the given index. The multi_index
433 // is CHECKed against the dimension sizes.
434 template <typename NativeT>
435 NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
436 const ShapeIndex& shape_index) const;
437 template <typename NativeT>
438 void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
439 const ShapeIndex& shape_index, NativeT value);
440
441 // Overloads of Get and Set for array literals. CHECKs if the literal is not
442 // array-shaped and dense.
443 template <typename NativeT>
444 NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
445 template <typename NativeT>
446 void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
447
448 // Returns the multi-index of the element in a sparse literal at the given
449 // sparse element number. The sparse element number is the position with in
450 // the sparse array's list of (index, value) pairs, and is checked against the
451 // total number of (index, value) pairs in the sparse array.
452 tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
453 int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
454
455 // Returns the value of the element in a sparse literal at the given sparse
456 // element number. The sparse element number is the position with in the
457 // sparse array's list of (index, value) pairs, and is checked against the
458 // total number of (index, value) pairs in the sparse array.
459 template <typename NativeT>
460 NativeT GetSparseElement(int64 sparse_element_number,
461 const ShapeIndex& shape_index = {}) const;
462
463 // Appends the given element to the literal. If the elements are not appended
464 // in sorted order, then SortSparseElements should be called before calling
465 // other methods. This literal must have a sparse layout.
466 template <typename NativeT>
467 void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
468 NativeT value, const ShapeIndex& shape_index = {});
469
470 // Sorts the elements in a sparse array.
471 void SortSparseElements(const ShapeIndex& shape_index = {});
472
473 // Returns the element value at index (0, ..., 0), however many zeroes are
474 // required for that index.
475 template <typename NativeT>
476 NativeT GetFirstElement() const;
477
478 // Returns a literal scalar representing the first element.
479 Literal GetFirstScalarLiteral() const;
480
481 // As Get(), but determines the correct type and converts the value
482 // into text.
483 string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
484 const ShapeIndex& shape_index = {}) const;
485
486 // As GetSparseElement(), but determines the correct type and converts the
487 // value into text.
488 string GetSparseElementAsString(int64 sparse_element_number,
489 const ShapeIndex& shape_index = {}) const;
490
491 // As Get(), but determines the correct type and converts the value into
492 // int64. This literal must be an array.
493 StatusOr<int64> GetIntegralAsS64(
494 tensorflow::gtl::ArraySlice<int64> multi_index) const;
495
496 // As Set(), but truncates `value` to the literal element type before storing.
497 // This literal must be an array.
498 Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
499 int64 value);
500
501 // Returns an identity matrix (rank 2) with the given row and column count.
502 template <typename NativeT>
503 static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
504
505 // Returns a tuple literal composed of given literals. Data is copied from the
506 // given elements into the returned literal.
507 static std::unique_ptr<Literal> MakeTuple(
508 tensorflow::gtl::ArraySlice<const Literal*> elements);
509
510 // As above, but intended to be invoked with move semantics; i.e.
511 //
512 // std::vector<std::unique_ptr<Literal>> elements = ...;
513 // auto result = Literal::MakeTupleOwned(std::move(elements));
514 //
515 // This would have been declared as an overload, but there is ambiguity
516 // in invocation between the above signature and this one.
517 static std::unique_ptr<Literal> MakeTupleOwned(
518 std::vector<std::unique_ptr<Literal>> elements);
519
520 // This overload lets you pass a braced list of unique_ptr<Literal>s to
521 // MakeTupleOwned:
522 //
523 // Literal::MakeTupleOwned(Literal::CreateR1(...), ...).
524 //
525 // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
526 // overload doesn't work because std::initializer_list's elements are always
527 // const.
528 //
529 // The arguments to this function must all be unique_ptr<Literal>.
530 template <typename... Ts>
531 static std::unique_ptr<Literal> MakeTupleOwned(
532 std::unique_ptr<Ts>... elements) {
533 std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
534 std::move(elements)...};
535 std::vector<std::unique_ptr<Literal>> v;
536 v.insert(v.begin(), std::make_move_iterator(arr.begin()),
537 std::make_move_iterator(arr.end()));
538 return MakeTupleOwned(std::move(v));
539 }
540
541 // Returns a string representation of the literal value.
542 // Warning: this function can take minutes for multi-million element Literals.
543 string ToString(bool print_layout = false) const;
544
545 // Invokes the "per cell" callback for each element in the provided
546 // literal with the element's indices and a string representation of
547 // the element's value.
548 //
549 // This function is useful if you want a polymorphic representation
550 // of the tensor's elements (turning it to a string for something
551 // like representation in a protobuf).
552 //
553 // This literal must have a dense layout.
554 void EachCellAsString(
555 const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
556 const string& value)>& per_cell) const;
557 template <typename NativeT>
558 void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
559 NativeT value)>
560 per_cell) const;
561
562 // Populate this literal with the given values. Examples:
563 //
564 // // Populate with floats.
565 // Array2D<float> float_values = ...
566 // literal.PopulateR2FromArray2D(values);
567 //
568 // // Populate with int32s.
569 // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
570 //
571 // The shape and element type of this literal must match given values. For
572 // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
573 // array of S32.
574 template <typename NativeT>
575 void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
576 void PopulateR1(const tensorflow::core::Bitmap& values);
577 template <typename NativeT>
578 void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
579 template <typename NativeT>
580 void PopulateFromArray(const Array<NativeT>& values);
581 template <typename NativeT>
582 void PopulateR2FromArray2D(const Array2D<NativeT>& values);
583 template <typename NativeT>
584 void PopulateR3FromArray3D(const Array3D<NativeT>& values);
585 template <typename NativeT>
586 void PopulateR4FromArray4D(const Array4D<NativeT>& values);
587
588 // Populates literal values by calling the generator function for every cell
589 // in this literal object.
590 //
591 // generator must be a callable of the type
592 // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
593 //
594 // This literal must have a dense layout.
595 template <typename NativeT, typename FnType>
596 Status Populate(const FnType& generator);
597
598 // A parallel version of Populate(). This can be used if the generator is
599 // thread-safe and the values for the shape's different elements are
600 // independent.
601 template <typename NativeT, typename FnType>
602 Status PopulateParallel(const FnType& generator);
603
604 // Fills this literal with the given value.
605 template <typename NativeT>
606 void PopulateWithValue(NativeT value);
607
608 // Returns whether every element in this literal is equal to value.
609 //
610 // value is an int8 because we expect this to be called with small
611 // compile-time constants (0, -1, etc.) and so that whatever value you pass
612 // can be represented exactly by floating-point types as small as 16 bits.
613 //
614 // If value doesn't fit in this literal's type, returns false. Values of 1/0
615 // are considered equal to true/false; other values are not considered equal
616 // to true. Also if this literal is not array-shaped false is returned.
617 bool IsAll(int8 value) const;
618
619 // Like IsAll(const Literal&, int8), except we check whether the literal is
620 // equal to a particular floating-point number.
621 //
622 // If the literal is not a floating-point value, this always returns false.
623 //
624 // This casts value to the type of literal, then compares using ==. The usual
625 // admonishments about floating-point equality checks apply. We expect you to
626 // use this to check for values that can be expressed precisely as a float,
627 // e.g. -0.5. Also if this literal is not array-shaped false is returned.
628 bool IsAllFloat(float value) const;
629
630 // Like IsAll(const Literal&, int8), except we check whether the literal is
631 // equal to a particular complex number.
632 //
633 // If the literal is not a complex value, this always returns false.
634 //
635 // This casts value to the type of literal, then compares using ==. The usual
636 // admonishments about floating-point equality checks apply. We expect you to
637 // use this to check for complex values that can be expressed precisely as
638 // float pairs e.g. (-0.5, 1.0).
639 //
640 // This literal must have a dense layout.
641 bool IsAllComplex(complex64 value) const;
642
643 // Literal consists entirely of the first element of the literal.
644 bool IsAllFirst() const;
645
646 // Returns whether this literal is zero at the specified index. This literal
647 // must be an array with a dense layout.
648 bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
649
650 // Return the count of the elements in the array at the given shape index in
651 // this literal.
652 int64 element_count(const ShapeIndex& index = {}) const {
653 return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
654 }
655
656 // Return the count of the elements in the sparse array at the given shape
657 // index in this literal, which will be no larger than
658 // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
659 int64 sparse_element_count() const;
660
661 protected:
662 // 'allocate_arrays' indicates whether to allocate memory for the arrays in
663 // the shape. If false, buffer pointers inside of the Literal::Pieces are set
664 // to nullptr.
665 Literal(const Shape& shape, bool allocate_arrays);
666
667 // Internal template helper for the Literal::CopySliceFrom(), matching its
668 // arguments one by one.
669 template <typename NativeT>
670 Status CopySliceFromInternal(const Literal& src_literal,
671 tensorflow::gtl::ArraySlice<int64> src_base,
672 tensorflow::gtl::ArraySlice<int64> dest_base,
673 tensorflow::gtl::ArraySlice<int64> copy_size);
674
675 // Utility structure which is used to create the optimal configuration for
676 // a ShapeUtil::ForEachIndex() scan across two literals.
677 struct StrideConfig {
678 StrideConfig(const Shape& source_shape, const Shape& dest_shape,
679 tensorflow::gtl::ArraySlice<int64> dimensions);
680
681 // The dimensions of the stride operation. Essentially every dimension
682 // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
683 // steps.
684 tensorflow::gtl::ArraySlice<int64> dimensions;
685 DimensionVector base;
686 DimensionVector step;
687 int64 minor_dimension = 0;
688 // The size of the strides for source and destination. One of the two
689 // (the one looping through its most minor dimension) will be 1, while
690 // the other will be the stride size at the dimension matching the other
691 // shape most minor dimension being scanned.
692 int64 dest_stride = 1;
693 int64 source_stride = 1;
694 // The size of the inner loop on the most minor dimension.
695 int64 minor_loop_size = 1;
696 };
697
698 // A data structure representing a subshape at a particular ShapeIndex within
699 // the literal. For array-shaped ShapeIndexes, this data structure holds the
700 // pointer to the memory allocated for the array data.
701 class Piece {
702 public:
703 // Return the buffer holding the array data for this piece as an array
704 // slice. This piece must be array-shaped.
705 template <typename NativeT>
706 tensorflow::gtl::ArraySlice<NativeT> data() const;
707 template <typename NativeT>
708 tensorflow::gtl::MutableArraySlice<NativeT> data();
709
710 // Return the buffer holding the array data for this piece as a void*. This
711 // piece must be array-shaped.
712 void* untyped_data();
713 const void* untyped_data() const;
714
715 // Gets or sets an element in the array at the given index. The multi_index
716 // is CHECKed against the dimension sizes of the array. This piece must be
717 // array-shaped.
718 template <typename NativeT>
719 NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
720 template <typename NativeT>
721 void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
722
723 // Gets/sets the buffer holding the array data.
724 char* buffer() const { return buffer_; }
725 void set_buffer(char* buffer) { buffer_ = buffer; }
726
727 // The array of multi-indices that provide the locations of non-zero
728 // elements in a sparse array. Only used if
729 // LayoutUtil::IsSparseArray(shape()) is true.
730 SparseIndexArray* sparse_indices() const { return sparse_indices_; }
731 void set_sparse_indices(SparseIndexArray* sparse_indices) {
732 sparse_indices_ = sparse_indices;
733 }
734
735 // Gets or sets the subshape of this piece. This reference points to a
736 // subshape within the shape in the containing Literal (Literal::shape_).
737 const Shape& subshape() const { return *subshape_; }
738 void set_subshape(const Shape* subshape) { subshape_ = subshape; }
739
740 // Returns the size in bytes of the buffer holding the array data.
741 int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
742
743 // Returns the number of elements in this piece's array.
744 int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
745
746 // Copy the data from 'src' into this piece's buffer. Shapes of this piece
747 // and src must be compatible.
748 Status CopyFrom(const Piece& src);
749
750 // Returns true if this piece and 'other' contain the same data. This piece
751 // and 'other' must be array-shaped and compatible.
752 bool EqualElements(const Piece& other) const;
753
754 // Writes the shape and data (if array-shaped) into the given proto.
755 void WriteToProto(LiteralProto* proto) const;
756
757 // Copies the data from the given proto into this piece. The shape of this
758 // piece must be equal (not just compatible) to the shape of the proto.
759 Status CopyFromProto(const LiteralProto& proto);
760
761 // Sorts the elements in a sparse array.
762 void SortSparseElements();
763
764 private:
765 // Recursive helper for EqualElements.
766 template <typename NativeT>
767 bool EqualElementsInternal(const Piece& other,
768 std::vector<int64>* multi_index) const;
769
770 // Helper for SortSparseElements that has the element type as a template
771 // parameter.
772 template <typename NativeT>
773 void SortSparseElementsInternal();
774
775 // For array-shaped pieces, this is the buffer holding the literal data.
776 char* buffer_ = nullptr;
777
778 // For sparse arrays, this is the array of indices.
779 SparseIndexArray* sparse_indices_ = nullptr;
780
781 // The shape of piece. This points into the shape of the containing Literal
782 // (Literal::shape_).
783 const Shape* subshape_ = nullptr;
784 };
785
786 // Returns the piece at the given ShapeIndex.
787 Piece& piece(const ShapeIndex& shape_index) {
788 return *pieces_.mutable_element(shape_index);
789 }
790 const Piece& piece(const ShapeIndex& shape_index) const {
791 return pieces_.element(shape_index);
792 }
793
794 // Returns the piece at the root of the shape (empty ShapeIndex).
795 Piece& root_piece() { return piece({}); }
796 const Piece& root_piece() const { return piece({}); }
797
798 // Deallocate the buffers held by this literal (if the literal owns the
799 // buffer).
800 void DeallocateBuffers();
801
802 // Implementation details shared between Populate() and PopulateParallel()
803 template <typename NativeT, typename FnType>
804 Status PopulateInternal(const FnType& generator, bool parallel);
805
806 Shape shape_;
807 ShapeTree<Piece> pieces_;
808
809 // Whether the buffers held in pieces_ are owned by this Literal.
810 bool owns_buffers_;
811
812 // LiteralView must access and manipulate Pieces of other Literals.
813 friend class LiteralView;
814}; // namespace xla
815
816std::ostream& operator<<(std::ostream& out, const Literal& literal);
817
818// A read-only view of a Literal. A LiteralView contains pointers to buffers
819// owned by the viewed Literal.
820//
821// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
822// and mutable) similar to (Mutable)ArraySlice.
823class LiteralView : public Literal {
824 public:
825 // Create and return a view of the given literal rooted at the given shape
826 // index within the given literal. A factory is used rather than a public
827 // constructor because only const LiteralViews are supported. It's still
828 // possible to create non-const LiteralViews via the copy constructors, but
829 // the factory method makes it a bit less likely. Implementing literal slices
830 // will fix this undesirable situation (b/71550060).
831 static const LiteralView Create(const Literal& literal,
832 const ShapeIndex& view_root = {});
833
834 LiteralView(const LiteralView& other);
835 LiteralView& operator=(const LiteralView& other);
836
837 virtual ~LiteralView();
838
839 private:
840 LiteralView(const Literal& literal, const ShapeIndex& view_root);
841
842 // Helper for the copy constructor and copy assignment operator.
843 void CopyFrom(const LiteralView& other);
844};
845
846template <typename NativeT>
847tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
848 CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
849 CHECK_EQ(subshape().element_type(),
850 primitive_util::NativeToPrimitiveType<NativeT>())
851 << "Attempting to access "
852 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
853 << " type, but literal element type is "
854 << PrimitiveType_Name(subshape().element_type());
855 return tensorflow::gtl::ArraySlice<NativeT>(
856 reinterpret_cast<const NativeT*>(buffer()),
857 ShapeUtil::ElementsIn(subshape()));
858}
859
860template <typename NativeT>
861tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
862 CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
863 CHECK_EQ(subshape().element_type(),
864 primitive_util::NativeToPrimitiveType<NativeT>())
865 << "Attempting to access "
866 << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
867 << " type, but literal element type is "
868 << PrimitiveType_Name(subshape().element_type());
869 return tensorflow::gtl::MutableArraySlice<NativeT>(
870 reinterpret_cast<NativeT*>(buffer()), ShapeUtil::ElementsIn(subshape()));
871}
872
873template <typename NativeT>
874NativeT Literal::Piece::Get(
875 tensorflow::gtl::ArraySlice<int64> multi_index) const {
876 CHECK(LayoutUtil::IsDenseArray(subshape()));
877 return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
878 subshape(), multi_index)];
879}
880
881template <typename NativeT>
882void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
883 NativeT value) {
884 CHECK(LayoutUtil::IsDenseArray(subshape()));
885 data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
886 subshape(), multi_index)] = value;
887}
888
889template <typename NativeT>
890tensorflow::gtl::ArraySlice<NativeT> Literal::data(
891 const ShapeIndex& shape_index) const {
892 return piece(shape_index).data<NativeT>();
893}
894
895template <typename NativeT>
896tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
897 const ShapeIndex& shape_index) {
898 return piece(shape_index).data<NativeT>();
899}
900
901template <typename NativeT>
902inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
903 const ShapeIndex& shape_index) const {
904 return piece(shape_index).Get<NativeT>(multi_index);
905}
906
907template <typename NativeT>
908inline NativeT Literal::Get(
909 tensorflow::gtl::ArraySlice<int64> multi_index) const {
910 return root_piece().Get<NativeT>(multi_index);
911}
912
913template <typename NativeT>
914inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
915 const ShapeIndex& shape_index, NativeT value) {
916 return piece(shape_index).Set<NativeT>(multi_index, value);
917}
918
919template <typename NativeT>
920inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
921 NativeT value) {
922 return root_piece().Set<NativeT>(multi_index, value);
923}
924
925template <typename NativeT>
926/* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
927 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
928 primitive_util::NativeToPrimitiveType<NativeT>(), {}));
929 literal->Set({}, value);
930 return literal;
931}
932
933template <typename NativeT>
934/* static */ std::unique_ptr<Literal> Literal::CreateR1(
935 tensorflow::gtl::ArraySlice<NativeT> values) {
936 auto literal = MakeUnique<Literal>(
937 ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
938 {static_cast<int64>(values.size())}));
939 literal->PopulateR1(values);
940 return literal;
941}
942
943template <typename NativeT>
944/* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout(
945 std::initializer_list<std::initializer_list<NativeT>> values,
946 const Layout& layout) {
947 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
948 primitive_util::NativeToPrimitiveType<NativeT>(),
949 {static_cast<int64>(values.size()),
950 static_cast<int64>(values.begin()->size())},
951 AsInt64Slice(layout.minor_to_major())));
952 literal->PopulateR2(values);
953 return literal;
954}
955
956template <typename NativeT>
957/* static */ std::unique_ptr<Literal> Literal::CreateR2(
958 std::initializer_list<std::initializer_list<NativeT>> values) {
959 return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
960}
961
962template <typename NativeT>
963/* static */ std::unique_ptr<Literal> Literal::CreateR3WithLayout(
964 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
965 values,
966 const Layout& layout) {
967 const int64 d0 = values.size();
968 const int64 d1 = values.begin()->size();
969 const int64 d2 = values.begin()->begin()->size();
970 Array3D<NativeT> tmp(d0, d1, d2);
971 int64 i0 = 0;
972 for (auto d1_values : values) {
973 int64 i1 = 0;
974 for (auto d2_values : d1_values) {
975 int64 i2 = 0;
976 for (auto value : d2_values) {
977 tmp(i0, i1, i2) = value;
978 ++i2;
979 }
980 ++i1;
981 }
982 ++i0;
983 }
984 return CreateR3FromArray3DWithLayout(tmp, layout);
985}
986
987template <typename NativeT>
988/* static */ std::unique_ptr<Literal> Literal::CreateR3(
989 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
990 values) {
991 return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
992}
993
994template <typename NativeT>
995/* static */ std::unique_ptr<Literal> Literal::CreateR4WithLayout(
996 std::initializer_list<std::initializer_list<
997 std::initializer_list<std::initializer_list<NativeT>>>>
998 values,
999 const Layout& layout) {
1000 const int64 d0 = values.size();
1001 const int64 d1 = values.begin()->size();
1002 const int64 d2 = values.begin()->begin()->size();
1003 const int64 d3 = values.begin()->begin()->begin()->size();
1004 Array4D<NativeT> tmp(d0, d1, d2, d3);
1005 int64 i0 = 0;
1006 for (auto d1_values : values) {
1007 int64 i1 = 0;
1008 for (auto d2_values : d1_values) {
1009 int64 i2 = 0;
1010 for (auto d3_values : d2_values) {
1011 int64 i3 = 0;
1012 for (auto value : d3_values) {
1013 tmp(i0, i1, i2, i3) = value;
1014 ++i3;
1015 }
1016 ++i2;
1017 }
1018 ++i1;
1019 }
1020 ++i0;
1021 }
1022 return CreateR4FromArray4DWithLayout(tmp, layout);
1023}
1024
1025template <typename NativeT>
1026/* static */ std::unique_ptr<Literal> Literal::CreateSparse(
1027 tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
1028 tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
1029 int64 num_elements = values.size();
1030 int64 rank = dimensions.size();
1031 CHECK_EQ(num_elements, indices.index_count());
1032 CHECK_EQ(rank, indices.rank());
1033 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
1034 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
1035 indices.max_indices()));
1036 literal->PopulateSparse(indices, values, sort);
1037 return literal;
1038}
1039
1040template <typename NativeT>
1041/* static */ std::unique_ptr<Literal> Literal::CreateR4(
1042 std::initializer_list<std::initializer_list<
1043 std::initializer_list<std::initializer_list<NativeT>>>>
1044 values) {
1045 return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
1046}
1047
1048template <typename NativeT>
1049/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
1050 const Array<NativeT>& values, const Layout& layout) {
1051 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
1052 primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
1053 AsInt64Slice(layout.minor_to_major())));
1054 literal->PopulateFromArray(values);
1055 return literal;
1056}
1057
1058template <typename NativeT>
1059/* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
1060 const Array<NativeT>& values) {
1061 return CreateFromArrayWithLayout(
1062 values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
1063}
1064
1065template <typename NativeT>
1066/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
1067 const Array2D<NativeT>& values, const Layout& layout) {
1068 return CreateFromArrayWithLayout(values, layout);
1069}
1070
1071template <typename NativeT>
1072/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
1073 const Array2D<NativeT>& values) {
1074 return CreateFromArray(values);
1075}
1076
1077template <typename NativeT>
1078/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
1079 const Array3D<NativeT>& values, const Layout& layout) {
1080 return CreateFromArrayWithLayout(values, layout);
1081}
1082
1083template <typename NativeT>
1084/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
1085 const Array3D<NativeT>& values) {
1086 return CreateFromArray(values);
1087}
1088
1089template <typename NativeT>
1090/* static */ std::unique_ptr<Literal> Literal::CreateR3Projected(
1091 std::initializer_list<std::initializer_list<NativeT>> values,
1092 int64 projection) {
1093 int64 dim0_size = projection;
1094 int64 dim1_size = values.size();
1095 int64 dim2_size = values.begin()->size();
1096
1097 Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
1098 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
1099 int64 dim1 = 0;
1100 for (auto inner_list : values) {
1101 int64 dim2 = 0;
1102 for (auto value : inner_list) {
1103 array(dim0, dim1, dim2) = value;
1104 ++dim2;
1105 }
1106 CHECK_EQ(dim2_size, dim2);
1107 ++dim1;
1108 }
1109 CHECK_EQ(dim1_size, dim1);
1110 }
1111 return CreateR3FromArray3D(array);
1112}
1113
1114template <typename NativeT>
1115/* static */ std::unique_ptr<Literal> Literal::CreateR4Projected(
1116 std::initializer_list<std::initializer_list<NativeT>> values,
1117 int64 projection_p, int64 projection_z) {
1118 int64 dim0_size = projection_p;
1119 int64 dim1_size = projection_z;
1120 int64 dim2_size = values.size();
1121 int64 dim3_size = values.begin()->size();
1122
1123 Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
1124 for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
1125 for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
1126 int64 dim2 = 0;
1127 for (auto inner_list : values) {
1128 int64 dim3 = 0;
1129 for (auto value : inner_list) {
1130 array(dim0, dim1, dim2, dim3) = value;
1131 ++dim3;
1132 }
1133 CHECK_EQ(dim3_size, dim3);
1134 ++dim2;
1135 }
1136 CHECK_EQ(dim2_size, dim2);
1137 }
1138 }
1139 return CreateR4FromArray4D(array);
1140}
1141
1142template <typename NativeT>
1143/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
1144 const Array4D<NativeT>& values) {
1145 return CreateFromArray(values);
1146}
1147
1148template <typename NativeT>
1149/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
1150 const Array4D<NativeT>& values, const Layout& layout) {
1151 return CreateFromArrayWithLayout(values, layout);
1152}
1153
1154template <typename NativeT>
1155NativeT Literal::GetFirstElement() const {
1156 return data<NativeT>().at(0);
1157}
1158
1159template <typename NativeT>
1160NativeT Literal::GetSparseElement(int64 sparse_element_number,
1161 const ShapeIndex& shape_index) const {
1162 CHECK(
1163 LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
1164 return data<NativeT>(shape_index)[sparse_element_number];
1165}
1166
1167template <typename NativeT>
1168void Literal::AppendSparseElement(
1169 tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
1170 const ShapeIndex& shape_index) {
1171 Piece& p = piece(shape_index);
1172 const Shape& subshape = p.subshape();
1173 CHECK(LayoutUtil::IsSparseArray(subshape));
1174 int64 rank = ShapeUtil::Rank(subshape);
1175 CHECK_EQ(multi_index.size(), rank);
1176 int64 last_element = p.sparse_indices()->index_count();
1177 CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
1178 p.sparse_indices()->Append(multi_index);
1179 CHECK_LT(last_element, p.data<NativeT>().size());
1180 p.data<NativeT>()[last_element] = value;
1181}
1182
1183// Returns an identity matrix (rank 2) with the given row and column count.
1184template <typename NativeT>
1185/* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) {
1186 Array2D<NativeT> array(size, size, 0);
1187 for (int64 i = 0; i < size; ++i) {
1188 array(i, i) = 1;
1189 }
1190 return CreateR2FromArray2D(array);
1191}
1192
1193template <typename NativeT>
1194void Literal::EachCell(
1195 std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
1196 NativeT value)>
1197 per_cell) const {
1198 if (ShapeUtil::HasZeroElements(shape())) {
1199 return;
1200 }
1201 std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
1202 do {
1203 per_cell(indices, Get<NativeT>(indices));
1204 } while (IndexUtil::BumpIndices(shape(), &indices));
1205}
1206
1207template <typename NativeT>
1208inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
1209 CHECK(ShapeUtil::IsArray(shape()));
1210 CHECK_EQ(ShapeUtil::Rank(shape()), 1);
1211 CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
1212 CHECK_EQ(shape().element_type(),
1213 primitive_util::NativeToPrimitiveType<NativeT>());
1214 for (int64 i = 0; i < values.size(); ++i) {
1215 Set({i}, values[i]);
1216 }
1217}
1218
1219template <typename NativeT>
1220void Literal::PopulateR2(
1221 std::initializer_list<std::initializer_list<NativeT>> values) {
1222 CHECK(ShapeUtil::IsArray(shape()));
1223 CHECK_EQ(ShapeUtil::Rank(shape()), 2);
1224 CHECK_EQ(shape().element_type(),
1225 primitive_util::NativeToPrimitiveType<NativeT>());
1226
1227 const int64 dim0_size = values.size();
1228 const int64 dim1_size = values.begin()->size();
1229 CHECK_EQ(dim0_size, shape().dimensions(0));
1230 CHECK_EQ(dim1_size, shape().dimensions(1));
1231
1232 int64 dim0 = 0;
1233 for (auto inner_list : values) {
1234 int64 dim1 = 0;
1235 for (auto value : inner_list) {
1236 Set({dim0, dim1}, value);
1237 ++dim1;
1238 }
1239 CHECK_EQ(dim1_size, dim1);
1240 ++dim0;
1241 }
1242}
1243
1244template <typename NativeT>
1245void Literal::PopulateFromArray(const Array<NativeT>& values) {
1246 CHECK(ShapeUtil::IsArray(shape()));
1247 CHECK_EQ(shape().element_type(),
1248 primitive_util::NativeToPrimitiveType<NativeT>());
1249 CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
1250 for (int dim = 0; dim < values.num_dimensions(); ++dim) {
1251 CHECK_EQ(values.dim(dim), shape().dimensions(dim));
1252 }
1253 values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
1254 NativeT value) { this->Set(indices, value); });
1255}
1256
1257template <typename NativeT>
1258void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
1259 PopulateFromArray(values);
1260}
1261
1262template <typename NativeT>
1263void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
1264 PopulateFromArray(values);
1265}
1266
1267template <typename NativeT>
1268void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
1269 PopulateFromArray(values);
1270}
1271
1272template <typename NativeT>
1273void Literal::PopulateSparse(SparseIndexArray indices,
1274 tensorflow::gtl::ArraySlice<NativeT> values,
1275 bool sort) {
1276 CHECK(LayoutUtil::IsSparseArray(shape()));
1277 int rank = ShapeUtil::Rank(shape());
1278 CHECK_EQ(indices.rank(), rank);
1279 int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
1280 CHECK_LE(indices.max_indices(), max_elements);
1281 int64 num_elements = values.size();
1282 CHECK_LE(num_elements, max_elements);
1283 CHECK_EQ(num_elements, indices.index_count());
1284 auto root_data = root_piece().data<NativeT>();
1285 root_data.remove_suffix(max_elements - values.size());
1286 std::copy(values.begin(), values.end(), root_data.begin());
1287 *this->root_piece().sparse_indices() = std::move(indices);
1288 if (sort) {
1289 auto root_data = this->root_piece().data<NativeT>();
1290 root_data.remove_suffix(root_data.size() - num_elements);
1291 this->root_piece().sparse_indices()->SortWithValues(root_data);
1292 }
1293 DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
1294}
1295
1296template <typename NativeT, typename FnType>
1297Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
1298 const Shape& this_shape = shape();
1299 const int64 rank = ShapeUtil::Rank(this_shape);
1300 TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
1301 TF_RET_CHECK(this_shape.element_type() ==
1302 primitive_util::NativeToPrimitiveType<NativeT>());
1303 tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
1304 if (rank > 0) {
1305 StrideConfig stride_config(this_shape, this_shape,
1306 AsInt64Slice(this_shape.dimensions()));
1307 int64 minor_dimension_size =
1308 ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
1309
1310 auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
1311 DimensionVector minor_scan_indexes(rank, 0);
1312 const int64 index =
1313 IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
1314 std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
1315 for (int64 i = 0; i < minor_dimension_size; ++i) {
1316 minor_scan_indexes[stride_config.minor_dimension] = i;
1317 literal_data.at(index + i) = generator(minor_scan_indexes);
1318 }
1319 };
1320 if (parallel) {
1321 ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
1322 stride_config.dimensions,
1323 stride_config.step, init_function);
1324 } else {
1325 ShapeUtil::ForEachIndex(
1326 this_shape, stride_config.base, stride_config.dimensions,
1327 stride_config.step,
1328 [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
1329 init_function(indexes);
1330 return true;
1331 });
1332 }
1333 } else {
1334 // For scalars.
1335 literal_data.at(0) = generator({});
1336 }
1337 return Status::OK();
1338}
1339template <typename NativeT, typename FnType>
1340Status Literal::Populate(const FnType& generator) {
1341 return PopulateInternal<NativeT>(generator, /*parallel=*/false);
1342}
1343
1344template <typename NativeT, typename FnType>
1345Status Literal::PopulateParallel(const FnType& generator) {
1346 return PopulateInternal<NativeT>(generator, /*parallel=*/true);
1347}
1348
1349template <typename NativeT>
1350void Literal::PopulateWithValue(NativeT value) {
1351 CHECK(ShapeUtil::IsArray(shape()));
1352 CHECK_EQ(shape().element_type(),
1353 primitive_util::NativeToPrimitiveType<NativeT>());
1354 for (NativeT& element : data<NativeT>()) {
1355 element = value;
1356 }
1357}
1358
1359template <typename NativeT>
1360/* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
1361 tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
1362 auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
1363 primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
1364 literal->PopulateWithValue(value);
1365 return literal;
1366}
1367
1368template <typename NativeT>
1369std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
1370 DimensionVector bounds = {times};
1371 bounds.reserve(shape().dimensions_size() + 1);
1372 for (int64 bound : shape().dimensions()) {
1373 bounds.push_back(bound);
1374 }
1375 auto literal =
1376 MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
1377 int64 elements = ShapeUtil::ElementsIn(literal->shape());
1378 if (elements == 0) {
1379 return literal;
1380 }
1381
1382 DimensionVector output_indices(bounds.size(), 0);
1383 tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
1384 input_indices.remove_prefix(1);
1385
1386 bool done = false;
1387 while (!done) {
1388 const auto element = Get<NativeT>(input_indices);
1389 literal->Set<NativeT>(output_indices, element);
1390
1391 done = true;
1392 for (int n = 0; n < output_indices.size(); ++n) {
1393 ++output_indices[n];
1394 if (output_indices[n] < bounds[n]) {
1395 done = false;
1396 break;
1397 }
1398 output_indices[n] = 0;
1399 }
1400 }
1401 return literal;
1402}
1403
1404} // namespace xla
1405
1406#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
1407