| 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ |
| 16 | #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ |
| 17 | |
| 18 | #include <vector> |
| 19 | |
| 20 | #include "tensorflow/core/framework/node_def_util.h" |
| 21 | #include "tensorflow/core/framework/tensor.h" |
| 22 | #include "tensorflow/core/lib/core/errors.h" |
| 23 | #include "tensorflow/core/lib/core/status.h" |
| 24 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
| 25 | #include "tensorflow/core/platform/macros.h" |
| 26 | |
| 27 | namespace tensorflow { |
| 28 | |
| 29 | class ShapeRefiner; |
| 30 | class ShapeRefinerTest; |
| 31 | |
| 32 | namespace grappler { |
| 33 | class GraphProperties; |
| 34 | class SymbolicShapeManager; |
| 35 | } // namespace grappler |
| 36 | |
| 37 | namespace shape_inference { |
| 38 | |
| 39 | struct DimensionOrConstant; |
| 40 | class InferenceContext; |
| 41 | |
| 42 | // Dimension values are accessed through InferenceContext. |
| 43 | class Dimension { |
| 44 | private: |
| 45 | Dimension(); |
| 46 | Dimension(int64 value); |
| 47 | ~Dimension() {} |
| 48 | |
| 49 | const int64 value_; |
| 50 | |
| 51 | friend class InferenceContext; |
| 52 | friend class ShapeManager; |
| 53 | TF_DISALLOW_COPY_AND_ASSIGN(Dimension); |
| 54 | }; |
| 55 | |
| 56 | class DimensionHandle { |
| 57 | public: |
| 58 | DimensionHandle() {} |
| 59 | bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; } |
| 60 | std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); } |
| 61 | |
| 62 | private: |
| 63 | DimensionHandle(const Dimension* dim) { ptr_ = dim; } |
| 64 | |
| 65 | const Dimension* operator->() const { return ptr_; } |
| 66 | bool IsSet() const { return ptr_ != nullptr; } |
| 67 | |
| 68 | const Dimension* ptr_ = nullptr; |
| 69 | |
| 70 | friend struct DimensionOrConstant; |
| 71 | friend class InferenceContext; |
| 72 | friend class ShapeInferenceTest; |
| 73 | friend class ShapeInferenceTestutil; |
| 74 | friend class ::tensorflow::ShapeRefinerTest; |
| 75 | friend class ShapeManager; |
| 76 | friend class ::tensorflow::grappler::GraphProperties; |
| 77 | friend class ::tensorflow::grappler::SymbolicShapeManager; |
| 78 | |
| 79 | // Intentionally copyable. |
| 80 | }; |
| 81 | |
| 82 | // Shape rank and dimensions are accessed through InferenceContext. |
| 83 | class Shape { |
| 84 | private: |
| 85 | Shape(); |
| 86 | Shape(const std::vector<DimensionHandle>& dims); |
| 87 | ~Shape() {} |
| 88 | |
| 89 | const int32 rank_; |
| 90 | const std::vector<DimensionHandle> dims_; |
| 91 | |
| 92 | friend class InferenceContext; |
| 93 | friend class ShapeManager; |
| 94 | friend class ::tensorflow::grappler::SymbolicShapeManager; |
| 95 | |
| 96 | TF_DISALLOW_COPY_AND_ASSIGN(Shape); |
| 97 | }; |
| 98 | |
| 99 | class ShapeHandle { |
| 100 | public: |
| 101 | ShapeHandle() {} |
| 102 | bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; } |
| 103 | std::size_t Handle() const { return reinterpret_cast<std::size_t>(ptr_); } |
| 104 | |
| 105 | private: |
| 106 | ShapeHandle(const Shape* shape) { ptr_ = shape; } |
| 107 | const Shape* operator->() const { return ptr_; } |
| 108 | bool IsSet() const { return ptr_ != nullptr; } |
| 109 | |
| 110 | const Shape* ptr_ = nullptr; |
| 111 | |
| 112 | friend class InferenceContext; |
| 113 | friend class ShapeInferenceTest; |
| 114 | friend class ShapeInferenceTestutil; |
| 115 | friend class ::tensorflow::ShapeRefinerTest; |
| 116 | friend class ShapeManager; |
| 117 | friend class ::tensorflow::grappler::SymbolicShapeManager; |
| 118 | |
| 119 | // Intentionally copyable. |
| 120 | }; |
| 121 | |
| 122 | // Struct used to allow functions to take DimensionHandle or a dimension value. |
| 123 | // Not meant to be constructed directly. |
| 124 | struct DimensionOrConstant { |
| 125 | public: |
| 126 | // Intentionally not explicit. |
| 127 | DimensionOrConstant(DimensionHandle dim); |
| 128 | |
| 129 | // val must be non-negative or InferenceContext::kUnknownDim. |
| 130 | DimensionOrConstant(int64 val); |
| 131 | |
| 132 | // dim takes precedence. If dim != nullptr, val is ignored. |
| 133 | DimensionHandle dim; |
| 134 | int64 val; |
| 135 | |
| 136 | private: |
| 137 | DimensionOrConstant(); |
| 138 | }; |
| 139 | |
| 140 | struct ShapeAndType { |
| 141 | ShapeAndType() {} |
| 142 | ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {} |
| 143 | |
| 144 | ShapeHandle shape; |
| 145 | DataType dtype = DT_INVALID; |
| 146 | }; |
| 147 | |
| 148 | // Shape inference functions registered on ops in REGISTER_OP implement |
| 149 | // their shape functions in terms of this InferenceContext. An InferenceContext |
| 150 | // is created by the framework and passed to a shape inference function. The |
| 151 | // shape inference function calls functions on the context, and should call |
| 152 | // set_output() to set the shape on all outputs. |
| 153 | // |
| 154 | // To infer shapes for user-defined functions see ShapeRefiner. |
| 155 | // |
| 156 | // All Shape* and Dimension* returned by functions of InferenceContext are owned |
| 157 | // by the InferenceContext. |
| 158 | class InferenceContext { |
| 159 | public: |
| 160 | static constexpr int64 kUnknownDim = -1; |
| 161 | static constexpr int32 kUnknownRank = -1; |
| 162 | |
| 163 | // <input_tensors> is NULL-padded to be the same size as <input_shapes>. |
| 164 | // |
| 165 | // Elements of <input_tensors_as_shapes> are used for when a shape function |
| 166 | // makes a call to MakeShapeFromShapeTensor; in particular, when the |
| 167 | // input_tensors[i] is nullptr but the shape represented by it is partially |
| 168 | // known from analysis of the graph. |
| 169 | // <input_tensors_as_shapes> can have fewer elements than <input_shapes>. |
| 170 | // Values of <input_tensors_as_shapes> do not need to outlive the context. |
| 171 | // |
| 172 | // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext. |
| 173 | InferenceContext(int graph_def_version, const NodeDef* node_def, |
| 174 | const OpDef& op_def, |
| 175 | const std::vector<ShapeHandle>& input_shapes, |
| 176 | const std::vector<const Tensor*>& input_tensors, |
| 177 | const std::vector<ShapeHandle>& input_tensors_as_shapes, |
| 178 | std::vector<std::unique_ptr<std::vector<ShapeAndType>>> |
| 179 | input_handle_shapes_and_types); |
| 180 | |
| 181 | // <input_tensors> is NULL-padded to be the same size as <input_shapes>. |
| 182 | // |
| 183 | // Elements of <input_tensors_as_shapes> are used for when a shape |
| 184 | // function makes a call to MakeShapeFromShapeTensor; in particular, when |
| 185 | // the input_tensors[i] is nullptr but the shape represented by it is |
| 186 | // partially known from analysis of the graph. <input_tensors_as_shapes> |
| 187 | // can have fewer elements than <input_shapes>. Values of |
| 188 | // <input_tensors_as_shapes> do not need to outlive the context. |
| 189 | // |
| 190 | // REQUIRES: <node_def> is not NULL, and must outlive the |
| 191 | // InferenceContext. |
| 192 | InferenceContext( |
| 193 | int graph_def_version, const NodeDef* node_def, const OpDef& op_def, |
| 194 | const std::vector<TensorShapeProto>& input_shapes, |
| 195 | const std::vector<const Tensor*>& input_tensors, |
| 196 | const std::vector<TensorShapeProto>& input_tensors_as_shapes, |
| 197 | const std::vector< |
| 198 | std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>& |
| 199 | input_handle_shapes_and_types); |
| 200 | |
| 201 | // <input_tensors> is NULL-padded to be the same size as <input_shapes>. |
| 202 | // |
| 203 | // Elements of <input_tensors_as_shapes> are used for when a shape |
| 204 | // function makes a call to MakeShapeFromShapeTensor; in particular, when |
| 205 | // the input_tensors[i] is nullptr but the shape represented by it is |
| 206 | // partially known from analysis of the graph. <input_tensors_as_shapes> |
| 207 | // can have fewer elements than <input_shapes>. Values of |
| 208 | // <input_tensors_as_shapes> do not need to outlive the context. |
| 209 | // |
| 210 | // REQUIRES: <node_def> is not NULL, and must outlive the |
| 211 | // InferenceContext. |
| 212 | InferenceContext( |
| 213 | int graph_def_version, const NodeDef* node_def, const OpDef& op_def, |
| 214 | const std::vector<PartialTensorShape>& input_shapes, |
| 215 | const std::vector<const Tensor*>& input_tensors, |
| 216 | const std::vector<PartialTensorShape>& input_tensors_as_shapes, |
| 217 | const std::vector<std::unique_ptr< |
| 218 | std::vector<std::pair<PartialTensorShape, DataType>>>>& |
| 219 | input_handle_shapes_and_types); |
| 220 | |
| 221 | ~InferenceContext(); |
| 222 | |
| 223 | // Runs the shape inference function 'fn' with 'this' as the |
| 224 | // argument, returns the status of the inference. |
| 225 | // |
| 226 | // On error, additional context is provided in the error message. |
| 227 | Status Run( |
| 228 | const std::function<Status(shape_inference::InferenceContext* c)>& fn); |
| 229 | |
| 230 | // Merge the stored shape of the input in position idx with <shape> according |
| 231 | // to the following rules: |
| 232 | // |
| 233 | // - If the ShapeHandles are the same or <shape> is unknown, there will be no |
| 234 | // change. Otherwise if the stored shape is unknown, the new shape will be |
| 235 | // <shape>. |
| 236 | // - If both shapes are known, then they must have the same rank. |
| 237 | // - For any one dimension, if the values for that dimension in both shapes |
| 238 | // are known, then the values must match. |
| 239 | // - If one shape has equal or more information than the other shape in every |
| 240 | // dimension, the new shape will become the shape with more information. |
| 241 | // - Example: merging [2,?] and [?,2] results in [2,2] |
| 242 | // - Example: [2,2] cannot be merged with [1,2] |
| 243 | // |
| 244 | // This requires idx to be in the [0, num_inputs) range. If the merge is |
| 245 | // successful, return true. Return false otherwise. |
| 246 | bool MergeInput(int idx, ShapeHandle shape) { |
| 247 | ShapeHandle new_shape; |
| 248 | if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false; |
| 249 | inputs_[idx] = new_shape; |
| 250 | return true; |
| 251 | } |
| 252 | |
| 253 | // Relax the stored shape of the input in position idx with <shape> according |
| 254 | // to the following rules: |
| 255 | // |
| 256 | // - If the ShapeHandles are the same then the stored shape will be returned. |
| 257 | // - If either of the ShapeHandles are unknown, then a new UnknownShape will |
| 258 | // be returned. A new shape must be returned because we cannot claim that |
| 259 | // the resulting shape is necessarily the same as either of the input |
| 260 | // shapes. |
| 261 | // - If the shapes both have known ranks but their ranks are different, a new |
| 262 | // UnknownShape will be returned. |
| 263 | // - For any one dimension, if the value for that dimension in either of the |
| 264 | // shapes is unknown, a new shape will be returned with a new UnknownDim in |
| 265 | // that dimension. |
| 266 | // - For any one dimension, if the values for that dimension in both shapes |
| 267 | // are known but do not match, a new shape will be returned with a new |
| 268 | // UnknownDim in that dimension. |
| 269 | // - If both shapes have the same known rank and match in every dimension, |
| 270 | // the stored shape will be returned. |
| 271 | // - Example: relaxing [2,?] and [?,2] results in [?,?] |
| 272 | // - Example: relaxing [2,2] and [3,2] results in [?,2] |
| 273 | // - Example: relaxing [2,2] with [1,2,3] results in ? |
| 274 | // |
| 275 | // This requires idx to be in the [0, num_inputs) range. If the relax is |
| 276 | // successful and the new shape differs from the old one, store the new |
| 277 | // shape and return true. Return false otherwise. |
| 278 | bool RelaxInput(int idx, ShapeHandle shape) { |
| 279 | ShapeHandle new_shape; |
| 280 | Relax(inputs_[idx], shape, &new_shape); |
| 281 | if (inputs_[idx].SameHandle(new_shape)) { |
| 282 | return false; |
| 283 | } |
| 284 | inputs_[idx] = new_shape; |
| 285 | return true; |
| 286 | } |
| 287 | |
| 288 | ShapeHandle input(int64 idx) const { return inputs_[idx]; } |
| 289 | Status input(StringPiece input_name, std::vector<ShapeHandle>* output) const; |
| 290 | int num_inputs() const { return inputs_.size(); } |
| 291 | |
| 292 | // Returns the input tensor at index <idx>, or nullptr if the input tensor is |
| 293 | // not available at the time of shape inference. |
| 294 | const Tensor* input_tensor(int idx) { |
| 295 | // Mark that this idx was requested. |
| 296 | requested_input_tensor_[idx] = true; |
| 297 | return input_tensors_[idx]; |
| 298 | } |
| 299 | |
| 300 | // Returns true iff input_tensor(idx) was called by the shape function. |
| 301 | bool requested_input_tensor(int idx) const { |
| 302 | return requested_input_tensor_[idx]; |
| 303 | } |
| 304 | |
| 305 | // Returns true if MakeShapeFromInputTensor was called but the constant |
| 306 | // input_tensor was not present. |
| 307 | bool requested_input_tensor_as_partial_shape(int idx) const { |
| 308 | return requested_input_tensor_as_partial_shape_[idx]; |
| 309 | } |
| 310 | |
| 311 | void set_input_tensors(const std::vector<const Tensor*>& input_tensors) { |
| 312 | input_tensors_ = input_tensors; |
| 313 | } |
| 314 | |
| 315 | void set_input_tensors_as_shapes( |
| 316 | const std::vector<ShapeHandle>& input_tensors_as_shapes) { |
| 317 | input_tensors_as_shapes_ = input_tensors_as_shapes; |
| 318 | } |
| 319 | |
| 320 | ShapeHandle output(int64 idx) const { return outputs_[idx]; } |
| 321 | void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; } |
| 322 | Status set_output(StringPiece output_name, |
| 323 | const std::vector<ShapeHandle>& shapes); |
| 324 | |
| 325 | int num_outputs() const { return outputs_.size(); } |
| 326 | ShapeHandle output(int idx) const { return outputs_[idx]; } |
| 327 | Status output(StringPiece output_name, |
| 328 | std::vector<ShapeHandle>* output) const; |
| 329 | |
| 330 | AttrSlice attrs() const { return AttrSlice(*node_def_); } |
| 331 | |
| 332 | string op() const; |
| 333 | |
| 334 | // idx can be negative for an offset from end of dimensions. |
| 335 | // idx must be in the range [-1 * s.rank, s.rank). |
| 336 | DimensionHandle Dim(ShapeHandle s, int64 idx) { |
| 337 | if (s->rank_ == kUnknownRank) { |
| 338 | return UnknownDim(); |
| 339 | } |
| 340 | return DimKnownRank(s, idx); |
| 341 | } |
| 342 | // As above, but asserts that the rank of the shape is known. |
| 343 | static DimensionHandle DimKnownRank(ShapeHandle s, int64 idx) { |
| 344 | CHECK_NE(s->rank_, kUnknownRank); |
| 345 | if (idx < 0) { |
| 346 | return s->dims_[s->dims_.size() + idx]; |
| 347 | } |
| 348 | return s->dims_[idx]; |
| 349 | } |
| 350 | |
| 351 | static int32 Rank(ShapeHandle s) { |
| 352 | DCHECK(s.IsSet()); |
| 353 | return s.IsSet() ? s->rank_ : kUnknownRank; |
| 354 | } |
| 355 | static bool RankKnown(ShapeHandle s) { |
| 356 | return (s.IsSet() && (Rank(s) != kUnknownRank)); |
| 357 | } |
| 358 | static inline int64 Value(DimensionOrConstant d) { |
| 359 | return d.dim.IsSet() ? d.dim->value_ : d.val; |
| 360 | } |
| 361 | static inline bool ValueKnown(DimensionOrConstant d) { |
| 362 | return Value(d) != kUnknownDim; |
| 363 | } |
| 364 | |
| 365 | // Fills the output proto with the shape defined by the handle. |
| 366 | // "proto" is expected to be empty prior to the call. |
| 367 | void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto); |
| 368 | |
| 369 | // Returns true if the rank and all dimensions of the Shape are known. |
| 370 | bool FullyDefined(ShapeHandle s); |
| 371 | |
| 372 | // Returns the total number of elements, or an unknown dimension for an |
| 373 | // incomplete shape. |
| 374 | DimensionHandle NumElements(ShapeHandle s); |
| 375 | |
| 376 | string DebugString(ShapeHandle s); |
| 377 | string DebugString(DimensionHandle d); |
| 378 | |
| 379 | // Describes the whole context, for debugging purposes. |
| 380 | string DebugString() const; |
| 381 | |
| 382 | // If <shape> has rank <rank>, or its rank is unknown, return OK and return |
| 383 | // the shape with asserted rank in <*out>. Otherwise return an error. |
| 384 | // |
| 385 | // Note that <*out> may be set to <shape>. |
| 386 | Status WithRank(ShapeHandle shape, int64 rank, |
| 387 | ShapeHandle* out) TF_MUST_USE_RESULT; |
| 388 | Status WithRankAtLeast(ShapeHandle shape, int64 rank, |
| 389 | ShapeHandle* out) TF_MUST_USE_RESULT; |
| 390 | Status WithRankAtMost(ShapeHandle shape, int64 rank, |
| 391 | ShapeHandle* out) TF_MUST_USE_RESULT; |
| 392 | |
| 393 | // If <dim> has value <value>, or its value is unknown, returns OK and returns |
| 394 | // the dimension with asserted value in <*out>. Otherwise returns an error. |
| 395 | // |
| 396 | // Note that <*out> may be set to <dim>. |
| 397 | Status WithValue(DimensionHandle dim, int64 value, |
| 398 | DimensionHandle* out) TF_MUST_USE_RESULT; |
| 399 | |
| 400 | // Merges <s0> and <s1> and returns the merged shape in <*out>. See |
| 401 | // 'MergeInput' function for full details and examples. |
| 402 | Status Merge(ShapeHandle s0, ShapeHandle s1, |
| 403 | ShapeHandle* out) TF_MUST_USE_RESULT; |
| 404 | |
| 405 | // Asserts that <s>'s rank >= <prefix>'s rank, and the first |
| 406 | // <prefix.rank> dimensions of <s> are compatible with the dimensions of |
| 407 | // <prefix>. |
| 408 | // Returns the merged results in <*s_out> and <*prefix_out>. |
| 409 | Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out, |
| 410 | ShapeHandle* prefix_out) TF_MUST_USE_RESULT; |
| 411 | |
| 412 | // Merges <d0> and <d1> and returns the merged dimension in <*out>. If <d0> |
| 413 | // and <d1> have incompatible values, returns an error. |
| 414 | // |
| 415 | // Note that <*out> may be set to <d0> or <d1>. |
| 416 | Status Merge(DimensionHandle d0, DimensionHandle d1, |
| 417 | DimensionHandle* out) TF_MUST_USE_RESULT; |
| 418 | |
| 419 | // Returns in <*out> a sub-shape of <s> with dimensions [start:]. |
| 420 | // <start> can be negative to index from the end of the shape. If <start> > |
| 421 | // rank of <s>, then an empty subshape is returned. |
| 422 | Status Subshape(ShapeHandle s, int64 start, |
| 423 | ShapeHandle* out) TF_MUST_USE_RESULT; |
| 424 | |
| 425 | // Returns in <*out> a sub-shape of <s>, with dimensions [start:end]. |
| 426 | // <start> and <end> can be negative, to index from the end of the shape. |
| 427 | // <start> and <end> are set to the rank of <s> if > rank of <s>. |
| 428 | Status Subshape(ShapeHandle s, int64 start, int64 end, |
| 429 | ShapeHandle* out) TF_MUST_USE_RESULT; |
| 430 | |
| 431 | // Returns in <*out> the result of appending the dimensions of <s2> to those |
| 432 | // of <s1>. |
| 433 | Status Concatenate(ShapeHandle s1, ShapeHandle s2, |
| 434 | ShapeHandle* out) TF_MUST_USE_RESULT; |
| 435 | |
| 436 | // Returns in <out> the shape from replacing <s.dim[dim_index]> with |
| 437 | // <new_dim>. |
| 438 | Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim, |
| 439 | ShapeHandle* out) TF_MUST_USE_RESULT; |
| 440 | |
| 441 | // Returns a new shape with the given dims. The returned value is owned by |
| 442 | // this context. |
| 443 | ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims); |
| 444 | ShapeHandle MakeShape(std::initializer_list<DimensionOrConstant> dims); |
| 445 | |
| 446 | // Returns a new unknown shape. |
| 447 | ShapeHandle UnknownShape(); |
| 448 | |
| 449 | // Returns a shape with specified rank but unknown dims. |
| 450 | ShapeHandle UnknownShapeOfRank(int64 rank); |
| 451 | |
| 452 | // Returns a new shape of zero dimensions. |
| 453 | ShapeHandle Scalar(); |
| 454 | |
| 455 | // Returns a new shape of one dimension. |
| 456 | ShapeHandle Vector(DimensionOrConstant dim); |
| 457 | |
| 458 | // Returns a new shape of two dimensions. |
| 459 | ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2); |
| 460 | |
| 461 | // Returns in <out> a new shape whose dimension sizes come from input tensor |
| 462 | // <input_idx>. The tensor must be a 1-dimensional int32 or int64 tensor. If |
| 463 | // the input tensor is NULL, then an unknown shape is returned. |
| 464 | Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out); |
| 465 | |
| 466 | // Like the function above, but treats scalar values as unknown |
| 467 | // shapes. **NOTE** If the scalar is statically known, its value |
| 468 | // must be -1 or an error is returned. |
| 469 | Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx, |
| 470 | ShapeHandle* out); |
| 471 | |
| 472 | // Returns in <out> a new shape corresponding to <proto>. |
| 473 | Status MakeShapeFromShapeProto(const TensorShapeProto& proto, |
| 474 | ShapeHandle* out); |
| 475 | |
| 476 | // Returns in <out> a new shape corresponding to <partial_shape>. |
| 477 | Status MakeShapeFromPartialTensorShape( |
| 478 | const PartialTensorShape& partial_shape, ShapeHandle* out); |
| 479 | |
| 480 | // Returns in <out> a new shape corresponding to <shape>. |
| 481 | Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out); |
| 482 | |
| 483 | // Returns a new dimension of the given size. The returned value is owned by |
| 484 | // this context. |
| 485 | inline DimensionHandle MakeDim(DimensionOrConstant d) { |
| 486 | return shape_manager_.MakeDim(d); |
| 487 | } |
| 488 | |
| 489 | inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } |
| 490 | |
| 491 | // Returns in <val> a scalar value from an input tensor <t>. The input tensor |
| 492 | // must be a 1-dimensional int32 or int64 tensor. Caller must ensure that the |
| 493 | // input tensor is not NULL. |
| 494 | Status GetScalarFromTensor(const Tensor* t, int64* val); |
| 495 | |
| 496 | // Returns a new dimension whose value is given by a scalar input tensor. |
| 497 | // The input tensor must be in host memory, since it is dereferenced to get |
| 498 | // the value. |
| 499 | Status MakeDimForScalarInput(int idx, DimensionHandle* out); |
| 500 | |
| 501 | // Returns a new dimension whose value is given by a scalar input tensor. |
| 502 | // This allows for a negative input dimension given the rank of a separate |
| 503 | // tensor. This rank can be negative if unknown. |
| 504 | // The input tensor must be in host memory, since it is dereferenced to get |
| 505 | // the value. |
| 506 | Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank, |
| 507 | DimensionHandle* out); |
| 508 | |
| 509 | // Look up the attr for the NodeDef being evaluated with name attr_name and |
| 510 | // set *value to its value. If no attr with attr_name is found in def(), or |
| 511 | // the attr does not have a matching type, a non-ok status will be returned. |
| 512 | template <class T> |
| 513 | Status GetAttr(StringPiece attr_name, T* value) const; |
| 514 | |
| 515 | // Returns in <out> the result of dividing <dividend> by <divisor>. |
| 516 | // Returns an error if <divisor> is not positive or if <evenly_divisible> |
| 517 | // and <divisor> does not evenly divide <dividend>. |
| 518 | Status Divide(DimensionHandle dividend, DimensionOrConstant divisor, |
| 519 | bool evenly_divisible, DimensionHandle* out); |
| 520 | |
| 521 | // Returns in <out> the sum of <first> and <second>. |
| 522 | Status Add(DimensionHandle first, DimensionOrConstant second, |
| 523 | DimensionHandle* out); |
| 524 | |
| 525 | // Returns in <out> the dimension that is <first> minus <second>. |
| 526 | Status Subtract(DimensionHandle first, DimensionOrConstant second, |
| 527 | DimensionHandle* out); |
| 528 | |
| 529 | // Returns in <out> the product of <first> and <second>. |
| 530 | Status Multiply(DimensionHandle first, DimensionOrConstant second, |
| 531 | DimensionHandle* out); |
| 532 | |
| 533 | // Returns in <out> the minimum of <first> and <second>. If either <first> or |
| 534 | // <second> is zero the results is zero. Otherwise, if either <first> or |
| 535 | // <second> is unknown the results is unknown. |
| 536 | Status Min(DimensionHandle first, DimensionOrConstant second, |
| 537 | DimensionHandle* out); |
| 538 | |
| 539 | // Returns in <out> the maximum of <first> and <second>. If either <first> or |
| 540 | // <second> is unknown the results is unknown. |
| 541 | Status Max(DimensionHandle first, DimensionOrConstant second, |
| 542 | DimensionHandle* out); |
| 543 | |
| 544 | Status construction_status() const { return construction_status_; } |
| 545 | |
| 546 | // Methods to propagate shape and dtype on edges of handles. Handles are the |
| 547 | // dtype DT_RESOURCE which can be used to access state stored in a |
| 548 | // ResourceManager. When ops (such as variables) consume these handles to |
| 549 | // produce tensors they might need to know side-information about the shapes |
| 550 | // and dtypes of tensors which can be accessed via the handle. These methods |
| 551 | // propagate that information. Output handle dtypes and shapes are ignored if |
| 552 | // the output tensor is not of type DT_RESOURCE. |
| 553 | |
| 554 | // Merge the stored shapes and types corresponding to the input handle in |
| 555 | // position idx with the specified shapes and types. This requires idx to be |
| 556 | // in the [0, num_inputs) range. |
| 557 | // |
| 558 | // If the merge is successful and any of the new shapes differs from the old |
| 559 | // one, or any of the old dtypes was DT_INVALID, store the new shapes and |
| 560 | // return true. Return false otherwise. |
| 561 | // |
| 562 | // See 'MergeInput' function for full details and examples. |
| 563 | bool MergeInputHandleShapesAndTypes( |
| 564 | int idx, |
| 565 | const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; |
| 566 | |
| 567 | // As MergeInputHandleShapesAndTypes, but for an output. |
| 568 | bool MergeOutputHandleShapesAndTypes( |
| 569 | int idx, |
| 570 | const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; |
| 571 | |
| 572 | // Relaxes the stored shapes and types corresponding to the input handle in |
| 573 | // position idx with the specified shapes and types. This requires idx to be |
| 574 | // in the [0, num_inputs) range. |
| 575 | // |
| 576 | // If the relax is successful and any of the new shapes differs from the old |
| 577 | // one, or any of the old dtypes was DT_INVALID, store the new shapes and |
| 578 | // return true. Return false otherwise. |
| 579 | // |
| 580 | // See 'RelaxInput' function for full details and examples. |
| 581 | bool RelaxInputHandleShapesAndMergeTypes( |
| 582 | int idx, |
| 583 | const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; |
| 584 | |
| 585 | // As RelaxInputHandleShapesAndTypes, but for an output. |
| 586 | bool RelaxOutputHandleShapesAndMergeTypes( |
| 587 | int idx, |
| 588 | const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT; |
| 589 | |
| 590 | // Returns the output handle shapes and types, for the resource tensor output |
| 591 | // at index <idx>. Returns NULL if the shape and types were never set. |
| 592 | const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) { |
| 593 | return output_handle_shapes_and_types_[idx].get(); |
| 594 | } |
| 595 | |
| 596 | // Returns the inputs handle shapes and types, for the resource tensor output |
| 597 | // at index <idx>. Returns NULL if the shape and types were not available. |
| 598 | const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) { |
| 599 | return input_handle_shapes_and_types_[idx].get(); |
| 600 | } |
| 601 | |
| 602 | void set_output_handle_shapes_and_types( |
| 603 | int idx, const std::vector<ShapeAndType>& shapes_and_types) { |
| 604 | output_handle_shapes_and_types_[idx].reset( |
| 605 | new std::vector<ShapeAndType>(shapes_and_types)); |
| 606 | } |
| 607 | |
| 608 | // Note that shape functions should usually call MakeShapeFromShapeTensor, |
| 609 | // as it does more analysis to provide partial shapes. |
| 610 | // |
| 611 | // Returns in <out> a new shape whose dimension sizes come from tensor <t>. |
| 612 | // The tensor must be a 1-dimensional int32 or int64 tensor. If <t> is NULL, |
| 613 | // then an unknown shape is returned. |
| 614 | Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, |
| 615 | ShapeHandle* out); |
| 616 | |
| 617 | int graph_def_version() const { return graph_def_version_; } |
| 618 | |
| 619 | const std::vector<std::pair<ShapeHandle, ShapeHandle>>& MergedShapes() const { |
| 620 | return merged_shapes_; |
| 621 | } |
| 622 | const std::vector<std::pair<DimensionHandle, DimensionHandle>>& MergedDims() |
| 623 | const { |
| 624 | return merged_dims_; |
| 625 | } |
| 626 | |
| 627 | private: |
| 628 | // Creates and stores shapes for use in InferenceContext. |
| 629 | class ShapeManager { |
| 630 | public: |
| 631 | ShapeManager(); |
| 632 | ~ShapeManager(); |
| 633 | |
| 634 | // Returns a new shape with the given dims. The returned value is owned by |
| 635 | // this class. |
| 636 | ShapeHandle MakeShape(const std::vector<DimensionHandle>& dims); |
| 637 | |
| 638 | // Returns a new unknown shape. |
| 639 | ShapeHandle UnknownShape(); |
| 640 | |
| 641 | // Returns a new dimension of the given size. The returned value |
| 642 | // is owned by this class. |
| 643 | inline DimensionHandle MakeDim(DimensionOrConstant d) { |
| 644 | if (d.dim.IsSet()) { |
| 645 | return d.dim; |
| 646 | } else { |
| 647 | all_dims_.push_back(new Dimension(d.val)); |
| 648 | return all_dims_.back(); |
| 649 | } |
| 650 | } |
| 651 | |
| 652 | private: |
| 653 | std::vector<Shape*> all_shapes_; // values are owned. |
| 654 | std::vector<Dimension*> all_dims_; // values are owned. |
| 655 | }; |
| 656 | |
| 657 | friend class ::tensorflow::grappler::GraphProperties; |
| 658 | |
| 659 | // Friend for user-defined function shape inference purposes. |
| 660 | friend class ::tensorflow::ShapeRefiner; |
| 661 | |
| 662 | friend class ShapeInferenceTest; // For testing Relax functions. |
| 663 | friend class ShapeInferenceTestutil; // For testing shapes. |
| 664 | |
| 665 | // Shared initialization across the two constructors. Remove |
| 666 | // once we get rid of one of them. |
| 667 | void PreInputInit(const OpDef& op_def, |
| 668 | const std::vector<const Tensor*>& input_tensors, |
| 669 | const std::vector<ShapeHandle>& input_tensors_as_shapes); |
| 670 | void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> |
| 671 | input_handle_data); |
| 672 | |
| 673 | DimensionHandle GetDimension(const DimensionOrConstant& d); |
| 674 | |
| 675 | Status ReturnUnknownShape(ShapeHandle* out) { |
| 676 | *out = UnknownShape(); |
| 677 | return Status::OK(); |
| 678 | } |
| 679 | Status ReturnCreatedShape(const std::vector<DimensionHandle>& dims, |
| 680 | ShapeHandle* out) { |
| 681 | *out = MakeShape(dims); |
| 682 | return Status::OK(); |
| 683 | } |
| 684 | |
| 685 | // Adds additional context to the given status. |
| 686 | Status AttachContext(const Status& status); |
| 687 | |
| 688 | // Relaxes an existing value <d_old> with a new value <d_new> and returns the |
| 689 | // relaxed dimension in <*out>. If <d_old> and <d_new> have incompatible |
| 690 | // values, returns an error. |
| 691 | // |
| 692 | // Note that <*out> may be set to <d_old> or <d_new>. |
| 693 | void Relax(DimensionHandle d_old, DimensionHandle d_new, |
| 694 | DimensionHandle* out); |
| 695 | // Relaxes an existing shape <s_old> with a new shape <s_new> and returns the |
| 696 | // relaxed shape in <*out>. See 'RelaxInput' function for full details and |
| 697 | // examples. |
| 698 | void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out); |
| 699 | |
| 700 | // Used to implement MergeInputHandleShapesAndTypes and |
| 701 | // MergeOutputHandleShapesAndTypes. |
| 702 | bool MergeHandleShapesAndTypes( |
| 703 | const std::vector<ShapeAndType>& shapes_and_types, |
| 704 | std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT; |
| 705 | // Used to implement RelaxInputHandleShapesAndMergeTypes and |
| 706 | // RelaxOutputHandleShapesAndMergeTypes. |
| 707 | bool RelaxHandleShapesAndMergeTypes( |
| 708 | const std::vector<ShapeAndType>& shapes_and_types, |
| 709 | std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT; |
| 710 | |
| 711 | // Forget all the previous merged shapes and dims. |
| 712 | void ForgetMerges() { |
| 713 | merged_shapes_.clear(); |
| 714 | merged_dims_.clear(); |
| 715 | } |
| 716 | |
| 717 | // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor. |
| 718 | Status InternalMakeShapeFromTensor( |
| 719 | bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t, |
| 720 | ShapeHandle tensor_shape, ShapeHandle* out); |
| 721 | |
| 722 | ShapeManager shape_manager_; |
| 723 | |
| 724 | // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from |
| 725 | // `shape_manager_`. |
| 726 | std::vector<ShapeHandle> inputs_; |
| 727 | std::vector<const Tensor*> input_tensors_; |
| 728 | std::vector<bool> requested_input_tensor_; |
| 729 | std::vector<ShapeHandle> outputs_; |
| 730 | // Can have fewer elements than inputs_. |
| 731 | std::vector<ShapeHandle> input_tensors_as_shapes_; |
| 732 | std::vector<bool> requested_input_tensor_as_partial_shape_; |
| 733 | |
| 734 | // input_handle_shapes_and_types_[i] is the list of shape/type pairs available |
| 735 | // through the resource handle passed along input i of the node. |
| 736 | // |
| 737 | // Values may be NULL. |
| 738 | std::vector<std::unique_ptr<std::vector<ShapeAndType>>> |
| 739 | input_handle_shapes_and_types_; |
| 740 | |
| 741 | // output_handle_shapes_and_types_[i] is the list of shape/type pairs |
| 742 | // available through the resource handle passed along output i of the node. |
| 743 | // |
| 744 | // Values may be NULL. |
| 745 | std::vector<std::unique_ptr<std::vector<ShapeAndType>>> |
| 746 | output_handle_shapes_and_types_; |
| 747 | |
| 748 | const int graph_def_version_; |
| 749 | const NodeDef* node_def_; |
| 750 | NameRangeMap input_name_map_; |
| 751 | NameRangeMap output_name_map_; |
| 752 | |
| 753 | // An error set during construction. TODO(cwhipkey): remove when test |
| 754 | // constructor is removed. |
| 755 | Status construction_status_; |
| 756 | |
| 757 | // Pair of shape or dim handles that are equivalent, ie that represent the |
| 758 | // same underlying shape of dimension. Note that for each pair at least one of |
| 759 | // the handles must contain an unknown shape, since we don't keep track of |
| 760 | // known shapes or dims here. |
| 761 | std::vector<std::pair<ShapeHandle, ShapeHandle>> merged_shapes_; |
| 762 | std::vector<std::pair<DimensionHandle, DimensionHandle>> merged_dims_; |
| 763 | |
| 764 | TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext); |
| 765 | }; |
| 766 | |
| 767 | // ----------------------------------------------------------------------------- |
| 768 | // Template and inline method implementations, please ignore |
| 769 | |
| 770 | inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} |
| 771 | inline Dimension::Dimension(int64 value) : value_(value) { |
| 772 | DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) |
| 773 | << "Dimension must be non-negative or equal to " |
| 774 | "InferenceContext::kUnknownDim but got " |
| 775 | << value; |
| 776 | } |
| 777 | |
| 778 | inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {} |
| 779 | inline Shape::Shape(const std::vector<DimensionHandle>& dims) |
| 780 | : rank_(dims.size()), dims_(dims) {} |
| 781 | |
| 782 | inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim) |
| 783 | : dim(dim) { |
| 784 | DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension." ; |
| 785 | } |
| 786 | |
| 787 | inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) { |
| 788 | DCHECK(val >= 0 || val == InferenceContext::kUnknownDim) |
| 789 | << "Dimension must be non-negative or equal to " |
| 790 | "InferenceContext::kUnknownDim but got " |
| 791 | << val; |
| 792 | } |
| 793 | |
| 794 | template <class T> |
| 795 | Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { |
| 796 | return GetNodeAttr(*node_def_, attr_name, value); |
| 797 | } |
| 798 | |
| 799 | } // namespace shape_inference |
| 800 | } // namespace tensorflow |
| 801 | |
| 802 | #endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ |
| 803 | |