1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/types.pb.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/lib/core/stringpiece.h" |
26 | #include "tensorflow/core/lib/gtl/array_slice.h" |
27 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
28 | #include "tensorflow/core/lib/strings/str_util.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // START_SKIP_DOXYGEN |
34 | template <class Shape> |
35 | class TensorShapeIter; |
36 | class TensorShape; |
37 | class TensorShapeProto; |
38 | class PartialTensorShape; |
39 | // END_SKIP_DOXYGEN |
40 | |
41 | /// Internal representation for both TensorShape and PartialTensorShape. |
42 | class TensorShapeRep { |
43 | public: |
44 | ~TensorShapeRep(); |
45 | |
46 | /// Copy the specified shape |
47 | TensorShapeRep(const TensorShapeRep& b); |
48 | void operator=(const TensorShapeRep& b); |
49 | |
50 | /// Move the specified shape. After moving, <b> is safe for destruction and |
51 | // can be reassigned into, but its dimensions and number of elements can be |
52 | // nonsensical (e.g., negative dimension sizes, or number of elements not |
53 | // properly recomputed). |
54 | TensorShapeRep(TensorShapeRep&& b); |
55 | void operator=(TensorShapeRep&& b); |
56 | |
57 | /// Clear a tensor shape, producing the scalar shape. |
58 | void Clear(); |
59 | |
60 | // Maximum number of dimensions in a tensor. |
61 | // It's 254 because 255 = kUnknownRank is used to represent unknown rank. |
62 | static constexpr int MaxDimensions() { return 254; } |
63 | |
64 | /// \brief Returns the number of elements in the tensor. |
65 | /// |
66 | /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` |
67 | /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully |
68 | /// defined. |
69 | int64 num_elements() const { return num_elements_; } |
70 | |
71 | /// For error messages. |
72 | string DebugString() const; |
73 | static string DebugString(const TensorShapeProto& proto); |
74 | |
75 | void DumpRep() const; // XXX |
76 | |
77 | protected: |
78 | // Constructable only via TensorShapeBase |
79 | TensorShapeRep() = default; |
80 | |
81 | void ClearAllButDataType(); |
82 | |
83 | // We use 16 bytes to represent a TensorShape. Because we need to |
84 | // be able to support full 64-bit dimension sizes and an arbitrary |
85 | // number of dimensions for a Tensor, but most tensor dimensions are |
86 | // significantly smaller than 64 bits and most tensors are 1, 2, or 3 |
87 | // dimensions, we have several representations. |
88 | // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1 |
89 | // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1 |
90 | // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using |
91 | // an out of line vector. |
92 | // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown. |
93 | // This value is not allowed in TensorShape either for format compatibility. |
94 | struct Rep16 { |
95 | uint16 dims_[6]; |
96 | }; |
97 | struct Rep32 { |
98 | uint32 dims_[3]; |
99 | }; |
100 | struct Rep64 { |
101 | gtl::InlinedVector<int64, 4>* dims_; |
102 | }; |
103 | |
104 | // We use the max value of uint16 or uint32 to represent unknown shapes, so |
105 | // the maximum representable valid shape in these representations is one less. |
106 | static const int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1; |
107 | static const int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1; |
108 | static const uint16 kUnknownRep16 = std::numeric_limits<uint16>::max(); |
109 | static const uint32 kUnknownRep32 = std::numeric_limits<uint32>::max(); |
110 | |
111 | Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); } |
112 | Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); } |
113 | Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); } |
114 | |
115 | const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); } |
116 | const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); } |
117 | const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); } |
118 | |
119 | enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 }; |
120 | |
121 | // Since we have a convenient extra byte available, we allow the |
122 | // Tensor class to store an 8-bit value in this extra storage. This |
123 | // allows it to store the Tensor's datatype enum value here and avoid |
124 | // an extra word of storage. |
125 | friend class Tensor; |
126 | friend class TensorShapeTestHelper; |
127 | DataType data_type() const { return static_cast<DataType>(buf()[13]); } |
128 | void set_data_type(DataType dt) { |
129 | // We only have 8 bits available to store DataType, so make sure it fits |
130 | DCHECK_LT(static_cast<uint32>(dt), 256u); |
131 | buf()[13] = static_cast<uint8>(dt); |
132 | } |
133 | |
134 | // We store the number of dimensions in byte 14, and the RepTag in byte 15. |
135 | // Bytes [0..13] vary depending on the representation. |
136 | // A value of 255 indicates unknown rank in the PartialTensorShape case. |
137 | static const uint8 kUnknownRank = 255; |
138 | uint8 ndims_byte() const { return buf()[14]; } |
139 | void set_ndims_byte(uint8 nd) { buf()[14] = nd; } |
140 | |
141 | RepTag tag() const { return static_cast<RepTag>(buf()[15]); } |
142 | void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); } |
143 | |
144 | void set_num_elements(int64 n) { num_elements_ = n; } |
145 | |
146 | private: |
147 | void DestructorOutOfLine(); |
148 | void SlowCopyFrom(const TensorShapeRep& b); |
149 | |
150 | uint8* buf() { return &u_.buf[0]; } |
151 | const uint8* buf() const { return &u_.buf[0]; } |
152 | |
153 | union { |
154 | uint8 buf[16]; |
155 | // Force data to be aligned enough for a pointer. |
156 | Rep64* unused_aligner; |
157 | } u_; |
158 | int64 num_elements_; |
159 | }; |
160 | |
161 | /// Base class for TensorShape and PartialTensorShape. |
162 | /// The class is templatized by either TensorShape or PartialTensorShape to |
163 | /// allow skipping known/unknown checks in the TensorShape case, but the |
164 | /// representation is shared exactly for fast conversion. |
165 | template <class Shape> |
166 | class TensorShapeBase : public TensorShapeRep { |
167 | public: |
168 | /// \brief Construct a `TensorShapeBase` from the provided sizes. |
169 | /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape) |
170 | explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes); |
171 | TensorShapeBase(std::initializer_list<int64> dim_sizes) |
172 | : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {} |
173 | |
174 | /// Construct an empty TensorShape, or an unknown rank PartialTensorShape |
175 | TensorShapeBase(); |
176 | |
177 | TensorShapeBase(const TensorShapeProto& proto); |
178 | |
179 | /// Returns `true` iff `proto` is a valid tensor shape. |
180 | // For TensorShape, the proto shape must be fully defined. |
181 | static bool IsValid(const TensorShapeProto& proto); |
182 | |
183 | /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error |
184 | /// status otherwise. |
185 | static Status IsValidShape(const TensorShapeProto& proto); |
186 | |
187 | /// \brief Add a dimension to the end ("inner-most"). |
188 | /// REQUIRES: `size >= 0` |
189 | void AddDim(int64 size); |
190 | |
191 | /// Appends all the dimensions from `shape`. |
192 | void AppendShape(const TensorShapeBase& shape); |
193 | |
194 | /// \brief Insert a dimension somewhere in the `TensorShape`. |
195 | /// REQUIRES: `0 <= d <= dims()` |
196 | /// REQUIRES: `size >= 0` |
197 | void InsertDim(int d, int64 size); |
198 | |
199 | /// \brief Modifies the size of the dimension `d` to be `size` |
200 | /// REQUIRES: `0 <= d < dims()` |
201 | /// REQUIRES: `size >= 0` |
202 | void set_dim(int d, int64 size); |
203 | |
204 | /// \brief Removes dimension `d` from the `TensorShape`. |
205 | /// REQUIRES: `0 <= d < dims()` |
206 | void RemoveDim(int d) { |
207 | CHECK_GE(d, 0); |
208 | RemoveDimRange(d, d + 1); |
209 | } |
210 | |
211 | /// \brief Removes last `n` dimensions from the `TensorShape`. |
212 | /// REQUIRES: `0 <= n <= dims()` |
213 | void RemoveLastDims(int n) { |
214 | CHECK_LE(n, dims()); |
215 | RemoveDimRange(dims() - n, dims()); |
216 | } |
217 | |
218 | /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`. |
219 | /// Negative values of `end` are interpreted as `dims() + end + 1` (as in |
220 | /// Python). The same is true for negative values of `begin`. REQUIRES: |
221 | /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()` |
222 | void RemoveDimRange(int begin, int end); |
223 | |
224 | /// Return whether the rank is unknown |
225 | bool unknown_rank() const { |
226 | return kIsPartial && ndims_byte() == kUnknownRank; |
227 | } |
228 | |
229 | /// Return the number of dimensions in the tensor. |
230 | /// Can be -1 meaning unknown rank for PartialTensorShape. |
231 | int dims() const { |
232 | uint8 dims = ndims_byte(); |
233 | return kIsPartial && dims == kUnknownRank ? -1 : dims; |
234 | } |
235 | |
236 | /// \brief Returns the number of elements in dimension `d`. |
237 | /// REQUIRES: `0 <= d < dims()` |
238 | // TODO(touts): Rename to `dimension()` to match |
239 | // `Eigen::Tensor::dimension()`? |
240 | int64 dim_size(int d) const; |
241 | |
242 | /// Returns sizes of all dimensions. |
243 | // Returns an empty list for unknown rank PartialTensorShape. |
244 | gtl::InlinedVector<int64, 4> dim_sizes() const; |
245 | |
246 | /// Return true iff the rank and all of the dimensions are well defined |
247 | // TODO(irving): Rename to is_fully_defined now that it's fast. |
248 | bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; } |
249 | |
250 | /// Fill `*proto` from `*this`. |
251 | void AsProto(TensorShapeProto* proto) const; |
252 | |
253 | /// For iterating through the dimensions. |
254 | TensorShapeIter<Shape> begin() const; |
255 | TensorShapeIter<Shape> end() const; |
256 | |
257 | private: |
258 | void RecomputeNumElements(); |
259 | |
260 | // True for PartialTensorShape, false for TensorShape |
261 | static constexpr bool kIsPartial = |
262 | std::is_same<Shape, PartialTensorShape>::value; |
263 | static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value, |
264 | "Shape is neither TensorShape nor PartialTensorShape" ); |
265 | |
266 | // Used by AddDim and MakeShapeHelper. Does no error checking. |
267 | void UnsafeAddDim(int64 size, int64 new_num_elements); |
268 | |
269 | // For use by TensorShapeUtils::MakeShape |
270 | template <class T, class S> |
271 | friend Status MakeShapeHelper(const T*, int64, S*); |
272 | }; |
273 | |
274 | /// Outputs `TensorShapeBase` to `std::ostream`. |
275 | template <typename Shape> |
276 | std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) { |
277 | return os << tsb.DebugString(); |
278 | } |
279 | |
280 | /// Represents the shape of a Tensor. |
281 | /// |
282 | /// A tensor's shape is denoted by its number of dimensions and a size for each |
283 | /// dimension. For example, a Tensor represented by a 3 x 4 matrix would have |
284 | /// a shape of 2-D, [3,4]. |
285 | /// |
286 | /// If you know the exact shape of your Tensor when you create the TensorShape |
287 | /// object, you can specify it then, or you can create a TensorShape with |
288 | /// zero dimensions and one element, and call AddDim() to add dimensions later. |
289 | class TensorShape : public TensorShapeBase<TensorShape> { |
290 | public: |
291 | using TensorShapeBase<TensorShape>::TensorShapeBase; |
292 | |
293 | /// Allow a TensorShape to be used as a PartialTensorShape without copying |
294 | operator const PartialTensorShape&() const; // NOLINT(runtime/explicit) |
295 | |
296 | /// Returns true if `*this` and `b` have the same sizes. Ignores |
297 | /// dimension names. |
298 | bool IsSameSize(const TensorShape& b) const; |
299 | bool operator==(const TensorShape& b) const { return IsSameSize(b); } |
300 | bool operator!=(const TensorShape& b) const { return !IsSameSize(b); } |
301 | |
302 | /// Fill `*dsizes` from `*this`. |
303 | template <int NDIMS> |
304 | Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const; |
305 | |
306 | /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in |
307 | /// which case we pad the rest of the sizes with 1. |
308 | template <int NDIMS> |
309 | Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const; |
310 | |
311 | private: |
312 | // These CHECK fail to ease debugging. |
313 | // REQUIRES: dims() == NDIMS |
314 | void CheckDimsEqual(int NDIMS) const; |
315 | // REQUIRES: dims() >= NDIMS |
316 | void CheckDimsAtLeast(int NDIMS) const; |
317 | }; |
318 | |
319 | /// Represents the value of one dimension in a TensorShape. |
320 | struct TensorShapeDim { |
321 | explicit TensorShapeDim(int64 s) : size(s) {} |
322 | int64 size; |
323 | }; |
324 | |
325 | // START_SKIP_DOXYGEN |
326 | template <class Shape> |
327 | class TensorShapeIter { |
328 | public: |
329 | TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {} |
330 | bool operator==(const TensorShapeIter& rhs) { |
331 | DCHECK(shape_ == rhs.shape_); |
332 | return d_ == rhs.d_; |
333 | } |
334 | bool operator!=(const TensorShapeIter& rhs) { |
335 | DCHECK(shape_ == rhs.shape_); |
336 | return d_ != rhs.d_; |
337 | } |
338 | void operator++() { ++d_; } |
339 | TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); } |
340 | |
341 | private: |
342 | const Shape* shape_; |
343 | int d_; |
344 | }; |
345 | // END_SKIP_DOXYGEN |
346 | |
347 | /// \brief Static helper routines for `TensorShape`. Includes a few common |
348 | /// predicates on a tensor shape. |
349 | class TensorShapeUtils { |
350 | public: |
351 | static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; } |
352 | |
353 | static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; } |
354 | |
355 | static bool IsVectorOrHigher(const TensorShape& shape) { |
356 | return shape.dims() >= 1; |
357 | } |
358 | |
359 | static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; } |
360 | |
361 | static bool IsSquareMatrix(const TensorShape& shape) { |
362 | return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1); |
363 | } |
364 | |
365 | static bool IsMatrixOrHigher(const TensorShape& shape) { |
366 | return shape.dims() >= 2; |
367 | } |
368 | |
369 | /// \brief Returns a `TensorShape` whose dimensions are |
370 | /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. |
371 | static Status MakeShape(const int32* dims, int64 n, TensorShape* out); |
372 | static Status MakeShape(const int64* dims, int64 n, TensorShape* out); |
373 | static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out); |
374 | static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out); |
375 | static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out); |
376 | static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out); |
377 | static Status MakeShape(gtl::ArraySlice<int32> shape, |
378 | PartialTensorShape* out); |
379 | static Status MakeShape(gtl::ArraySlice<int64> shape, |
380 | PartialTensorShape* out); |
381 | |
382 | static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes); |
383 | |
384 | /// \brief Returns true iff `shape` starts with `prefix`. |
385 | static bool StartsWith(const TensorShape& shape, const TensorShape& prefix); |
386 | |
387 | /// \brief Returns true iff `shape` ends with `suffix`. |
388 | static bool EndsWith(const TensorShape& shape, const TensorShape& suffix); |
389 | |
390 | /// \brief Returns the product of values in an int64 array, |
391 | /// or a failing Status if the array represents a value larger than |
392 | /// a `TensorShape` can hold. |
393 | static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements); |
394 | }; |
395 | |
396 | /// Manages the partially known dimensions of a Tensor and their sizes. |
397 | class PartialTensorShape : public TensorShapeBase<PartialTensorShape> { |
398 | public: |
399 | PartialTensorShape() {} |
400 | using TensorShapeBase<PartialTensorShape>::TensorShapeBase; |
401 | |
402 | /// Add a dimension to the end ("inner-most"), returns a new |
403 | /// PartialTensorShape. |
404 | /// REQUIRES: `size >= -1`, where -1 means unknown |
405 | PartialTensorShape Concatenate(int64 size) const; |
406 | |
407 | /// Appends all the dimensions from `shape`. Returns a new |
408 | /// PartialTensorShape. |
409 | PartialTensorShape Concatenate(const PartialTensorShape& shape) const; |
410 | |
411 | /// Merges all the dimensions from `shape`. Returns |
412 | /// `InvalidArgument` error if either `shape` has a different rank |
413 | /// or if any of the dimensions are incompatible. |
414 | Status MergeWith(const PartialTensorShape& shape, |
415 | PartialTensorShape* result) const; |
416 | |
417 | /// Exact equality test. Returns true iff the ranks match (i.e., both are |
418 | /// unknown, or both are known and equal), and all dimensions are equal (i.e., |
419 | /// both dimensions are known, or both are known and equal). This is a |
420 | /// stronger condition that IsCompatibleWith. |
421 | bool IsIdenticalTo(const PartialTensorShape& shape) const; |
422 | |
423 | /// Return true iff the ranks match, and if the |
424 | /// dimensions all either match or one is unknown. |
425 | bool IsCompatibleWith(const PartialTensorShape& shape) const; |
426 | |
427 | // Fill `*shape` from `*this`. |
428 | // If `*this` is not fully defined, returns false and |
429 | // `*shape` is left in an intermediate state. Otherwise |
430 | // returns true. |
431 | bool AsTensorShape(TensorShape* shape) const; |
432 | |
433 | /// \brief Returns a `PartialTensorShape` whose dimensions are |
434 | /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are |
435 | /// considered "unknown". |
436 | template <class T> |
437 | static Status MakePartialShape(const T* dims, int n, |
438 | PartialTensorShape* out) { |
439 | return TensorShapeUtils::MakeShape(dims, n, out); |
440 | } |
441 | }; |
442 | |
443 | /// \brief Static helper routines for `PartialTensorShape`. Includes a few |
444 | /// common predicates on a partially known tensor shape. |
445 | class PartialTensorShapeUtils { |
446 | public: |
447 | static string PartialShapeListString( |
448 | const gtl::ArraySlice<PartialTensorShape>& shapes); |
449 | |
450 | static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0, |
451 | const gtl::ArraySlice<PartialTensorShape>& shapes1); |
452 | |
453 | static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0, |
454 | const gtl::ArraySlice<PartialTensorShape>& shapes1); |
455 | }; |
456 | |
457 | // ---------------------------------------------------------------------------- |
458 | // Template method implementation details below |
459 | // ---------------------------------------------------------------------------- |
460 | |
461 | template <int NDIMS> |
462 | Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const { |
463 | CheckDimsEqual(NDIMS); |
464 | return AsEigenDSizesWithPadding<NDIMS>(); |
465 | } |
466 | |
467 | template <int NDIMS> |
468 | Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding() |
469 | const { |
470 | CheckDimsAtLeast(NDIMS); |
471 | static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions" ); |
472 | Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; |
473 | for (int d = 0; d < dims(); d++) { |
474 | dsizes[d] = dim_size(d); |
475 | } |
476 | for (int d = dims(); d < NDIMS; d++) { |
477 | dsizes[d] = 1; |
478 | } |
479 | return dsizes; |
480 | } |
481 | |
482 | // ---------------------------------------------------------------------------- |
483 | // Inlining of some performance critical routines |
484 | // ---------------------------------------------------------------------------- |
485 | |
486 | inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { |
487 | num_elements_ = b.num_elements_; |
488 | if (b.tag() != REP_OUT_OF_LINE) { |
489 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
490 | // memcpy above Implicitly does: |
491 | // set_ndims_byte(b.ndims_byte()); |
492 | // set_tag(b.tag()); |
493 | } else { |
494 | set_tag(REP16); // So that SlowCopyFrom does not try to deallocate |
495 | SlowCopyFrom(b); |
496 | } |
497 | } |
498 | |
499 | inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) { |
500 | num_elements_ = b.num_elements_; |
501 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
502 | // memcpy above Implicitly does: |
503 | // set_ndims_byte(b.ndims_byte()); |
504 | // set_tag(b.tag()); |
505 | b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. |
506 | } |
507 | |
508 | inline TensorShapeRep::~TensorShapeRep() { |
509 | if (tag() == REP_OUT_OF_LINE) { |
510 | DestructorOutOfLine(); |
511 | } |
512 | } |
513 | |
514 | inline void TensorShapeRep::operator=(const TensorShapeRep& b) { |
515 | num_elements_ = b.num_elements_; |
516 | if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) { |
517 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
518 | // memcpy above implicitly also does: |
519 | // set_tag(b.tag()); |
520 | // set_ndims_byte(b.ndims_byte()); |
521 | } else { |
522 | SlowCopyFrom(b); |
523 | } |
524 | } |
525 | |
526 | inline void TensorShapeRep::operator=(TensorShapeRep&& b) { |
527 | if (tag() == REP_OUT_OF_LINE) { |
528 | DestructorOutOfLine(); |
529 | } |
530 | num_elements_ = b.num_elements_; |
531 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
532 | // memcpy above Implicitly does: |
533 | // set_ndims_byte(b.ndims_byte()); |
534 | // set_tag(b.tag()); |
535 | b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. |
536 | } |
537 | |
538 | inline TensorShape::operator const PartialTensorShape&() const { |
539 | // Downcast to the shared representation and upcast to PartialTensorShape |
540 | const TensorShapeRep* rep = this; |
541 | return *static_cast<const PartialTensorShape*>(rep); |
542 | } |
543 | |
544 | // Declare explicit instantiations in .cc file |
545 | extern template class TensorShapeBase<TensorShape>; |
546 | extern template class TensorShapeBase<PartialTensorShape>; |
547 | |
548 | } // namespace tensorflow |
549 | |
550 | #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ |
551 | |