1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
16#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
17
18#include <vector>
19
20#include "tensorflow/core/framework/node_def_util.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/lib/core/errors.h"
23#include "tensorflow/core/lib/core/status.h"
24#include "tensorflow/core/lib/gtl/inlined_vector.h"
25#include "tensorflow/core/platform/macros.h"
26
27namespace tensorflow {
28
29class ShapeRefiner;
30class ShapeRefinerTest;
31
32namespace grappler {
33class GraphProperties;
34class SymbolicShapeManager;
35} // namespace grappler
36
37namespace shape_inference {
38
39struct DimensionOrConstant;
40class InferenceContext;
41
42// Dimension values are accessed through InferenceContext.
43class Dimension {
44 private:
45 Dimension();
46 Dimension(int64 value);
47 ~Dimension() {}
48
49 const int64 value_;
50
51 friend class InferenceContext;
52 friend class ShapeManager;
53 TF_DISALLOW_COPY_AND_ASSIGN(Dimension);
54};
55
56class DimensionHandle {
57 public:
58 DimensionHandle() {}
59 bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; }
60 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
61
62 private:
63 DimensionHandle(const Dimension* dim) { ptr_ = dim; }
64
65 const Dimension* operator->() const { return ptr_; }
66 bool IsSet() const { return ptr_ != nullptr; }
67
68 const Dimension* ptr_ = nullptr;
69
70 friend struct DimensionOrConstant;
71 friend class InferenceContext;
72 friend class ShapeInferenceTest;
73 friend class ShapeInferenceTestutil;
74 friend class ::tensorflow::ShapeRefinerTest;
75 friend class ShapeManager;
76 friend class ::tensorflow::grappler::GraphProperties;
77 friend class ::tensorflow::grappler::SymbolicShapeManager;
78
79 // Intentionally copyable.
80};
81
82// Shape rank and dimensions are accessed through InferenceContext.
83class Shape {
84 private:
85 Shape();
86 Shape(const std::vector<DimensionHandle>& dims);
87 ~Shape() {}
88
89 const int32 rank_;
90 const std::vector<DimensionHandle> dims_;
91
92 friend class InferenceContext;
93 friend class ShapeManager;
94 friend class ::tensorflow::grappler::SymbolicShapeManager;
95
96 TF_DISALLOW_COPY_AND_ASSIGN(Shape);
97};
98
99class ShapeHandle {
100 public:
101 ShapeHandle() {}
102 bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; }
103 std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); }
104
105 private:
106 ShapeHandle(const Shape* shape) { ptr_ = shape; }
107 const Shape* operator->() const { return ptr_; }
108 bool IsSet() const { return ptr_ != nullptr; }
109
110 const Shape* ptr_ = nullptr;
111
112 friend class InferenceContext;
113 friend class ShapeInferenceTest;
114 friend class ShapeInferenceTestutil;
115 friend class ::tensorflow::ShapeRefinerTest;
116 friend class ShapeManager;
117 friend class ::tensorflow::grappler::SymbolicShapeManager;
118
119 // Intentionally copyable.
120};
121
122// Struct used to allow functions to take DimensionHandle or a dimension value.
123// Not meant to be constructed directly.
124struct DimensionOrConstant {
125 public:
126 // Intentionally not explicit.
127 DimensionOrConstant(DimensionHandle dim);
128
129 // val must be non-negative or InferenceContext::kUnknownDim.
130 DimensionOrConstant(int64 val);
131
132 // dim takes precedence. If dim != nullptr, val is ignored.
133 DimensionHandle dim;
134 int64 val;
135
136 private:
137 DimensionOrConstant();
138};
139
140struct ShapeAndType {
141 ShapeAndType() {}
142 ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
143
144 ShapeHandle shape;
145 DataType dtype = DT_INVALID;
146};
147
148// Shape inference functions registered on ops in REGISTER_OP implement
149// their shape functions in terms of this InferenceContext. An InferenceContext
150// is created by the framework and passed to a shape inference function. The
151// shape inference function calls functions on the context, and should call
152// set_output() to set the shape on all outputs.
153//
154// To infer shapes for user-defined functions see ShapeRefiner.
155//
156// All Shape* and Dimension* returned by functions of InferenceContext are owned
157// by the InferenceContext.
158class InferenceContext {
159 public:
160 static constexpr int64 kUnknownDim = -1;
161 static constexpr int32 kUnknownRank = -1;
162
163 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
164 //
165 // Elements of <input_tensors_as_shapes> are used for when a shape function
166 // makes a call to MakeShapeFromShapeTensor; in particular, when the
167 // input_tensors[i] is nullptr but the shape represented by it is partially
168 // known from analysis of the graph.
169 // <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
170 // Values of <input_tensors_as_shapes> do not need to outlive the context.
171 //
172 // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
173 InferenceContext(int graph_def_version, const NodeDef* node_def,
174 const OpDef& op_def,
175 const std::vector<ShapeHandle>& input_shapes,
176 const std::vector<const Tensor*>& input_tensors,
177 const std::vector<ShapeHandle>& input_tensors_as_shapes,
178 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
179 input_handle_shapes_and_types);
180
181 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
182 //
183 // Elements of <input_tensors_as_shapes> are used for when a shape
184 // function makes a call to MakeShapeFromShapeTensor; in particular, when
185 // the input_tensors[i] is nullptr but the shape represented by it is
186 // partially known from analysis of the graph. <input_tensors_as_shapes>
187 // can have fewer elements than <input_shapes>. Values of
188 // <input_tensors_as_shapes> do not need to outlive the context.
189 //
190 // REQUIRES: <node_def> is not NULL, and must outlive the
191 // InferenceContext.
192 InferenceContext(
193 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
194 const std::vector<TensorShapeProto>& input_shapes,
195 const std::vector<const Tensor*>& input_tensors,
196 const std::vector<TensorShapeProto>& input_tensors_as_shapes,
197 const std::vector<
198 std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
199 input_handle_shapes_and_types);
200
201 // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
202 //
203 // Elements of <input_tensors_as_shapes> are used for when a shape
204 // function makes a call to MakeShapeFromShapeTensor; in particular, when
205 // the input_tensors[i] is nullptr but the shape represented by it is
206 // partially known from analysis of the graph. <input_tensors_as_shapes>
207 // can have fewer elements than <input_shapes>. Values of
208 // <input_tensors_as_shapes> do not need to outlive the context.
209 //
210 // REQUIRES: <node_def> is not NULL, and must outlive the
211 // InferenceContext.
212 InferenceContext(
213 int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
214 const std::vector<PartialTensorShape>& input_shapes,
215 const std::vector<const Tensor*>& input_tensors,
216 const std::vector<PartialTensorShape>& input_tensors_as_shapes,
217 const std::vector<std::unique_ptr<
218 std::vector<std::pair<PartialTensorShape, DataType>>>>&
219 input_handle_shapes_and_types);
220
221 ~InferenceContext();
222
223 // Runs the shape inference function 'fn' with 'this' as the
224 // argument, returns the status of the inference.
225 //
226 // On error, additional context is provided in the error message.
227 Status Run(
228 const std::function<Status(shape_inference::InferenceContext* c)>& fn);
229
230 // Merge the stored shape of the input in position idx with <shape> according
231 // to the following rules:
232 //
233 // - If the ShapeHandles are the same or <shape> is unknown, there will be no
234 // change. Otherwise if the stored shape is unknown, the new shape will be
235 // <shape>.
236 // - If both shapes are known, then they must have the same rank.
237 // - For any one dimension, if the values for that dimension in both shapes
238 // are known, then the values must match.
239 // - If one shape has equal or more information than the other shape in every
240 // dimension, the new shape will become the shape with more information.
241 // - Example: merging [2,?] and [?,2] results in [2,2]
242 // - Example: [2,2] cannot be merged with [1,2]
243 //
244 // This requires idx to be in the [0, num_inputs) range. If the merge is
245 // successful, return true. Return false otherwise.
246 bool MergeInput(int idx, ShapeHandle shape) {
247 ShapeHandle new_shape;
248 if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false;
249 inputs_[idx] = new_shape;
250 return true;
251 }
252
253 // Relax the stored shape of the input in position idx with <shape> according
254 // to the following rules:
255 //
256 // - If the ShapeHandles are the same then the stored shape will be returned.
257 // - If either of the ShapeHandles are unknown, then a new UnknownShape will
258 // be returned. A new shape must be returned because we cannot claim that
259 // the resulting shape is necessarily the same as either of the input
260 // shapes.
261 // - If the shapes both have known ranks but their ranks are different, a new
262 // UnknownShape will be returned.
263 // - For any one dimension, if the value for that dimension in either of the
264 // shapes is unknown, a new shape will be returned with a new UnknownDim in
265 // that dimension.
266 // - For any one dimension, if the values for that dimension in both shapes
267 // are known but do not match, a new shape will be returned with a new
268 // UnknownDim in that dimension.
269 // - If both shapes have the same known rank and match in every dimension,
270 // the stored shape will be returned.
271 // - Example: relaxing [2,?] and [?,2] results in [?,?]
272 // - Example: relaxing [2,2] and [3,2] results in [?,2]
273 // - Example: relaxing [2,2] with [1,2,3] results in ?
274 //
275 // This requires idx to be in the [0, num_inputs) range. If the relax is
276 // successful and the new shape differs from the old one, store the new
277 // shape and return true. Return false otherwise.
278 bool RelaxInput(int idx, ShapeHandle shape) {
279 ShapeHandle new_shape;
280 Relax(inputs_[idx], shape, &new_shape);
281 if (inputs_[idx].SameHandle(new_shape)) {
282 return false;
283 }
284 inputs_[idx] = new_shape;
285 return true;
286 }
287
288 ShapeHandle input(int64 idx) const { return inputs_[idx]; }
289 Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const;
290 int num_inputs() const { return inputs_.size(); }
291
292 // Returns the input tensor at index <idx>, or nullptr if the input tensor is
293 // not available at the time of shape inference.
294 const Tensor* input_tensor(int idx) {
295 // Mark that this idx was requested.
296 requested_input_tensor_[idx] = true;
297 return input_tensors_[idx];
298 }
299
300 // Returns true iff input_tensor(idx) was called by the shape function.
301 bool requested_input_tensor(int idx) const {
302 return requested_input_tensor_[idx];
303 }
304
305 // Returns true if MakeShapeFromInputTensor was called but the constant
306 // input_tensor was not present.
307 bool requested_input_tensor_as_partial_shape(int idx) const {
308 return requested_input_tensor_as_partial_shape_[idx];
309 }
310
311 void set_input_tensors(const std::vector<const Tensor*>& input_tensors) {
312 input_tensors_ = input_tensors;
313 }
314
315 void set_input_tensors_as_shapes(
316 const std::vector<ShapeHandle>& input_tensors_as_shapes) {
317 input_tensors_as_shapes_ = input_tensors_as_shapes;
318 }
319
320 ShapeHandle output(int64 idx) const { return outputs_[idx]; }
321 void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
322 Status set_output(StringPiece output_name,
323 const std::vector<ShapeHandle>& shapes);
324
325 int num_outputs() const { return outputs_.size(); }
326 ShapeHandle output(int idx) const { return outputs_[idx]; }
327 Status output(StringPiece output_name,
328 std::vector<ShapeHandle>* output) const;
329
330 AttrSlice attrs() const { return AttrSlice(*node_def_); }
331
332 string op() const;
333
334 // idx can be negative for an offset from end of dimensions.
335 // idx must be in the range [-1 * s.rank, s.rank).
336 DimensionHandle Dim(ShapeHandle s, int64 idx) {
337 if (s->rank_ == kUnknownRank) {
338 return UnknownDim();
339 }
340 return DimKnownRank(s, idx);
341 }
342 // As above, but asserts that the rank of the shape is known.
343 static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) {
344 CHECK_NE(s->rank_, kUnknownRank);
345 if (idx < 0) {
346 return s->dims_[s->dims_.size() + idx];
347 }
348 return s->dims_[idx];
349 }
350
351 static int32 Rank(ShapeHandle s) {
352 DCHECK(s.IsSet());
353 return s.IsSet() ? s->rank_ : kUnknownRank;
354 }
355 static bool RankKnown(ShapeHandle s) {
356 return (s.IsSet() && (Rank(s) != kUnknownRank));
357 }
358 static inline int64 Value(DimensionOrConstant d) {
359 return d.dim.IsSet() ? d.dim->value_ : d.val;
360 }
361 static inline bool ValueKnown(DimensionOrConstant d) {
362 return Value(d) != kUnknownDim;
363 }
364
365 // Fills the output proto with the shape defined by the handle.
366 // "proto" is expected to be empty prior to the call.
367 void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto);
368
369 // Returns true if the rank and all dimensions of the Shape are known.
370 bool FullyDefined(ShapeHandle s);
371
372 // Returns the total number of elements, or an unknown dimension for an
373 // incomplete shape.
374 DimensionHandle NumElements(ShapeHandle s);
375
376 string DebugString(ShapeHandle s);
377 string DebugString(DimensionHandle d);
378
379 // Describes the whole context, for debugging purposes.
380 string DebugString() const;
381
382 // If <shape> has rank <rank>, or its rank is unknown, return OK and return
383 // the shape with asserted rank in <*out>. Otherwise return an error.
384 //
385 // Note that <*out> may be set to <shape>.
386 Status WithRank(ShapeHandle shape, int64 rank,
387 ShapeHandle* out) TF_MUST_USE_RESULT;
388 Status WithRankAtLeast(ShapeHandle shape, int64 rank,
389 ShapeHandle* out) TF_MUST_USE_RESULT;
390 Status WithRankAtMost(ShapeHandle shape, int64 rank,
391 ShapeHandle* out) TF_MUST_USE_RESULT;
392
393 // If <dim> has value <value>, or its value is unknown, returns OK and returns
394 // the dimension with asserted value in <*out>. Otherwise returns an error.
395 //
396 // Note that <*out> may be set to <dim>.
397 Status WithValue(DimensionHandle dim, int64 value,
398 DimensionHandle* out) TF_MUST_USE_RESULT;
399
400 // Merges <s0> and <s1> and returns the merged shape in <*out>. See
401 // 'MergeInput' function for full details and examples.
402 Status Merge(ShapeHandle s0, ShapeHandle s1,
403 ShapeHandle* out) TF_MUST_USE_RESULT;
404
405 // Asserts that <s>'s rank >= <prefix>'s rank, and the first
406 // <prefix.rank> dimensions of <s> are compatible with the dimensions of
407 // <prefix>.
408 // Returns the merged results in <*s_out> and <*prefix_out>.
409 Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out,
410 ShapeHandle* prefix_out) TF_MUST_USE_RESULT;
411
412 // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0>
413 // and <d1> have incompatible values, returns an error.
414 //
415 // Note that <*out> may be set to <d0> or <d1>.
416 Status Merge(DimensionHandle d0, DimensionHandle d1,
417 DimensionHandle* out) TF_MUST_USE_RESULT;
418
419 // Returns in <*out> a sub-shape of <s> with dimensions [start:].
420 // <start> can be negative to index from the end of the shape. If <start> >
421 // rank of <s>, then an empty subshape is returned.
422 Status Subshape(ShapeHandle s, int64 start,
423 ShapeHandle* out) TF_MUST_USE_RESULT;
424
425 // Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
426 // <start> and <end> can be negative, to index from the end of the shape.
427 // <start> and <end> are set to the rank of <s> if > rank of <s>.
428 Status Subshape(ShapeHandle s, int64 start, int64 end,
429 ShapeHandle* out) TF_MUST_USE_RESULT;
430
431 // Returns in <*out> the result of appending the dimensions of <s2> to those
432 // of <s1>.
433 Status Concatenate(ShapeHandle s1, ShapeHandle s2,
434 ShapeHandle* out) TF_MUST_USE_RESULT;
435
436 // Returns in <out> the shape from replacing <s.dim[dim_index]> with
437 // <new_dim>.
438 Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim,
439 ShapeHandle* out) TF_MUST_USE_RESULT;
440
441 // Returns a new shape with the given dims. The returned value is owned by
442 // this context.
443 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
444 ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims);
445
446 // Returns a new unknown shape.
447 ShapeHandle UnknownShape();
448
449 // Returns a shape with specified rank but unknown dims.
450 ShapeHandle UnknownShapeOfRank(int64 rank);
451
452 // Returns a new shape of zero dimensions.
453 ShapeHandle Scalar();
454
455 // Returns a new shape of one dimension.
456 ShapeHandle Vector(DimensionOrConstant dim);
457
458 // Returns a new shape of two dimensions.
459 ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2);
460
461 // Returns in <out> a new shape whose dimension sizes come from input tensor
462 // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If
463 // the input tensor is NULL, then an unknown shape is returned.
464 Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
465
466 // Like the function above, but treats scalar values as unknown
467 // shapes. **NOTE** If the scalar is statically known, its value
468 // must be -1 or an error is returned.
469 Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
470 ShapeHandle* out);
471
472 // Returns in <out> a new shape corresponding to <proto>.
473 Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
474 ShapeHandle* out);
475
476 // Returns in <out> a new shape corresponding to <partial_shape>.
477 Status MakeShapeFromPartialTensorShape(
478 const PartialTensorShape& partial_shape, ShapeHandle* out);
479
480 // Returns in <out> a new shape corresponding to <shape>.
481 Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
482
483 // Returns a new dimension of the given size. The returned value is owned by
484 // this context.
485 inline DimensionHandle MakeDim(DimensionOrConstant d) {
486 return shape_manager_.MakeDim(d);
487 }
488
489 inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); }
490
491 // Returns in <val> a scalar value from an input tensor <t>. The input tensor
492 // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the
493 // input tensor is not NULL.
494 Status GetScalarFromTensor(const Tensor* t, int64* val);
495
496 // Returns a new dimension whose value is given by a scalar input tensor.
497 // The input tensor must be in host memory, since it is dereferenced to get
498 // the value.
499 Status MakeDimForScalarInput(int idx, DimensionHandle* out);
500
501 // Returns a new dimension whose value is given by a scalar input tensor.
502 // This allows for a negative input dimension given the rank of a separate
503 // tensor. This rank can be negative if unknown.
504 // The input tensor must be in host memory, since it is dereferenced to get
505 // the value.
506 Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank,
507 DimensionHandle* out);
508
509 // Look up the attr for the NodeDef being evaluated with name attr_name and
510 // set *value to its value. If no attr with attr_name is found in def(), or
511 // the attr does not have a matching type, a non-ok status will be returned.
512 template <class T>
513 Status GetAttr(StringPiece attr_name, T* value) const;
514
515 // Returns in <out> the result of dividing <dividend> by <divisor>.
516 // Returns an error if <divisor> is not positive or if <evenly_divisible>
517 // and <divisor> does not evenly divide <dividend>.
518 Status Divide(DimensionHandle dividend, DimensionOrConstant divisor,
519 bool evenly_divisible, DimensionHandle* out);
520
521 // Returns in <out> the sum of <first> and <second>.
522 Status Add(DimensionHandle first, DimensionOrConstant second,
523 DimensionHandle* out);
524
525 // Returns in <out> the dimension that is <first> minus <second>.
526 Status Subtract(DimensionHandle first, DimensionOrConstant second,
527 DimensionHandle* out);
528
529 // Returns in <out> the product of <first> and <second>.
530 Status Multiply(DimensionHandle first, DimensionOrConstant second,
531 DimensionHandle* out);
532
533 // Returns in <out> the minimum of <first> and <second>. If either <first> or
534 // <second> is zero the results is zero. Otherwise, if either <first> or
535 // <second> is unknown the results is unknown.
536 Status Min(DimensionHandle first, DimensionOrConstant second,
537 DimensionHandle* out);
538
539 // Returns in <out> the maximum of <first> and <second>. If either <first> or
540 // <second> is unknown the results is unknown.
541 Status Max(DimensionHandle first, DimensionOrConstant second,
542 DimensionHandle* out);
543
544 Status construction_status() const { return construction_status_; }
545
546 // Methods to propagate shape and dtype on edges of handles. Handles are the
547 // dtype DT_RESOURCE which can be used to access state stored in a
548 // ResourceManager. When ops (such as variables) consume these handles to
549 // produce tensors they might need to know side-information about the shapes
550 // and dtypes of tensors which can be accessed via the handle. These methods
551 // propagate that information. Output handle dtypes and shapes are ignored if
552 // the output tensor is not of type DT_RESOURCE.
553
554 // Merge the stored shapes and types corresponding to the input handle in
555 // position idx with the specified shapes and types. This requires idx to be
556 // in the [0, num_inputs) range.
557 //
558 // If the merge is successful and any of the new shapes differs from the old
559 // one, or any of the old dtypes was DT_INVALID, store the new shapes and
560 // return true. Return false otherwise.
561 //
562 // See 'MergeInput' function for full details and examples.
563 bool MergeInputHandleShapesAndTypes(
564 int idx,
565 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
566
567 // As MergeInputHandleShapesAndTypes, but for an output.
568 bool MergeOutputHandleShapesAndTypes(
569 int idx,
570 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
571
572 // Relaxes the stored shapes and types corresponding to the input handle in
573 // position idx with the specified shapes and types. This requires idx to be
574 // in the [0, num_inputs) range.
575 //
576 // If the relax is successful and any of the new shapes differs from the old
577 // one, or any of the old dtypes was DT_INVALID, store the new shapes and
578 // return true. Return false otherwise.
579 //
580 // See 'RelaxInput' function for full details and examples.
581 bool RelaxInputHandleShapesAndMergeTypes(
582 int idx,
583 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
584
585 // As RelaxInputHandleShapesAndTypes, but for an output.
586 bool RelaxOutputHandleShapesAndMergeTypes(
587 int idx,
588 const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
589
590 // Returns the output handle shapes and types, for the resource tensor output
591 // at index <idx>. Returns NULL if the shape and types were never set.
592 const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
593 return output_handle_shapes_and_types_[idx].get();
594 }
595
596 // Returns the inputs handle shapes and types, for the resource tensor output
597 // at index <idx>. Returns NULL if the shape and types were not available.
598 const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
599 return input_handle_shapes_and_types_[idx].get();
600 }
601
602 void set_output_handle_shapes_and_types(
603 int idx, const std::vector<ShapeAndType>& shapes_and_types) {
604 output_handle_shapes_and_types_[idx].reset(
605 new std::vector<ShapeAndType>(shapes_and_types));
606 }
607
608 // Note that shape functions should usually call MakeShapeFromShapeTensor,
609 // as it does more analysis to provide partial shapes.
610 //
611 // Returns in <out> a new shape whose dimension sizes come from tensor <t>.
612 // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL,
613 // then an unknown shape is returned.
614 Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
615 ShapeHandle* out);
616
617 int graph_def_version() const { return graph_def_version_; }
618
619 const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const {
620 return merged_shapes_;
621 }
622 const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims()
623 const {
624 return merged_dims_;
625 }
626
627 private:
628 // Creates and stores shapes for use in InferenceContext.
629 class ShapeManager {
630 public:
631 ShapeManager();
632 ~ShapeManager();
633
634 // Returns a new shape with the given dims. The returned value is owned by
635 // this class.
636 ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims);
637
638 // Returns a new unknown shape.
639 ShapeHandle UnknownShape();
640
641 // Returns a new dimension of the given size. The returned value
642 // is owned by this class.
643 inline DimensionHandle MakeDim(DimensionOrConstant d) {
644 if (d.dim.IsSet()) {
645 return d.dim;
646 } else {
647 all_dims_.push_back(new Dimension(d.val));
648 return all_dims_.back();
649 }
650 }
651
652 private:
653 std::vector<Shape*> all_shapes_; // values are owned.
654 std::vector<Dimension*> all_dims_; // values are owned.
655 };
656
657 friend class ::tensorflow::grappler::GraphProperties;
658
659 // Friend for user-defined function shape inference purposes.
660 friend class ::tensorflow::ShapeRefiner;
661
662 friend class ShapeInferenceTest; // For testing Relax functions.
663 friend class ShapeInferenceTestutil; // For testing shapes.
664
665 // Shared initialization across the two constructors. Remove
666 // once we get rid of one of them.
667 void PreInputInit(const OpDef& op_def,
668 const std::vector<const Tensor*>& input_tensors,
669 const std::vector<ShapeHandle>& input_tensors_as_shapes);
670 void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
671 input_handle_data);
672
673 DimensionHandle GetDimension(const DimensionOrConstant& d);
674
675 Status ReturnUnknownShape(ShapeHandle* out) {
676 *out = UnknownShape();
677 return Status::OK();
678 }
679 Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims,
680 ShapeHandle* out) {
681 *out = MakeShape(dims);
682 return Status::OK();
683 }
684
685 // Adds additional context to the given status.
686 Status AttachContext(const Status& status);
687
688 // Relaxes an existing value <d_old> with a new value <d_new> and returns the
689 // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible
690 // values, returns an error.
691 //
692 // Note that <*out> may be set to <d_old> or <d_new>.
693 void Relax(DimensionHandle d_old, DimensionHandle d_new,
694 DimensionHandle* out);
695 // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the
696 // relaxed shape in <*out>. See 'RelaxInput' function for full details and
697 // examples.
698 void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out);
699
700 // Used to implement MergeInputHandleShapesAndTypes and
701 // MergeOutputHandleShapesAndTypes.
702 bool MergeHandleShapesAndTypes(
703 const std::vector<ShapeAndType>& shapes_and_types,
704 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
705 // Used to implement RelaxInputHandleShapesAndMergeTypes and
706 // RelaxOutputHandleShapesAndMergeTypes.
707 bool RelaxHandleShapesAndMergeTypes(
708 const std::vector<ShapeAndType>& shapes_and_types,
709 std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
710
711 // Forget all the previous merged shapes and dims.
712 void ForgetMerges() {
713 merged_shapes_.clear();
714 merged_dims_.clear();
715 }
716
717 // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
718 Status InternalMakeShapeFromTensor(
719 bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
720 ShapeHandle tensor_shape, ShapeHandle* out);
721
722 ShapeManager shape_manager_;
723
724 // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
725 // `shape_manager_`.
726 std::vector<ShapeHandle> inputs_;
727 std::vector<const Tensor*> input_tensors_;
728 std::vector<bool> requested_input_tensor_;
729 std::vector<ShapeHandle> outputs_;
730 // Can have fewer elements than inputs_.
731 std::vector<ShapeHandle> input_tensors_as_shapes_;
732 std::vector<bool> requested_input_tensor_as_partial_shape_;
733
734 // input_handle_shapes_and_types_[i] is the list of shape/type pairs available
735 // through the resource handle passed along input i of the node.
736 //
737 // Values may be NULL.
738 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
739 input_handle_shapes_and_types_;
740
741 // output_handle_shapes_and_types_[i] is the list of shape/type pairs
742 // available through the resource handle passed along output i of the node.
743 //
744 // Values may be NULL.
745 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
746 output_handle_shapes_and_types_;
747
748 const int graph_def_version_;
749 const NodeDef* node_def_;
750 NameRangeMap input_name_map_;
751 NameRangeMap output_name_map_;
752
753 // An error set during construction. TODO(cwhipkey): remove when test
754 // constructor is removed.
755 Status construction_status_;
756
757 // Pair of shape or dim handles that are equivalent, ie that represent the
758 // same underlying shape of dimension. Note that for each pair at least one of
759 // the handles must contain an unknown shape, since we don't keep track of
760 // known shapes or dims here.
761 std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_;
762 std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_;
763
764 TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
765};
766
767// -----------------------------------------------------------------------------
768// Template and inline method implementations, please ignore
769
770inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
771inline Dimension::Dimension(int64 value) : value_(value) {
772 DCHECK(value >= 0 || value == InferenceContext::kUnknownDim)
773 << "Dimension must be non-negative or equal to "
774 "InferenceContext::kUnknownDim but got "
775 << value;
776}
777
778inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
779inline Shape::Shape(const std::vector<DimensionHandle>& dims)
780 : rank_(dims.size()), dims_(dims) {}
781
782inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim)
783 : dim(dim) {
784 DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension.";
785}
786
787inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
788 DCHECK(val >= 0 || val == InferenceContext::kUnknownDim)
789 << "Dimension must be non-negative or equal to "
790 "InferenceContext::kUnknownDim but got "
791 << val;
792}
793
794template <class T>
795Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
796 return GetNodeAttr(*node_def_, attr_name, value);
797}
798
799} // namespace shape_inference
800} // namespace tensorflow
801
802#endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
803