1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/tensor_shape.h"
17
18#include "tensorflow/core/framework/tensor_shape.pb.h"
19#include "tensorflow/core/kernels/bounds_check.h"
20#include "tensorflow/core/lib/core/errors.h"
21#include "tensorflow/core/lib/strings/str_util.h"
22#include "tensorflow/core/lib/strings/strcat.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/util/overflow.h"
25
26namespace tensorflow {
27
28// TensorShape and PartialTensorShape should have no fields beyond
29// TensorShapeRep. In particular, their sizes should be the same.
30static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
31 "TensorShape must have no fields beyond TensorShapeRep");
32static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
33 "PartialTensorShape must have no fields beyond TensorShapeRep");
34
35template <class Shape>
36static void AppendTo(const TensorShapeBase<Shape>& s,
37 gtl::InlinedVector<int64, 8>* vals) {
38 for (auto dim : s) {
39 vals->push_back(dim.size);
40 }
41}
42
43void TensorShape::CheckDimsEqual(int NDIMS) const {
44 CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
45 << " from a tensor of " << dims() << " dimensions";
46}
47
48void TensorShape::CheckDimsAtLeast(int NDIMS) const {
49 CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
50 << " dimensions from a tensor of " << dims()
51 << " dimensions";
52}
53
54template <class Shape>
55bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
56 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
57 // unknown_shape() set, and it seems hard to remove this without backwards
58 // compatibility issues.
59 if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
60 int64 num_elements = 1;
61 if (proto.dim().size() > MaxDimensions()) return false;
62 for (const auto& d : proto.dim()) {
63 if (d.size() < (kIsPartial ? -1 : 0)) return false;
64 if (d.size() == -1) {
65 num_elements = -1;
66 } else if (!kIsPartial || num_elements >= 0) {
67 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
68 if (num_elements < 0) return false;
69 }
70 }
71 return true;
72}
73
74template <class Shape>
75Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
76 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
77 // unknown_shape() set, and it seems hard to remove this without backwards
78 // compatibility issues.
79 if (kIsPartial && proto.unknown_rank()) {
80 if (proto.dim_size() > 0) {
81 return errors::InvalidArgument(
82 "An unknown shape must not have any dimensions set.");
83 }
84 return Status::OK();
85 }
86 int64 num_elements = 1;
87 if (proto.dim().size() > MaxDimensions()) {
88 return errors::InvalidArgument("Shape ", DebugString(proto),
89 " has too many dimensions");
90 }
91 for (const auto& d : proto.dim()) {
92 if (d.size() < (kIsPartial ? -1 : 0)) {
93 if (kIsPartial) {
94 return errors::InvalidArgument(
95 "Shape ", DebugString(proto),
96 " has dimensions with values below -1 (where -1 means unknown)");
97 } else {
98 return errors::InvalidArgument("Shape ", DebugString(proto),
99 " is not fully defined");
100 }
101 }
102 if (d.size() == -1) {
103 num_elements = -1;
104 } else if (!kIsPartial || num_elements >= 0) {
105 num_elements = MultiplyWithoutOverflow(num_elements, d.size());
106 if (num_elements < 0) {
107 return errors::InvalidArgument(
108 "Shape ", DebugString(proto),
109 " is too large (more than 2**63 - 1 entries)");
110 }
111 }
112 }
113 return Status::OK();
114}
115
116template <class Shape>
117TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
118 set_tag(REP16);
119 set_data_type(DT_INVALID);
120 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
121 // unknown_shape() set, and it seems hard to remove this without backwards
122 // compatibility issues.
123 if (kIsPartial && proto.unknown_rank()) {
124 set_ndims_byte(kUnknownRank);
125 set_num_elements(-1);
126 } else {
127 set_ndims_byte(0);
128 set_num_elements(1);
129 for (const auto& d : proto.dim()) {
130 AddDim(d.size());
131 }
132 }
133}
134
135template <class Shape>
136TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
137 set_tag(REP16);
138 set_data_type(DT_INVALID);
139 set_ndims_byte(0);
140 set_num_elements(1);
141 for (int64 s : dim_sizes) {
142 AddDim(internal::SubtleMustCopy(s));
143 }
144}
145
146template <class Shape>
147TensorShapeBase<Shape>::TensorShapeBase() {
148 set_tag(REP16);
149 set_data_type(DT_INVALID);
150 if (kIsPartial) {
151 set_ndims_byte(kUnknownRank);
152 set_num_elements(-1);
153 } else {
154 set_ndims_byte(0);
155 set_num_elements(1);
156 }
157}
158
159void TensorShapeRep::DestructorOutOfLine() {
160 DCHECK(tag() == REP_OUT_OF_LINE);
161 delete as64()->dims_;
162}
163
164void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
165 if (b.tag() != REP_OUT_OF_LINE) {
166 if (tag() == REP_OUT_OF_LINE) {
167 delete as64()->dims_;
168 }
169 memcpy(buf(), b.buf(), sizeof(u_.buf));
170 // memcpy above implicitly also does:
171 // set_tag(b.tag());
172 // set_ndims_byte(b.ndims_byte());
173 // set_data_type(b.data_type());
174 } else {
175 DCHECK_EQ(b.tag(), REP_OUT_OF_LINE);
176 set_ndims_byte(b.ndims_byte());
177 set_data_type(b.data_type());
178 if (tag() == REP_OUT_OF_LINE) {
179 // vector already allocated
180 *(as64()->dims_) = *(b.as64()->dims_);
181 } else {
182 set_tag(REP_OUT_OF_LINE);
183 as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
184 }
185 }
186}
187
188template <class Shape>
189int64 TensorShapeBase<Shape>::dim_size(int d) const {
190 if (unknown_rank()) return -1;
191 DCHECK_GE(d, 0);
192 DCHECK_LT(d, dims());
193 if (tag() == REP16) {
194 uint16 dim = as16()->dims_[d];
195 if (kIsPartial && dim == kUnknownRep16) return -1;
196 return dim;
197 } else if (tag() == REP32) {
198 uint32 dim = as32()->dims_[d];
199 if (kIsPartial && dim == kUnknownRep32) return -1;
200 return dim;
201 } else {
202 return (*as64()->dims_)[d];
203 }
204}
205
206void TensorShapeRep::Clear() {
207 ClearAllButDataType();
208 set_data_type(DT_INVALID);
209}
210
211void TensorShapeRep::ClearAllButDataType() {
212 if (tag() == REP_OUT_OF_LINE) {
213 delete as64()->dims_;
214 }
215 set_tag(REP16);
216 set_ndims_byte(0);
217 // Leaves data_type alone
218 set_num_elements(1);
219}
220
221template <class Shape>
222void TensorShapeBase<Shape>::RecomputeNumElements() {
223 if (unknown_rank()) {
224 set_num_elements(-1);
225 return;
226 }
227 int64 n = 1;
228 for (auto dim : *this) {
229 if (kIsPartial && dim.size < 0) {
230 n = -1;
231 break;
232 }
233 n = MultiplyWithoutOverflow(n, dim.size);
234 CHECK_LE(0, n);
235 }
236 set_num_elements(n);
237}
238
239template <class Shape>
240void TensorShapeBase<Shape>::AddDim(int64 size) {
241 if (!kIsPartial) CHECK_GE(size, 0);
242 if (unknown_rank()) return;
243 CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
244 int64 new_num_elements;
245 if (kIsPartial && (num_elements() < 0 || size < 0)) {
246 new_num_elements = -1;
247 } else {
248 new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
249 CHECK_LE(0, new_num_elements);
250 }
251 UnsafeAddDim(size, new_num_elements);
252}
253
254template <class Shape>
255void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
256 const int nd = ndims_byte();
257 if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
258 as16()->dims_[nd] =
259 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
260 } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
261 as32()->dims_[nd] =
262 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
263 } else if (tag() == REP_OUT_OF_LINE) {
264 as64()->dims_->push_back(size);
265 } else {
266 // Need to change representation
267 gtl::InlinedVector<int64, 8> vals;
268 AppendTo(*this, &vals);
269 vals.push_back(size);
270 // We know we can't be REP16. See if we have a small enough
271 // number of dimensions and each dimension's size is small enough
272 // to allow REP32.
273 bool can_be_rep32 = (vals.size() <= 3);
274 if (can_be_rep32) {
275 for (size_t i = 0; i < vals.size(); i++) {
276 if (vals[i] >= kMaxRep32) {
277 can_be_rep32 = false;
278 break;
279 }
280 }
281 }
282 if (can_be_rep32) {
283 set_tag(REP32);
284 for (size_t d = 0; d < vals.size(); d++) {
285 as32()->dims_[d] = kIsPartial && vals[d] < 0
286 ? kUnknownRep32
287 : static_cast<uint32>(vals[d]);
288 }
289 } else {
290 set_tag(REP_OUT_OF_LINE);
291 as64()->dims_ =
292 new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
293 }
294 }
295 set_ndims_byte(nd + 1);
296 set_num_elements(new_num_elements);
297}
298
299template <class Shape>
300void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
301 for (auto d : shape) AddDim(d.size);
302}
303
304template <class Shape>
305void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
306 CHECK_GE(d, 0);
307 CHECK_LE(d, dims());
308 if (!kIsPartial) CHECK_GE(size, 0);
309 CHECK_LT(dims(), MaxDimensions());
310 gtl::InlinedVector<int64, 8> vals;
311 AppendTo(*this, &vals);
312 vals.insert(vals.begin() + d, size);
313 ClearAllButDataType();
314 for (auto dval : vals) {
315 AddDim(dval);
316 }
317}
318
319template <class Shape>
320gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
321 gtl::InlinedVector<int64, 4> result;
322 for (auto dim : *this) {
323 result.push_back(dim.size);
324 }
325 return result;
326}
327
328template <class Shape>
329void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
330 CHECK_GE(d, 0);
331 CHECK_LT(d, dims());
332 CHECK_GE(size, 0);
333 if (tag() == REP16 && size < kMaxRep16) {
334 as16()->dims_[d] =
335 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
336 } else if (tag() == REP32 && size < kMaxRep32) {
337 as32()->dims_[d] =
338 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
339 } else if (tag() == REP_OUT_OF_LINE) {
340 (*as64()->dims_)[d] = size;
341 } else {
342 // Must upgrade
343 gtl::InlinedVector<int64, 8> vals;
344 AppendTo(*this, &vals);
345 vals[d] = size;
346 ClearAllButDataType();
347 for (auto dval : vals) {
348 AddDim(dval);
349 }
350 }
351 RecomputeNumElements();
352}
353
354template <class Shape>
355void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
356 if (unknown_rank()) return;
357 begin = begin < 0 ? dims() + begin + 1 : begin;
358 end = end < 0 ? dims() + end + 1 : end;
359 CHECK_GE(begin, 0);
360 CHECK_LE(begin, dims());
361 CHECK_GE(end, 0);
362 CHECK_LE(end, dims());
363 if (begin >= end) return;
364 gtl::InlinedVector<int64, 8> vals;
365 AppendTo(*this, &vals);
366 vals.erase(vals.begin() + begin, vals.begin() + end);
367 ClearAllButDataType();
368 for (auto dval : vals) {
369 AddDim(dval);
370 }
371 RecomputeNumElements();
372}
373
374bool TensorShape::IsSameSize(const TensorShape& b) const {
375 if (b.dims() != dims()) return false;
376 for (int d = 0; d < dims(); d++) {
377 if (dim_size(d) != b.dim_size(d)) return false;
378 }
379 return true;
380}
381
382template <class Shape>
383void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
384 proto->Clear();
385 if (unknown_rank()) {
386 proto->set_unknown_rank(true);
387 } else {
388 for (int i = 0; i < dims(); i++) {
389 proto->add_dim()->set_size(dim_size(i));
390 }
391 }
392}
393
394void TensorShapeRep::DumpRep() const {
395#if 0
396 fprintf(stderr, "Rep: %d %d dims\n", tag(), dims());
397 if (tag() == REP16) {
398 fprintf(stderr, "REP16 NDIMS: %d\n", ndims_byte());
399 for (int i = 0; i < ndims_byte(); i++) {
400 fprintf(stderr, "dim %d: %d\n", i, as16()->dims_[i]);
401 }
402 } else if (tag_ == REP32) {
403 fprintf(stderr, "REP32 NDIMS: %d\n", ndims_);
404 for (int i = 0; i < ndims_byte(); i++) {
405 fprintf(stderr, "dim %d: %d\n", i, as32()->dims_[i]);
406 }
407 } else if (tag_ == REP_OUT_OF_LINE) {
408 fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n", ndims_, as16()->dims_);
409 for (int i = 0; i < ndims_byte(); i++) {
410 fprintf(stderr, "dim %d: %lld\n", i, (*as64()->dims_)[i]);
411 }
412 }
413#endif
414}
415
416template <class Shape>
417TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
418 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
419}
420
421template <class Shape>
422TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
423 CHECK(!unknown_rank());
424 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), dims());
425}
426
427string TensorShapeRep::DebugString() const {
428 const auto& shape = *static_cast<const PartialTensorShape*>(this);
429 if (shape.unknown_rank()) return "<unknown>";
430 string s = "[";
431 for (int i = 0; i < shape.dims(); i++) {
432 if (i > 0) strings::StrAppend(&s, ",");
433 int64 dim = shape.dim_size(i);
434 if (dim < 0) {
435 strings::StrAppend(&s, "?");
436 } else {
437 strings::StrAppend(&s, dim);
438 }
439 }
440 strings::StrAppend(&s, "]");
441 return s;
442}
443
444string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
445 string s;
446 if (proto.unknown_rank()) {
447 strings::StrAppend(&s, "<unknown>");
448 if (proto.dim_size() == 0) return s;
449 }
450 strings::StrAppend(&s, "[");
451 bool first = true;
452 for (const auto& d : proto.dim()) {
453 if (!first) strings::StrAppend(&s, ",");
454 if (d.size() == -1) {
455 strings::StrAppend(&s, "?");
456 } else {
457 strings::StrAppend(&s, d.size());
458 }
459 first = false;
460 }
461 strings::StrAppend(&s, "]");
462 return s;
463}
464
465bool TensorShapeUtils::StartsWith(const TensorShape& shape,
466 const TensorShape& prefix) {
467 if (shape.dims() < prefix.dims()) return false;
468 for (int i = 0; i < prefix.dims(); ++i) {
469 if (shape.dim_size(i) != prefix.dim_size(i)) return false;
470 }
471 return true;
472}
473
474bool TensorShapeUtils::EndsWith(const TensorShape& shape,
475 const TensorShape& suffix) {
476 const int suffix_size = suffix.dims();
477 if (shape.dims() < suffix_size) return false;
478 for (int i = 0; i < suffix_size; ++i) {
479 if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
480 return false;
481 }
482 }
483 return true;
484}
485
486template <typename T, class Shape>
487Status MakeShapeHelper(const T* dims, int64 n, Shape* out) {
488 out->Clear();
489 if (n > TensorShape::MaxDimensions()) {
490 return errors::InvalidArgument("Too many dimensions");
491 }
492 if (n < 0) {
493 return errors::InvalidArgument("Negative number of dimensions ", n);
494 }
495 for (int64 i = 0; i < n; ++i) {
496 T dim = internal::SubtleMustCopy(dims[i]);
497 int64 new_num_elements;
498 if (dim < 0) {
499 if (!out->kIsPartial) {
500 return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
501 }
502 if (dim < -1) {
503 return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
504 }
505 dim = -1;
506 new_num_elements = -1;
507 } else if (out->num_elements() < 0) {
508 new_num_elements = -1;
509 } else {
510 new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
511 if (TF_PREDICT_FALSE(new_num_elements < 0)) {
512 TensorShapeProto proto;
513 for (int64 j = 0; j < n; ++j) {
514 proto.add_dim()->set_size(dim);
515 }
516 return errors::InvalidArgument(
517 "Shape ", TensorShape::DebugString(proto),
518 " would have more than 2**63 - 1 elements");
519 }
520 }
521 out->UnsafeAddDim(dim, new_num_elements);
522 }
523 return Status::OK();
524}
525
526#define MAKE_SHAPE(T, Shape) \
527 Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) { \
528 return MakeShapeHelper(dims, n, out); \
529 } \
530 Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
531 return MakeShapeHelper(shape.data(), shape.size(), out); \
532 }
533MAKE_SHAPE(int32, TensorShape)
534MAKE_SHAPE(int64, TensorShape)
535MAKE_SHAPE(int32, PartialTensorShape)
536MAKE_SHAPE(int64, PartialTensorShape)
537#undef MAKE_SHAPE
538
539string TensorShapeUtils::ShapeListString(
540 const gtl::ArraySlice<TensorShape>& shapes) {
541 string result = "[";
542 bool first = true;
543 for (const TensorShape& shape : shapes) {
544 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
545 first = false;
546 }
547 strings::StrAppend(&result, "]");
548 return result;
549}
550
551PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
552 PartialTensorShape out = *this;
553 out.AddDim(size);
554 return out;
555}
556
557PartialTensorShape PartialTensorShape::Concatenate(
558 const PartialTensorShape& shape) const {
559 if (unknown_rank() || shape.unknown_rank()) {
560 return PartialTensorShape();
561 }
562 PartialTensorShape out = *this;
563 for (auto dim : shape) out.AddDim(dim.size);
564 return out;
565}
566
567Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
568 PartialTensorShape* result) const {
569 if (unknown_rank()) {
570 *result = shape;
571 return Status::OK();
572 }
573 if (shape.unknown_rank()) {
574 *result = *this;
575 return Status::OK();
576 }
577 const int dims_ = dims();
578 if (dims_ != shape.dims()) {
579 return errors::InvalidArgument(
580 "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
581 shape.dims());
582 }
583 CHECK(result != this);
584 result->Clear();
585 for (int i = 0; i < dims_; ++i) {
586 const int64 dim0 = dim_size(i);
587 const int64 dim1 = shape.dim_size(i);
588 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
589 return errors::InvalidArgument(
590 "PartialTensorShape: Incompatible shapes during merge: ",
591 DebugString(), " vs. ", shape.DebugString());
592 }
593 result->AddDim(dim0 >= 0 ? dim0 : dim1);
594 }
595 return Status::OK();
596}
597
598bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
599 if (IsFullyDefined()) {
600 const TensorShapeRep* rep = this;
601 *shape = *static_cast<const TensorShape*>(rep);
602 return true;
603 }
604 return false;
605}
606
607bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
608 if (unknown_rank() || shape.unknown_rank()) {
609 return unknown_rank() == shape.unknown_rank();
610 }
611 if (dims() != shape.dims()) return false;
612 for (int i = 0; i < dims(); i++) {
613 if (dim_size(i) != shape.dim_size(i)) return false;
614 }
615 return true;
616}
617
618bool PartialTensorShape::IsCompatibleWith(
619 const PartialTensorShape& shape) const {
620 if (unknown_rank() || shape.unknown_rank()) return true;
621 if (dims() != shape.dims()) return false;
622 for (int i = 0; i < dims(); i++) {
623 const int64 dim0 = dim_size(i);
624 const int64 dim1 = shape.dim_size(i);
625 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
626 }
627 return true;
628}
629
630string PartialTensorShapeUtils::PartialShapeListString(
631 const gtl::ArraySlice<PartialTensorShape>& shapes) {
632 string result = "[";
633 bool first = true;
634 for (const PartialTensorShape& shape : shapes) {
635 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
636 first = false;
637 }
638 strings::StrAppend(&result, "]");
639 return result;
640}
641
642bool PartialTensorShapeUtils::AreCompatible(
643 const gtl::ArraySlice<PartialTensorShape>& shapes0,
644 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
645 if (shapes0.size() == shapes1.size()) {
646 for (size_t i = 0; i < shapes0.size(); ++i) {
647 if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
648 return false;
649 }
650 }
651 return true;
652 } else {
653 return false;
654 }
655}
656
657bool PartialTensorShapeUtils::AreIdentical(
658 const gtl::ArraySlice<PartialTensorShape>& shapes0,
659 const gtl::ArraySlice<PartialTensorShape>& shapes1) {
660 if (shapes0.size() == shapes1.size()) {
661 for (size_t i = 0; i < shapes0.size(); ++i) {
662 if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
663 return false;
664 }
665 }
666 return true;
667 } else {
668 return false;
669 }
670}
671
672Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
673 int64* num_elements) {
674 int64 n = 1;
675 for (auto dim : shape) {
676 n = MultiplyWithoutOverflow(n, dim);
677 if (n < 0) {
678 return errors::InvalidArgument("Can't compute total size of shape [",
679 str_util::Join(shape, ","),
680 "]; product would overflow int64");
681 }
682 }
683 *num_elements = n;
684 return Status::OK();
685}
686
687template class TensorShapeBase<TensorShape>;
688template class TensorShapeBase<PartialTensorShape>;
689
690} // namespace tensorflow
691