tf_1.8_xla_doc
xla_resource.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
20 
21 #include <memory>
22 
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/lib/core/status.h"
28 
32 namespace tensorflow {
33 
34 // Represents a resource, such as a Variable or TensorArray.
35 class XlaResource {
36  public:
37  enum Kind {
38  kInvalid,
39  kVariable,
40  kTensorArray,
41  kStack,
42  };
43 
44  XlaResource(Kind kind, int arg_num, string name, DataType type,
45  TensorShape shape,
46  const xla::ComputationDataHandle& initial_value,
47  int64 tensor_array_size,
48  const std::set<string>& tensor_array_gradients);
49 
50  XlaResource(const XlaResource&) = delete;
51  XlaResource(XlaResource&&) = delete;
52  XlaResource& operator=(const XlaResource&) = delete;
53  XlaResource& operator=(XlaResource&&) = delete;
54 
55  Kind kind() const { return kind_; }
56 
57  // If this resource is visible externally to the computation, what was its
58  // argument number?
59  // < 0 means "not visible externally".
60  int arg_num() const { return arg_num_; }
61 
62  // A descriptive name for the resource, used in error messages.
63  const string& name() const { return name_; }
64 
65  // Current type and value of the resource. Uninitialized resources are
66  // represented by a default (zero) handle and type DT_INVALID.
67  // While the type of a resource is notionally fixed during execution, when
68  // a resource is first initialized we do not yet know its type, so we keep
69  // track of its type dynamically.
70  DataType type() const { return type_; }
71 
72  // Shape of the resource. For an uninitialized resource, this is ignored.
73  // For a Variable, this is the shape of the value. For a TensorArray or Stack
74  // this is the shape of each entry in the TensorArray/Stack.
75  const TensorShape& shape() const { return shape_; }
76 
77  const xla::ComputationDataHandle& value() const { return value_; }
78 
79  // Value of the resource at computation entry. Used to detect which
80  // variables have new values that need to be written back.
81  const xla::ComputationDataHandle& initial_value() const {
82  return initial_value_;
83  }
84 
85  // A variable is initialized if it has a value.
86  bool initialized() const { return value_.handle() > 0; }
87 
88  // Sets the type and shape of the resource. The type and shape of a resource
89  // must not change once the variable has been initialized.
90  Status SetTypeAndShape(DataType type, const TensorShape& shape);
91 
92  // Sets the current value of the resource. Returns an error if the type is not
93  // set to a valid value.
94  Status SetValue(const xla::ComputationDataHandle& value);
95 
96  // Sets the current value of the resource to an all-zero value.
97  Status SetZeroValue(xla::ComputationBuilder* builder);
98 
99  // Looks up the gradient for `source`, or creates it if it does not already
100  // exist. The call target must be an initialized TensorArray resource. A
101  // TensorArray can have multiple named gradients; see the operator
102  // documentation for TensorArrayGradV3 for details.
103  Status GetOrCreateTensorArrayGradient(const string& source,
104  xla::ComputationBuilder* builder,
105  XlaResource** gradient_out);
106 
107  // Packs a resource into a single XLA value `pack`, suitable for use as
108  // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without
109  // gradients, sets `*pack` to `value`.
110  // For TensorArrays with gradients, packs the value and its gradient values in
111  // a tuple; the gradients values are packed in order by source name.
112  Status Pack(xla::ComputationDataHandle* pack,
113  xla::ComputationBuilder* builder) const;
114 
115  // Updates the resource with values from `pack`. If `gradient_sources` is
116  // non-empty, treats `pack` as a tuple that represents a TensorArray and
117  // its gradients, and unpacks and updates the gradient resources.
118  // If `reset_initial_values` is true, sets the initial_values as well as the
119  // values.
120  // Opposite of Pack().
121  Status SetFromPack(const std::set<string>& gradient_sources,
122  const xla::ComputationDataHandle& pack,
123  xla::ComputationBuilder* builder);
124 
125  // TensorArray and Stack specific fields
126 
127  // 'tensor_array_size' stores the expected size of the TensorArray or Stack.
128  // We need to store this since sometimes TensorArrays must be initialized
129  // lazily since we do not know the element shape at construction time.
130  // Used by both TensorArrays and Stacks.
131  int64 tensor_array_size() const { return tensor_array_size_; }
132  void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
133 
134  // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
135  // to an XlaResource containing the gradient TensorArrays. We store a pointer
136  // here since there should only be one gradient TensorArray per 'source'
137  // string, irrespective of the number of calls to TensorArrayGrad. The map
138  // is ordered since values are packed into tuples by Pack() sorted by name
139  // order.
140  const std::map<string, std::unique_ptr<XlaResource>>& tensor_array_gradients()
141  const {
142  return tensor_array_gradients_;
143  }
144 
145  private:
146  const Kind kind_;
147  const int arg_num_;
148  const string name_;
149 
150  DataType type_;
151  TensorShape shape_;
152  xla::ComputationDataHandle value_;
153  xla::ComputationDataHandle initial_value_;
154 
155  int64 tensor_array_size_ = -1;
156 
157  std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
158 };
159 
160 } // namespace tensorflow
161 
162 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
Definition: computation_builder.h:59
Definition: compile.cc:35