1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/tensor_shape.h" |
17 | |
18 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
19 | #include "tensorflow/core/kernels/bounds_check.h" |
20 | #include "tensorflow/core/lib/core/errors.h" |
21 | #include "tensorflow/core/lib/strings/str_util.h" |
22 | #include "tensorflow/core/lib/strings/strcat.h" |
23 | #include "tensorflow/core/platform/logging.h" |
24 | #include "tensorflow/core/util/overflow.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | // TensorShape and PartialTensorShape should have no fields beyond |
29 | // TensorShapeRep. In particular, their sizes should be the same. |
30 | static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape), |
31 | "TensorShape must have no fields beyond TensorShapeRep" ); |
32 | static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape), |
33 | "PartialTensorShape must have no fields beyond TensorShapeRep" ); |
34 | |
35 | template <class Shape> |
36 | static void AppendTo(const TensorShapeBase<Shape>& s, |
37 | gtl::InlinedVector<int64, 8>* vals) { |
38 | for (auto dim : s) { |
39 | vals->push_back(dim.size); |
40 | } |
41 | } |
42 | |
43 | void TensorShape::CheckDimsEqual(int NDIMS) const { |
44 | CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions" |
45 | << " from a tensor of " << dims() << " dimensions" ; |
46 | } |
47 | |
48 | void TensorShape::CheckDimsAtLeast(int NDIMS) const { |
49 | CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS |
50 | << " dimensions from a tensor of " << dims() |
51 | << " dimensions" ; |
52 | } |
53 | |
54 | template <class Shape> |
55 | bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) { |
56 | // NOTE(irving): Unfortunately, TensorShape allows parsing protos with |
57 | // unknown_shape() set, and it seems hard to remove this without backwards |
58 | // compatibility issues. |
59 | if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0; |
60 | int64 num_elements = 1; |
61 | if (proto.dim().size() > MaxDimensions()) return false; |
62 | for (const auto& d : proto.dim()) { |
63 | if (d.size() < (kIsPartial ? -1 : 0)) return false; |
64 | if (d.size() == -1) { |
65 | num_elements = -1; |
66 | } else if (!kIsPartial || num_elements >= 0) { |
67 | num_elements = MultiplyWithoutOverflow(num_elements, d.size()); |
68 | if (num_elements < 0) return false; |
69 | } |
70 | } |
71 | return true; |
72 | } |
73 | |
74 | template <class Shape> |
75 | Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) { |
76 | // NOTE(irving): Unfortunately, TensorShape allows parsing protos with |
77 | // unknown_shape() set, and it seems hard to remove this without backwards |
78 | // compatibility issues. |
79 | if (kIsPartial && proto.unknown_rank()) { |
80 | if (proto.dim_size() > 0) { |
81 | return errors::InvalidArgument( |
82 | "An unknown shape must not have any dimensions set." ); |
83 | } |
84 | return Status::OK(); |
85 | } |
86 | int64 num_elements = 1; |
87 | if (proto.dim().size() > MaxDimensions()) { |
88 | return errors::InvalidArgument("Shape " , DebugString(proto), |
89 | " has too many dimensions" ); |
90 | } |
91 | for (const auto& d : proto.dim()) { |
92 | if (d.size() < (kIsPartial ? -1 : 0)) { |
93 | if (kIsPartial) { |
94 | return errors::InvalidArgument( |
95 | "Shape " , DebugString(proto), |
96 | " has dimensions with values below -1 (where -1 means unknown)" ); |
97 | } else { |
98 | return errors::InvalidArgument("Shape " , DebugString(proto), |
99 | " is not fully defined" ); |
100 | } |
101 | } |
102 | if (d.size() == -1) { |
103 | num_elements = -1; |
104 | } else if (!kIsPartial || num_elements >= 0) { |
105 | num_elements = MultiplyWithoutOverflow(num_elements, d.size()); |
106 | if (num_elements < 0) { |
107 | return errors::InvalidArgument( |
108 | "Shape " , DebugString(proto), |
109 | " is too large (more than 2**63 - 1 entries)" ); |
110 | } |
111 | } |
112 | } |
113 | return Status::OK(); |
114 | } |
115 | |
116 | template <class Shape> |
117 | TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) { |
118 | set_tag(REP16); |
119 | set_data_type(DT_INVALID); |
120 | // NOTE(irving): Unfortunately, TensorShape allows parsing protos with |
121 | // unknown_shape() set, and it seems hard to remove this without backwards |
122 | // compatibility issues. |
123 | if (kIsPartial && proto.unknown_rank()) { |
124 | set_ndims_byte(kUnknownRank); |
125 | set_num_elements(-1); |
126 | } else { |
127 | set_ndims_byte(0); |
128 | set_num_elements(1); |
129 | for (const auto& d : proto.dim()) { |
130 | AddDim(d.size()); |
131 | } |
132 | } |
133 | } |
134 | |
135 | template <class Shape> |
136 | TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) { |
137 | set_tag(REP16); |
138 | set_data_type(DT_INVALID); |
139 | set_ndims_byte(0); |
140 | set_num_elements(1); |
141 | for (int64 s : dim_sizes) { |
142 | AddDim(internal::SubtleMustCopy(s)); |
143 | } |
144 | } |
145 | |
146 | template <class Shape> |
147 | TensorShapeBase<Shape>::TensorShapeBase() { |
148 | set_tag(REP16); |
149 | set_data_type(DT_INVALID); |
150 | if (kIsPartial) { |
151 | set_ndims_byte(kUnknownRank); |
152 | set_num_elements(-1); |
153 | } else { |
154 | set_ndims_byte(0); |
155 | set_num_elements(1); |
156 | } |
157 | } |
158 | |
159 | void TensorShapeRep::DestructorOutOfLine() { |
160 | DCHECK(tag() == REP_OUT_OF_LINE); |
161 | delete as64()->dims_; |
162 | } |
163 | |
164 | void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) { |
165 | if (b.tag() != REP_OUT_OF_LINE) { |
166 | if (tag() == REP_OUT_OF_LINE) { |
167 | delete as64()->dims_; |
168 | } |
169 | memcpy(buf(), b.buf(), sizeof(u_.buf)); |
170 | // memcpy above implicitly also does: |
171 | // set_tag(b.tag()); |
172 | // set_ndims_byte(b.ndims_byte()); |
173 | // set_data_type(b.data_type()); |
174 | } else { |
175 | DCHECK_EQ(b.tag(), REP_OUT_OF_LINE); |
176 | set_ndims_byte(b.ndims_byte()); |
177 | set_data_type(b.data_type()); |
178 | if (tag() == REP_OUT_OF_LINE) { |
179 | // vector already allocated |
180 | *(as64()->dims_) = *(b.as64()->dims_); |
181 | } else { |
182 | set_tag(REP_OUT_OF_LINE); |
183 | as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_)); |
184 | } |
185 | } |
186 | } |
187 | |
188 | template <class Shape> |
189 | int64 TensorShapeBase<Shape>::dim_size(int d) const { |
190 | if (unknown_rank()) return -1; |
191 | DCHECK_GE(d, 0); |
192 | DCHECK_LT(d, dims()); |
193 | if (tag() == REP16) { |
194 | uint16 dim = as16()->dims_[d]; |
195 | if (kIsPartial && dim == kUnknownRep16) return -1; |
196 | return dim; |
197 | } else if (tag() == REP32) { |
198 | uint32 dim = as32()->dims_[d]; |
199 | if (kIsPartial && dim == kUnknownRep32) return -1; |
200 | return dim; |
201 | } else { |
202 | return (*as64()->dims_)[d]; |
203 | } |
204 | } |
205 | |
206 | void TensorShapeRep::Clear() { |
207 | ClearAllButDataType(); |
208 | set_data_type(DT_INVALID); |
209 | } |
210 | |
211 | void TensorShapeRep::ClearAllButDataType() { |
212 | if (tag() == REP_OUT_OF_LINE) { |
213 | delete as64()->dims_; |
214 | } |
215 | set_tag(REP16); |
216 | set_ndims_byte(0); |
217 | // Leaves data_type alone |
218 | set_num_elements(1); |
219 | } |
220 | |
221 | template <class Shape> |
222 | void TensorShapeBase<Shape>::RecomputeNumElements() { |
223 | if (unknown_rank()) { |
224 | set_num_elements(-1); |
225 | return; |
226 | } |
227 | int64 n = 1; |
228 | for (auto dim : *this) { |
229 | if (kIsPartial && dim.size < 0) { |
230 | n = -1; |
231 | break; |
232 | } |
233 | n = MultiplyWithoutOverflow(n, dim.size); |
234 | CHECK_LE(0, n); |
235 | } |
236 | set_num_elements(n); |
237 | } |
238 | |
239 | template <class Shape> |
240 | void TensorShapeBase<Shape>::AddDim(int64 size) { |
241 | if (!kIsPartial) CHECK_GE(size, 0); |
242 | if (unknown_rank()) return; |
243 | CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor" ; |
244 | int64 new_num_elements; |
245 | if (kIsPartial && (num_elements() < 0 || size < 0)) { |
246 | new_num_elements = -1; |
247 | } else { |
248 | new_num_elements = MultiplyWithoutOverflow(num_elements(), size); |
249 | CHECK_LE(0, new_num_elements); |
250 | } |
251 | UnsafeAddDim(size, new_num_elements); |
252 | } |
253 | |
254 | template <class Shape> |
255 | void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) { |
256 | const int nd = ndims_byte(); |
257 | if (tag() == REP16 && nd < 6 && size < kMaxRep16) { |
258 | as16()->dims_[nd] = |
259 | kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size); |
260 | } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) { |
261 | as32()->dims_[nd] = |
262 | kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size); |
263 | } else if (tag() == REP_OUT_OF_LINE) { |
264 | as64()->dims_->push_back(size); |
265 | } else { |
266 | // Need to change representation |
267 | gtl::InlinedVector<int64, 8> vals; |
268 | AppendTo(*this, &vals); |
269 | vals.push_back(size); |
270 | // We know we can't be REP16. See if we have a small enough |
271 | // number of dimensions and each dimension's size is small enough |
272 | // to allow REP32. |
273 | bool can_be_rep32 = (vals.size() <= 3); |
274 | if (can_be_rep32) { |
275 | for (size_t i = 0; i < vals.size(); i++) { |
276 | if (vals[i] >= kMaxRep32) { |
277 | can_be_rep32 = false; |
278 | break; |
279 | } |
280 | } |
281 | } |
282 | if (can_be_rep32) { |
283 | set_tag(REP32); |
284 | for (size_t d = 0; d < vals.size(); d++) { |
285 | as32()->dims_[d] = kIsPartial && vals[d] < 0 |
286 | ? kUnknownRep32 |
287 | : static_cast<uint32>(vals[d]); |
288 | } |
289 | } else { |
290 | set_tag(REP_OUT_OF_LINE); |
291 | as64()->dims_ = |
292 | new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end()); |
293 | } |
294 | } |
295 | set_ndims_byte(nd + 1); |
296 | set_num_elements(new_num_elements); |
297 | } |
298 | |
299 | template <class Shape> |
300 | void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) { |
301 | for (auto d : shape) AddDim(d.size); |
302 | } |
303 | |
304 | template <class Shape> |
305 | void TensorShapeBase<Shape>::InsertDim(int d, int64 size) { |
306 | CHECK_GE(d, 0); |
307 | CHECK_LE(d, dims()); |
308 | if (!kIsPartial) CHECK_GE(size, 0); |
309 | CHECK_LT(dims(), MaxDimensions()); |
310 | gtl::InlinedVector<int64, 8> vals; |
311 | AppendTo(*this, &vals); |
312 | vals.insert(vals.begin() + d, size); |
313 | ClearAllButDataType(); |
314 | for (auto dval : vals) { |
315 | AddDim(dval); |
316 | } |
317 | } |
318 | |
319 | template <class Shape> |
320 | gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const { |
321 | gtl::InlinedVector<int64, 4> result; |
322 | for (auto dim : *this) { |
323 | result.push_back(dim.size); |
324 | } |
325 | return result; |
326 | } |
327 | |
328 | template <class Shape> |
329 | void TensorShapeBase<Shape>::set_dim(int d, int64 size) { |
330 | CHECK_GE(d, 0); |
331 | CHECK_LT(d, dims()); |
332 | CHECK_GE(size, 0); |
333 | if (tag() == REP16 && size < kMaxRep16) { |
334 | as16()->dims_[d] = |
335 | kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size); |
336 | } else if (tag() == REP32 && size < kMaxRep32) { |
337 | as32()->dims_[d] = |
338 | kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size); |
339 | } else if (tag() == REP_OUT_OF_LINE) { |
340 | (*as64()->dims_)[d] = size; |
341 | } else { |
342 | // Must upgrade |
343 | gtl::InlinedVector<int64, 8> vals; |
344 | AppendTo(*this, &vals); |
345 | vals[d] = size; |
346 | ClearAllButDataType(); |
347 | for (auto dval : vals) { |
348 | AddDim(dval); |
349 | } |
350 | } |
351 | RecomputeNumElements(); |
352 | } |
353 | |
354 | template <class Shape> |
355 | void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) { |
356 | if (unknown_rank()) return; |
357 | begin = begin < 0 ? dims() + begin + 1 : begin; |
358 | end = end < 0 ? dims() + end + 1 : end; |
359 | CHECK_GE(begin, 0); |
360 | CHECK_LE(begin, dims()); |
361 | CHECK_GE(end, 0); |
362 | CHECK_LE(end, dims()); |
363 | if (begin >= end) return; |
364 | gtl::InlinedVector<int64, 8> vals; |
365 | AppendTo(*this, &vals); |
366 | vals.erase(vals.begin() + begin, vals.begin() + end); |
367 | ClearAllButDataType(); |
368 | for (auto dval : vals) { |
369 | AddDim(dval); |
370 | } |
371 | RecomputeNumElements(); |
372 | } |
373 | |
374 | bool TensorShape::IsSameSize(const TensorShape& b) const { |
375 | if (b.dims() != dims()) return false; |
376 | for (int d = 0; d < dims(); d++) { |
377 | if (dim_size(d) != b.dim_size(d)) return false; |
378 | } |
379 | return true; |
380 | } |
381 | |
382 | template <class Shape> |
383 | void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const { |
384 | proto->Clear(); |
385 | if (unknown_rank()) { |
386 | proto->set_unknown_rank(true); |
387 | } else { |
388 | for (int i = 0; i < dims(); i++) { |
389 | proto->add_dim()->set_size(dim_size(i)); |
390 | } |
391 | } |
392 | } |
393 | |
394 | void TensorShapeRep::DumpRep() const { |
395 | #if 0 |
396 | fprintf(stderr, "Rep: %d %d dims\n" , tag(), dims()); |
397 | if (tag() == REP16) { |
398 | fprintf(stderr, "REP16 NDIMS: %d\n" , ndims_byte()); |
399 | for (int i = 0; i < ndims_byte(); i++) { |
400 | fprintf(stderr, "dim %d: %d\n" , i, as16()->dims_[i]); |
401 | } |
402 | } else if (tag_ == REP32) { |
403 | fprintf(stderr, "REP32 NDIMS: %d\n" , ndims_); |
404 | for (int i = 0; i < ndims_byte(); i++) { |
405 | fprintf(stderr, "dim %d: %d\n" , i, as32()->dims_[i]); |
406 | } |
407 | } else if (tag_ == REP_OUT_OF_LINE) { |
408 | fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n" , ndims_, as16()->dims_); |
409 | for (int i = 0; i < ndims_byte(); i++) { |
410 | fprintf(stderr, "dim %d: %lld\n" , i, (*as64()->dims_)[i]); |
411 | } |
412 | } |
413 | #endif |
414 | } |
415 | |
416 | template <class Shape> |
417 | TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const { |
418 | return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0); |
419 | } |
420 | |
421 | template <class Shape> |
422 | TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const { |
423 | CHECK(!unknown_rank()); |
424 | return TensorShapeIter<Shape>(static_cast<const Shape*>(this), dims()); |
425 | } |
426 | |
427 | string TensorShapeRep::DebugString() const { |
428 | const auto& shape = *static_cast<const PartialTensorShape*>(this); |
429 | if (shape.unknown_rank()) return "<unknown>" ; |
430 | string s = "[" ; |
431 | for (int i = 0; i < shape.dims(); i++) { |
432 | if (i > 0) strings::StrAppend(&s, "," ); |
433 | int64 dim = shape.dim_size(i); |
434 | if (dim < 0) { |
435 | strings::StrAppend(&s, "?" ); |
436 | } else { |
437 | strings::StrAppend(&s, dim); |
438 | } |
439 | } |
440 | strings::StrAppend(&s, "]" ); |
441 | return s; |
442 | } |
443 | |
444 | string TensorShapeRep::DebugString(const TensorShapeProto& proto) { |
445 | string s; |
446 | if (proto.unknown_rank()) { |
447 | strings::StrAppend(&s, "<unknown>" ); |
448 | if (proto.dim_size() == 0) return s; |
449 | } |
450 | strings::StrAppend(&s, "[" ); |
451 | bool first = true; |
452 | for (const auto& d : proto.dim()) { |
453 | if (!first) strings::StrAppend(&s, "," ); |
454 | if (d.size() == -1) { |
455 | strings::StrAppend(&s, "?" ); |
456 | } else { |
457 | strings::StrAppend(&s, d.size()); |
458 | } |
459 | first = false; |
460 | } |
461 | strings::StrAppend(&s, "]" ); |
462 | return s; |
463 | } |
464 | |
465 | bool TensorShapeUtils::StartsWith(const TensorShape& shape, |
466 | const TensorShape& prefix) { |
467 | if (shape.dims() < prefix.dims()) return false; |
468 | for (int i = 0; i < prefix.dims(); ++i) { |
469 | if (shape.dim_size(i) != prefix.dim_size(i)) return false; |
470 | } |
471 | return true; |
472 | } |
473 | |
474 | bool TensorShapeUtils::EndsWith(const TensorShape& shape, |
475 | const TensorShape& suffix) { |
476 | const int suffix_size = suffix.dims(); |
477 | if (shape.dims() < suffix_size) return false; |
478 | for (int i = 0; i < suffix_size; ++i) { |
479 | if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) { |
480 | return false; |
481 | } |
482 | } |
483 | return true; |
484 | } |
485 | |
486 | template <typename T, class Shape> |
487 | Status MakeShapeHelper(const T* dims, int64 n, Shape* out) { |
488 | out->Clear(); |
489 | if (n > TensorShape::MaxDimensions()) { |
490 | return errors::InvalidArgument("Too many dimensions" ); |
491 | } |
492 | if (n < 0) { |
493 | return errors::InvalidArgument("Negative number of dimensions " , n); |
494 | } |
495 | for (int64 i = 0; i < n; ++i) { |
496 | T dim = internal::SubtleMustCopy(dims[i]); |
497 | int64 new_num_elements; |
498 | if (dim < 0) { |
499 | if (!out->kIsPartial) { |
500 | return errors::InvalidArgument("Dimension " , dim, " must be >= 0" ); |
501 | } |
502 | if (dim < -1) { |
503 | return errors::InvalidArgument("Dimension " , dim, " must be >= -1" ); |
504 | } |
505 | dim = -1; |
506 | new_num_elements = -1; |
507 | } else if (out->num_elements() < 0) { |
508 | new_num_elements = -1; |
509 | } else { |
510 | new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim); |
511 | if (TF_PREDICT_FALSE(new_num_elements < 0)) { |
512 | TensorShapeProto proto; |
513 | for (int64 j = 0; j < n; ++j) { |
514 | proto.add_dim()->set_size(dim); |
515 | } |
516 | return errors::InvalidArgument( |
517 | "Shape " , TensorShape::DebugString(proto), |
518 | " would have more than 2**63 - 1 elements" ); |
519 | } |
520 | } |
521 | out->UnsafeAddDim(dim, new_num_elements); |
522 | } |
523 | return Status::OK(); |
524 | } |
525 | |
526 | #define MAKE_SHAPE(T, Shape) \ |
527 | Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) { \ |
528 | return MakeShapeHelper(dims, n, out); \ |
529 | } \ |
530 | Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \ |
531 | return MakeShapeHelper(shape.data(), shape.size(), out); \ |
532 | } |
533 | MAKE_SHAPE(int32, TensorShape) |
534 | MAKE_SHAPE(int64, TensorShape) |
535 | MAKE_SHAPE(int32, PartialTensorShape) |
536 | MAKE_SHAPE(int64, PartialTensorShape) |
537 | #undef MAKE_SHAPE |
538 | |
539 | string TensorShapeUtils::ShapeListString( |
540 | const gtl::ArraySlice<TensorShape>& shapes) { |
541 | string result = "[" ; |
542 | bool first = true; |
543 | for (const TensorShape& shape : shapes) { |
544 | strings::StrAppend(&result, (first ? "" : ", " ), shape.DebugString()); |
545 | first = false; |
546 | } |
547 | strings::StrAppend(&result, "]" ); |
548 | return result; |
549 | } |
550 | |
551 | PartialTensorShape PartialTensorShape::Concatenate(int64 size) const { |
552 | PartialTensorShape out = *this; |
553 | out.AddDim(size); |
554 | return out; |
555 | } |
556 | |
557 | PartialTensorShape PartialTensorShape::Concatenate( |
558 | const PartialTensorShape& shape) const { |
559 | if (unknown_rank() || shape.unknown_rank()) { |
560 | return PartialTensorShape(); |
561 | } |
562 | PartialTensorShape out = *this; |
563 | for (auto dim : shape) out.AddDim(dim.size); |
564 | return out; |
565 | } |
566 | |
567 | Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, |
568 | PartialTensorShape* result) const { |
569 | if (unknown_rank()) { |
570 | *result = shape; |
571 | return Status::OK(); |
572 | } |
573 | if (shape.unknown_rank()) { |
574 | *result = *this; |
575 | return Status::OK(); |
576 | } |
577 | const int dims_ = dims(); |
578 | if (dims_ != shape.dims()) { |
579 | return errors::InvalidArgument( |
580 | "PartialTensorShape: Incompatible ranks during merge: " , dims_, " vs. " , |
581 | shape.dims()); |
582 | } |
583 | CHECK(result != this); |
584 | result->Clear(); |
585 | for (int i = 0; i < dims_; ++i) { |
586 | const int64 dim0 = dim_size(i); |
587 | const int64 dim1 = shape.dim_size(i); |
588 | if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) { |
589 | return errors::InvalidArgument( |
590 | "PartialTensorShape: Incompatible shapes during merge: " , |
591 | DebugString(), " vs. " , shape.DebugString()); |
592 | } |
593 | result->AddDim(dim0 >= 0 ? dim0 : dim1); |
594 | } |
595 | return Status::OK(); |
596 | } |
597 | |
598 | bool PartialTensorShape::AsTensorShape(TensorShape* shape) const { |
599 | if (IsFullyDefined()) { |
600 | const TensorShapeRep* rep = this; |
601 | *shape = *static_cast<const TensorShape*>(rep); |
602 | return true; |
603 | } |
604 | return false; |
605 | } |
606 | |
607 | bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const { |
608 | if (unknown_rank() || shape.unknown_rank()) { |
609 | return unknown_rank() == shape.unknown_rank(); |
610 | } |
611 | if (dims() != shape.dims()) return false; |
612 | for (int i = 0; i < dims(); i++) { |
613 | if (dim_size(i) != shape.dim_size(i)) return false; |
614 | } |
615 | return true; |
616 | } |
617 | |
618 | bool PartialTensorShape::IsCompatibleWith( |
619 | const PartialTensorShape& shape) const { |
620 | if (unknown_rank() || shape.unknown_rank()) return true; |
621 | if (dims() != shape.dims()) return false; |
622 | for (int i = 0; i < dims(); i++) { |
623 | const int64 dim0 = dim_size(i); |
624 | const int64 dim1 = shape.dim_size(i); |
625 | if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false; |
626 | } |
627 | return true; |
628 | } |
629 | |
630 | string PartialTensorShapeUtils::PartialShapeListString( |
631 | const gtl::ArraySlice<PartialTensorShape>& shapes) { |
632 | string result = "[" ; |
633 | bool first = true; |
634 | for (const PartialTensorShape& shape : shapes) { |
635 | strings::StrAppend(&result, (first ? "" : ", " ), shape.DebugString()); |
636 | first = false; |
637 | } |
638 | strings::StrAppend(&result, "]" ); |
639 | return result; |
640 | } |
641 | |
642 | bool PartialTensorShapeUtils::AreCompatible( |
643 | const gtl::ArraySlice<PartialTensorShape>& shapes0, |
644 | const gtl::ArraySlice<PartialTensorShape>& shapes1) { |
645 | if (shapes0.size() == shapes1.size()) { |
646 | for (size_t i = 0; i < shapes0.size(); ++i) { |
647 | if (!shapes0[i].IsCompatibleWith(shapes1[i])) { |
648 | return false; |
649 | } |
650 | } |
651 | return true; |
652 | } else { |
653 | return false; |
654 | } |
655 | } |
656 | |
657 | bool PartialTensorShapeUtils::AreIdentical( |
658 | const gtl::ArraySlice<PartialTensorShape>& shapes0, |
659 | const gtl::ArraySlice<PartialTensorShape>& shapes1) { |
660 | if (shapes0.size() == shapes1.size()) { |
661 | for (size_t i = 0; i < shapes0.size(); ++i) { |
662 | if (!shapes0[i].IsIdenticalTo(shapes1[i])) { |
663 | return false; |
664 | } |
665 | } |
666 | return true; |
667 | } else { |
668 | return false; |
669 | } |
670 | } |
671 | |
672 | Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape, |
673 | int64* num_elements) { |
674 | int64 n = 1; |
675 | for (auto dim : shape) { |
676 | n = MultiplyWithoutOverflow(n, dim); |
677 | if (n < 0) { |
678 | return errors::InvalidArgument("Can't compute total size of shape [" , |
679 | str_util::Join(shape, "," ), |
680 | "]; product would overflow int64" ); |
681 | } |
682 | } |
683 | *num_elements = n; |
684 | return Status::OK(); |
685 | } |
686 | |
687 | template class TensorShapeBase<TensorShape>; |
688 | template class TensorShapeBase<PartialTensorShape>; |
689 | |
690 | } // namespace tensorflow |
691 | |