1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
17#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/framework/allocator.h"
21#include "tensorflow/core/framework/tensor_shape.h"
22#include "tensorflow/core/framework/tensor_types.h"
23#include "tensorflow/core/framework/types.h"
24#include "tensorflow/core/framework/types.pb.h"
25#include "tensorflow/core/lib/core/refcount.h"
26#include "tensorflow/core/lib/core/status.h"
27#include "tensorflow/core/lib/core/stringpiece.h"
28#include "tensorflow/core/lib/gtl/inlined_vector.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/core/platform/macros.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace tensorflow {
34
35// Forward declarations. In particular, we forward declare protos so that their
36// symbols can be removed from .so exports.
37class AllocationDescription;
38class Allocator;
39class OpKernelContext;
40class TensorBuffer;
41class TensorCApi;
42class TensorDescription;
43class TensorProto;
44class VariantTensorData;
45namespace batch_util {
46Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
47} // namespace batch_util
48
49/// @ingroup core
50/// Represents an n-dimensional array of values.
51class Tensor {
52 public:
53 /// \brief Creates a 1-dimensional, 0-element float tensor.
54 ///
55 /// The returned Tensor is not a scalar (shape {}), but is instead
56 /// an empty one-dimensional Tensor (shape {0}, NumElements() ==
57 /// 0). Since it has no elements, it does not need to be assigned a
58 /// value and is initialized by default (IsInitialized() is
59 /// true). If this is undesirable, consider creating a one-element
60 /// scalar which does require initialization:
61 ///
62 /// ```c++
63 ///
64 /// Tensor(DT_FLOAT, TensorShape({}))
65 ///
66 /// ```
67 Tensor();
68
69 /// \brief Creates a Tensor of the given `type` and `shape`. If
70 /// LogMemory::IsEnabled() the allocation is logged as coming from
71 /// an unknown kernel and step. Calling the Tensor constructor
72 /// directly from within an Op is deprecated: use the
73 /// OpKernelConstruction/OpKernelContext allocate_* methods to
74 /// allocate a new tensor, which record the kernel and step.
75 ///
76 /// The underlying buffer is allocated using a `CPUAllocator`.
77 Tensor(DataType type, const TensorShape& shape);
78
79 /// \brief Creates a tensor with the input `type` and `shape`, using
80 /// the allocator `a` to allocate the underlying buffer. If
81 /// LogMemory::IsEnabled() the allocation is logged as coming from
82 /// an unknown kernel and step. Calling the Tensor constructor
83 /// directly from within an Op is deprecated: use the
84 /// OpKernelConstruction/OpKernelContext allocate_* methods to
85 /// allocate a new tensor, which record the kernel and step.
86 ///
87 /// `a` must outlive the lifetime of this Tensor.
88 Tensor(Allocator* a, DataType type, const TensorShape& shape);
89
90 /// \brief Creates a tensor with the input `type` and `shape`, using
91 /// the allocator `a` and the specified "allocation_attr" to
92 /// allocate the underlying buffer. If the kernel and step are known
93 /// allocation_attr.allocation_will_be_logged should be set to true
94 /// and LogMemory::RecordTensorAllocation should be called after the
95 /// tensor is constructed. Calling the Tensor constructor directly
96 /// from within an Op is deprecated: use the
97 /// OpKernelConstruction/OpKernelContext allocate_* methods to
98 /// allocate a new tensor, which record the kernel and step.
99 ///
100 /// `a` must outlive the lifetime of this Tensor.
101 Tensor(Allocator* a, DataType type, const TensorShape& shape,
102 const AllocationAttributes& allocation_attr);
103
104 /// \brief Creates an empty Tensor of the given data type.
105 ///
106 /// Like Tensor(), returns a 1-dimensional, 0-element Tensor with
107 /// IsInitialized() returning True. See the Tensor() documentation
108 /// for details.
109 explicit Tensor(DataType type);
110
111 /// Copy constructor.
112 Tensor(const Tensor& other);
113
114 /// \brief Move constructor. After this call, <other> is safely destructible
115 /// and can be assigned to, but other calls on it (e.g. shape manipulation)
116 /// are not valid.
117 Tensor(Tensor&& other);
118
119 ~Tensor();
120
121 /// Returns the data type.
122 DataType dtype() const { return shape_.data_type(); }
123
124 /// Returns the shape of the tensor.
125 const TensorShape& shape() const { return shape_; }
126
127 /// \brief Convenience accessor for the tensor shape.
128 ///
129 /// For all shape accessors, see comments for relevant methods of
130 /// `TensorShape` in `tensor_shape.h`.
131 int dims() const { return shape().dims(); }
132
133 /// Convenience accessor for the tensor shape.
134 int64 dim_size(int d) const { return shape().dim_size(d); }
135
136 /// Convenience accessor for the tensor shape.
137 int64 NumElements() const { return shape().num_elements(); }
138
139 bool IsSameSize(const Tensor& b) const {
140 return shape().IsSameSize(b.shape());
141 }
142
143 // True iff the two tensors use the same underlying refcounted storage
144 bool SharesBufferWith(const Tensor& b) const;
145
146 /// \brief If necessary, has this Tensor been initialized?
147 ///
148 /// Zero-element Tensors are always considered initialized, even if they
149 /// have never been assigned to and do not have any memory allocated.
150 bool IsInitialized() const;
151
152 /// Returns the estimated memory usage of this tensor.
153 size_t TotalBytes() const;
154
155 // Returns the size of sallocated memory for this tensor.
156 size_t AllocatedBytes() const;
157
158 /// Returns true iff this tensor is aligned.
159 bool IsAligned() const {
160#if EIGEN_MAX_ALIGN_BYTES == 0
161 return true;
162#else
163 void* ptr = base<void>();
164 return reinterpret_cast<intptr_t>(ptr) % EIGEN_MAX_ALIGN_BYTES == 0;
165#endif
166 }
167
168 /// Assign operator. This tensor shares other's underlying storage.
169 Tensor& operator=(const Tensor& other) {
170 CopyFromInternal(other, other.shape());
171 return *this;
172 }
173
174 /// Move operator. See move constructor for details.
175 Tensor& operator=(Tensor&& other);
176
177 /// \brief Copy the other tensor into this tensor and reshape it.
178 ///
179 /// This tensor shares other's underlying storage. Returns `true`
180 /// iff `other.shape()` has the same number of elements of the given
181 /// `shape`.
182 bool CopyFrom(const Tensor& other,
183 const TensorShape& shape) TF_MUST_USE_RESULT {
184 if (other.NumElements() != shape.num_elements()) return false;
185 CopyFromInternal(other, shape);
186 return true;
187 }
188
189 /// \brief Slice this tensor along the 1st dimension.
190
191 /// I.e., the returned tensor satisfies
192 /// returned[i, ...] == this[dim0_start + i, ...].
193 /// The returned tensor shares the underlying tensor buffer with this
194 /// tensor.
195 ///
196 /// NOTE: The returned tensor may not satisfy the same alignment
197 /// requirement as this tensor depending on the shape. The caller
198 /// must check the returned tensor's alignment before calling certain
199 /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
200 ///
201 /// REQUIRES: `dims()` >= 1
202 /// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
203 Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
204
205 /// \brief Parse `other` and construct the tensor.
206
207 /// Returns `true` iff the parsing succeeds. If the parsing fails,
208 /// the state of `*this` is unchanged.
209 bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
210 bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;
211
212 /// \brief Fills in `proto` with `*this` tensor's content.
213 ///
214 /// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while
215 /// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()`
216 /// in a compact form.
217 void AsProtoField(TensorProto* proto) const;
218 void AsProtoTensorContent(TensorProto* proto) const;
219
220 /// \brief Return the tensor data as an `Eigen::Tensor` with the type and
221 /// sizes of this `Tensor`.
222 ///
223 /// Use these methods when you know the data type and the number of
224 /// dimensions of the Tensor and you want an `Eigen::Tensor`
225 /// automatically sized to the `Tensor` sizes. The implementation check
226 /// fails if either type or sizes mismatch.
227 ///
228 /// Example:
229 ///
230 /// ```c++
231 ///
232 /// typedef float T;
233 /// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...);
234 /// auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5.
235 /// auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5.
236 /// auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D.
237 /// auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D.
238 /// auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
239 ///
240 /// ```
241 template <typename T>
242 typename TTypes<T>::Vec vec() {
243 return tensor<T, 1>();
244 }
245
246 template <typename T>
247 typename TTypes<T>::Matrix matrix() {
248 return tensor<T, 2>();
249 }
250
251 template <typename T, size_t NDIMS>
252 typename TTypes<T, NDIMS>::Tensor tensor();
253
254 /// \brief Return the tensor data to an `Eigen::Tensor` with the
255 /// same size but a bitwise cast to the specified dtype `T`.
256 ///
257 /// Using a bitcast is useful for move and copy operations.
258 /// NOTE: this is the same as `tensor()` except a bitcast is allowed.
259 template <typename T, size_t NDIMS>
260 typename TTypes<T, NDIMS>::Tensor bit_casted_tensor();
261
262 /// \brief Return the tensor data to an `Eigen::Tensor` with the
263 /// last dimension elements converted into single elements of a larger type.
264 ///
265 /// For example, this is useful for kernels that can treat NCHW_VECT_C int8
266 /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of
267 /// the original element type * num elements in the original last dimension.
268 /// NDIMS should be 1 less than the original number of dimensions.
269 template <typename T, size_t NDIMS>
270 typename TTypes<T, NDIMS>::Tensor reinterpret_last_dimension();
271
272 /// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a
273 /// specified shape.
274 ///
275 /// These methods allow you to access the data with the dimensions
276 /// and sizes of your choice. You do not need to know the number of
277 /// dimensions of the Tensor to call them. However, they `CHECK` that
278 /// the type matches and the dimensions requested creates an
279 /// `Eigen::Tensor` with the same number of elements as the tensor.
280 ///
281 /// Example:
282 ///
283 /// ```c++
284 ///
285 /// typedef float T;
286 /// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...);
287 /// // 1D Eigen::Tensor, size 60:
288 /// auto flat = my_ten.flat<T>();
289 /// // 2D Eigen::Tensor 12 x 5:
290 /// auto inner = my_ten.flat_inner_dims<T>();
291 /// // 2D Eigen::Tensor 4 x 15:
292 /// auto outer = my_ten.shaped<T, 2>({4, 15});
293 /// // CHECK fails, bad num elements:
294 /// auto outer = my_ten.shaped<T, 2>({4, 8});
295 /// // 3D Eigen::Tensor 6 x 5 x 2:
296 /// auto weird = my_ten.shaped<T, 3>({6, 5, 2});
297 /// // CHECK fails, type mismatch:
298 /// auto bad = my_ten.flat<int32>();
299 ///
300 /// ```
301 template <typename T>
302 typename TTypes<T>::Flat flat() {
303 return shaped<T, 1>({NumElements()});
304 }
305
306 template <typename T>
307 typename TTypes<T>::UnalignedFlat unaligned_flat() {
308 return unaligned_shaped<T, 1>({NumElements()});
309 }
310
311 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
312 /// Tensor dimensions but the last NDIMS-1 into the first dimension of the
313 /// result. If NDIMS > dims() then leading dimensions of size 1 will be
314 /// added to make the output rank NDIMS.
315 template <typename T, size_t NDIMS = 2>
316 typename TTypes<T, NDIMS>::Tensor flat_inner_dims();
317
318 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
319 /// Tensor dimensions but the first NDIMS-1 into the last dimension of the
320 /// result. If NDIMS > dims() then trailing dimensions of size 1 will be
321 /// added to make the output rank NDIMS.
322 template <typename T, size_t NDIMS = 2>
323 typename TTypes<T, NDIMS>::Tensor flat_outer_dims();
324
325 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the
326 /// first 'begin' Tensor dimensions into the first dimension of the result and
327 /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last
328 /// dimension of the result. If 'begin' < 0 then the |'begin'| leading
329 /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then
330 /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added.
331 template <typename T, size_t NDIMS = 3>
332 typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64 begin);
333
334 template <typename T, size_t NDIMS>
335 typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
336
337 /// \brief Return the tensor data to an `Eigen::Tensor` with the new
338 /// shape specified in `new_sizes` and cast to a new dtype `T`.
339 ///
340 /// Using a bitcast is useful for move and copy operations.
341 /// The allowed bitcast is the only difference from `shaped()`.
342 template <typename T, size_t NDIMS>
343 typename TTypes<T, NDIMS>::Tensor bit_casted_shaped(
344 gtl::ArraySlice<int64> new_sizes);
345
346 template <typename T, size_t NDIMS>
347 typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped(
348 gtl::ArraySlice<int64> new_sizes);
349
350 /// \brief Return the Tensor data as a `TensorMap` of fixed size 1:
351 /// `TensorMap<TensorFixedSize<T, 1>>`.
352
353 /// Using `scalar()` allows the compiler to perform optimizations as
354 /// the size of the tensor is known at compile time.
355 template <typename T>
356 typename TTypes<T>::Scalar scalar();
357
358 /// Const versions of all the methods above.
359 template <typename T>
360 typename TTypes<T>::ConstVec vec() const {
361 return tensor<T, 1>();
362 }
363
364 template <typename T>
365 typename TTypes<T>::ConstMatrix matrix() const {
366 return tensor<T, 2>();
367 }
368
369 template <typename T, size_t NDIMS>
370 typename TTypes<T, NDIMS>::ConstTensor tensor() const;
371
372 /// \brief Return the tensor data to an `Eigen::Tensor` with the
373 /// same size but a bitwise cast to the specified dtype `T`.
374 ///
375 /// Using a bitcast is useful for move and copy operations.
376 /// NOTE: this is the same as `tensor()` except a bitcast is allowed.
377 template <typename T, size_t NDIMS>
378 typename TTypes<T, NDIMS>::ConstTensor bit_casted_tensor() const;
379
380 /// \brief Return the tensor data to an `Eigen::Tensor` with the
381 /// last dimension elements converted into single elements of a larger type.
382 ///
383 /// For example, this is useful for kernels that can treat NCHW_VECT_C int8
384 /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of
385 /// the original element type * num elements in the original last dimension.
386 /// NDIMS should be 1 less than the original number of dimensions.
387 template <typename T, size_t NDIMS>
388 typename TTypes<T, NDIMS>::ConstTensor reinterpret_last_dimension() const;
389
390 template <typename T>
391 typename TTypes<T>::ConstFlat flat() const {
392 return shaped<T, 1>({NumElements()});
393 }
394
395 template <typename T>
396 typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
397 return unaligned_shaped<T, 1>({NumElements()});
398 }
399
400 template <typename T, size_t NDIMS>
401 typename TTypes<T, NDIMS>::ConstTensor shaped(
402 gtl::ArraySlice<int64> new_sizes) const;
403
404 /// \brief Return the tensor data to an `Eigen::Tensor` with the new
405 /// shape specified in `new_sizes` and cast to a new dtype `T`.
406 ///
407 /// Using a bitcast is useful for move and copy operations.
408 /// The allowed bitcast is the only difference from `shaped()`.
409 template <typename T, size_t NDIMS>
410 typename TTypes<T, NDIMS>::ConstTensor bit_casted_shaped(
411 gtl::ArraySlice<int64> new_sizes) const;
412
413 template <typename T, size_t NDIMS>
414 typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
415 gtl::ArraySlice<int64> new_sizes) const;
416
417 template <typename T>
418 typename TTypes<T>::ConstScalar scalar() const;
419
420 template <typename T, size_t NDIMS = 2>
421 typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const;
422
423 template <typename T, size_t NDIMS = 2>
424 typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const;
425
426 template <typename T, size_t NDIMS = 3>
427 typename TTypes<T, NDIMS>::ConstTensor flat_inner_outer_dims(
428 int64 begin) const;
429
430 /// Render the first `max_entries` values in `*this` into a string.
431 string SummarizeValue(int64 max_entries) const;
432
433 /// A human-readable summary of the tensor suitable for debugging.
434 string DebugString() const;
435
436 /// Fill in the `TensorDescription` proto with metadata about the
437 /// tensor that is useful for monitoring and debugging.
438 void FillDescription(TensorDescription* description) const;
439
440 /// \brief Returns a `StringPiece` mapping the current tensor's buffer.
441 ///
442 /// The returned `StringPiece` may point to memory location on devices
443 /// that the CPU cannot address directly.
444 ///
445 /// NOTE: The underlying tensor buffer is refcounted, so the lifetime
446 /// of the contents mapped by the `StringPiece` matches the lifetime of
447 /// the buffer; callers should arrange to make sure the buffer does
448 /// not get destroyed while the `StringPiece` is still used.
449 ///
450 /// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
451 StringPiece tensor_data() const;
452
453 /// Copy the other tensor into this tensor and reshape it and reinterpret the
454 /// buffer's datatype.
455 ///
456 /// This tensor shares other's underlying storage.
457 void UnsafeCopyFromInternal(const Tensor&, DataType dtype,
458 const TensorShape&);
459
460 private:
461 // Returns true if the refcount on buf_ and any possible underlying root
462 // buffer is one.
463 bool RefCountIsOne() const;
464 void CheckType(DataType expected_dtype) const;
465 void CheckTypeAndIsAligned(DataType expected_dtype) const;
466 void CheckIsAlignedAndSingleElement() const;
467 void set_dtype(DataType t) { shape_.set_data_type(t); }
468
469 // TensorShape's InlineVector.
470 static gtl::InlinedVector<int64, 4> ComputeFlatInnerDims(
471 gtl::ArraySlice<int64> orig, int64 num_out_dims);
472 static gtl::InlinedVector<int64, 4> ComputeFlatOuterDims(
473 gtl::ArraySlice<int64> orig, int64 num_out_dims);
474
475 TensorShape shape_;
476 TensorBuffer* buf_;
477
478 friend class DMAHelper;
479 friend class TensorCApi;
480 friend class TensorReference; // For access to buf_
481 friend class VariableOp; // For access to set_shape
482 friend class AutoReloadVariableOp; // For access to set_shape
483 friend class TensorTestHelper; // For access to set_shape
484 friend class OpKernelContext; // For access to RefCountIsOne().
485 friend class ScopedAllocator; // For access to buf_.
486 friend class XlaTensorBuffer; // For access to the private constructor taking
487 // the buffer
488 template <typename Device, typename T>
489 friend class AssignVariableOp; // For access to RefCountIsOne().
490 template <typename Device, typename T>
491 friend Status PrepareToUpdateVariable(
492 OpKernelContext* ctx, Tensor* tensor); // For access to RefCountIsOne().
493 friend Status batch_util::CopyElementToSlice(
494 Tensor element, Tensor* parent,
495 int64 index); // For access to RefCountIsOne().
496 friend class NumpyTensorBuffer; // For access to the private constructor
497 // taking the buffer.
498
499 // Creates a tensor with the input datatype, shape and buf.
500 //
501 // Acquires a ref on buf that belongs to this Tensor.
502 Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
503
504 bool CanUseDMA() const;
505
506 // Only needed by variable op to set the shape of an uninitialized
507 // Tensor.
508 // TODO: Remove this when we have a better story for detecting
509 // uninitialized tensors.
510 void set_shape(const TensorShape& shape) {
511 DataType dt = dtype();
512 shape_ = shape;
513 set_dtype(dt);
514 }
515
516 void CopyFromInternal(const Tensor& other, const TensorShape& shape);
517
518 template <typename T>
519 T* base() const;
520
521 template <size_t NDIMS>
522 void FillDimsAndValidateCompatibleShape(
523 gtl::ArraySlice<int64> new_sizes,
524 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
525
526 template <typename T, size_t NDIMS>
527 void FillDimsAndValidateCompatibleShape(
528 gtl::ArraySlice<int64> new_sizes,
529 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
530};
531
532// Implementation details
533
534// START_SKIP_DOXYGEN
535
536// Interface to access the raw ref-counted data buffer.
537class TensorBuffer : public core::RefCounted {
538 public:
539 ~TensorBuffer() override {}
540
541 // data() points to a memory region of size() bytes.
542 virtual void* data() const = 0;
543 virtual size_t size() const = 0;
544
545 // If this TensorBuffer is sub-buffer of another TensorBuffer,
546 // returns that TensorBuffer. Otherwise, returns this.
547 virtual TensorBuffer* root_buffer() = 0;
548
549 // Fill metadata about the allocation into the proto.
550 virtual void FillAllocationDescription(
551 AllocationDescription* proto) const = 0;
552
553 template <typename T>
554 T* base() const {
555 return reinterpret_cast<T*>(data());
556 }
557
558 // Whether this TensorBuffer owns the underlying memory.
559 virtual bool OwnsMemory() const { return true; }
560};
561
562template <typename T>
563T* Tensor::base() const {
564 return buf_ == nullptr ? nullptr : buf_->base<T>();
565}
566
567template <typename T, size_t NDIMS>
568typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
569 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
570 return typename TTypes<T, NDIMS>::Tensor(base<T>(),
571 shape().AsEigenDSizes<NDIMS>());
572}
573
574template <typename T, size_t NDIMS>
575typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const {
576 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
577 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
578 shape().AsEigenDSizes<NDIMS>());
579}
580
581template <typename T, size_t NDIMS>
582typename TTypes<T, NDIMS>::Tensor Tensor::bit_casted_tensor() {
583 CHECK(IsAligned());
584 return typename TTypes<T, NDIMS>::Tensor(base<T>(),
585 shape().AsEigenDSizes<NDIMS>());
586}
587
588template <typename T, size_t NDIMS>
589typename TTypes<T, NDIMS>::ConstTensor Tensor::bit_casted_tensor() const {
590 CHECK(IsAligned());
591 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
592 shape().AsEigenDSizes<NDIMS>());
593}
594
595template <typename T, size_t NDIMS>
596typename TTypes<T, NDIMS>::Tensor Tensor::reinterpret_last_dimension() {
597 if (NDIMS == dims()) {
598 return tensor<T, NDIMS>();
599 }
600 CHECK(IsAligned());
601 CHECK_EQ(NDIMS, dims() - 1);
602 CHECK_EQ(sizeof(T), shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype()));
603 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
604 for (int d = 0; d < NDIMS; ++d) {
605 dims[d] = shape_.dim_sizes()[d];
606 }
607 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
608}
609
610template <typename T, size_t NDIMS>
611typename TTypes<T, NDIMS>::ConstTensor Tensor::reinterpret_last_dimension()
612 const {
613 if (NDIMS == dims()) {
614 return tensor<T, NDIMS>();
615 }
616 CHECK(IsAligned());
617 CHECK_EQ(NDIMS, dims() - 1);
618 CHECK_EQ(sizeof(T), shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype()));
619 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
620 for (int d = 0; d < NDIMS; ++d) {
621 dims[d] = shape_.dim_sizes()[d];
622 }
623 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(), dims);
624}
625
626template <size_t NDIMS>
627void Tensor::FillDimsAndValidateCompatibleShape(
628 gtl::ArraySlice<int64> new_sizes,
629 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const {
630 CHECK_EQ(NDIMS, new_sizes.size());
631 int64 new_num_elements = 1;
632 for (size_t d = 0; d < NDIMS; d++) {
633 new_num_elements *= new_sizes[d];
634 (*dims)[d] = new_sizes[d];
635 }
636 CHECK_EQ(new_num_elements, NumElements());
637}
638
639template <typename T, size_t NDIMS>
640void Tensor::FillDimsAndValidateCompatibleShape(
641 gtl::ArraySlice<int64> new_sizes,
642 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const {
643 CHECK_EQ(NDIMS, new_sizes.size());
644 int64 new_num_elements = 1;
645 for (size_t d = 0; d < NDIMS; d++) {
646 new_num_elements *= new_sizes[d];
647 (*dims)[d] = new_sizes[d];
648 }
649 const int element_size = DataTypeSize(BaseType(dtype()));
650 if (element_size > 0) {
651 CHECK_EQ(new_num_elements * sizeof(T), NumElements() * element_size);
652 } else {
653 // DataTypeSize() returns 0 for some data types. In this case, assume that T
654 // has the same size as the buffer type.
655 // NOTE: If we can be sure that DataTypeSize() does not return 0 for all POD
656 // types, then we should check DataTypeToEnum<T>::v() == dtype(). Or simply
657 // check if `element_size > 0` to err when bit cast is attempted on Tensor
658 // of unknown data type size.
659 CHECK_EQ(new_num_elements, NumElements());
660 }
661}
662
663template <typename T, size_t NDIMS>
664typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
665 gtl::ArraySlice<int64> new_sizes) {
666 CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
667 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
668 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
669 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
670}
671
672template <typename T, size_t NDIMS>
673typename TTypes<T, NDIMS>::Tensor Tensor::bit_casted_shaped(
674 gtl::ArraySlice<int64> new_sizes) {
675 CHECK(IsAligned());
676 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
677 FillDimsAndValidateCompatibleShape<T>(new_sizes, &dims);
678 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
679}
680
681template <typename T, size_t NDIMS>
682typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
683 gtl::ArraySlice<int64> new_sizes) {
684 CheckType(DataTypeToEnum<T>::v());
685 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
686 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
687 return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
688}
689
690template <typename T, size_t NDIMS>
691typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
692 gtl::ArraySlice<int64> new_sizes) const {
693 CheckType(DataTypeToEnum<T>::v());
694 CHECK(IsAligned());
695 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
696 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
697 return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
698}
699
700template <typename T, size_t NDIMS>
701typename TTypes<T, NDIMS>::ConstTensor Tensor::bit_casted_shaped(
702 gtl::ArraySlice<int64> new_sizes) const {
703 CHECK(IsAligned());
704 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
705 FillDimsAndValidateCompatibleShape<T>(new_sizes, &dims);
706 return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
707}
708
709template <typename T, size_t NDIMS>
710typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
711 gtl::ArraySlice<int64> new_sizes) const {
712 CheckType(DataTypeToEnum<T>::v());
713 Eigen::array<Eigen::DenseIndex, NDIMS> dims;
714 FillDimsAndValidateCompatibleShape(new_sizes, &dims);
715 return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
716}
717
718template <typename T>
719typename TTypes<T>::Scalar Tensor::scalar() {
720 CheckIsAlignedAndSingleElement();
721 return typename TTypes<T>::Scalar(base<T>());
722}
723
724template <typename T>
725typename TTypes<T>::ConstScalar Tensor::scalar() const {
726 CheckIsAlignedAndSingleElement();
727 return typename TTypes<T>::ConstScalar(base<T>());
728}
729
730template <typename T, size_t NDIMS>
731typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() {
732 return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS));
733}
734
735template <typename T, size_t NDIMS>
736typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() {
737 return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS));
738}
739
740template <typename T, size_t NDIMS>
741typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_outer_dims(int64 begin) {
742 gtl::InlinedVector<int64, 4> flat_outer =
743 ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS);
744 return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS));
745}
746
747template <typename T, size_t NDIMS>
748typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const {
749 return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS));
750}
751
752template <typename T, size_t NDIMS>
753typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const {
754 return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS));
755}
756
757template <typename T, size_t NDIMS>
758typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_outer_dims(
759 int64 begin) const {
760 gtl::InlinedVector<int64, 4> flat_outer =
761 ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS);
762 return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS));
763}
764
765inline Tensor::Tensor(const Tensor& other)
766 : shape_(other.shape()), buf_(other.buf_) {
767 if (buf_) buf_->Ref();
768}
769
770inline Tensor::Tensor(Tensor&& other)
771 : shape_(std::move(other.shape())), buf_(other.buf_) {
772 other.buf_ = nullptr;
773}
774
775inline Tensor& Tensor::operator=(Tensor&& other) {
776 // Avoid self-assignment, since we might destroy our underlying buffer.
777 if (&other != this) {
778 shape_ = std::move(other.shape_);
779 if (buf_) buf_->Unref();
780 buf_ = other.buf_;
781 other.buf_ = nullptr;
782 }
783 return *this;
784}
785
786// END_SKIP_DOXYGEN
787
788} // namespace tensorflow
789
790#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
791