18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ 19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ 24 #include "tensorflow/compiler/xla/xla_data.pb.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/types.pb.h" 27 #include "tensorflow/core/lib/core/status.h" 44 XlaResource(Kind kind,
int arg_num,
string name, DataType type,
46 const xla::ComputationDataHandle& initial_value,
47 int64 tensor_array_size,
48 const std::set<string>& tensor_array_gradients);
50 XlaResource(
const XlaResource&) =
delete;
51 XlaResource(XlaResource&&) =
delete;
52 XlaResource& operator=(
const XlaResource&) =
delete;
53 XlaResource& operator=(XlaResource&&) =
delete;
55 Kind kind()
const {
return kind_; }
60 int arg_num()
const {
return arg_num_; }
63 const string& name()
const {
return name_; }
70 DataType type()
const {
return type_; }
75 const TensorShape& shape()
const {
return shape_; }
77 const xla::ComputationDataHandle& value()
const {
return value_; }
81 const xla::ComputationDataHandle& initial_value()
const {
82 return initial_value_;
86 bool initialized()
const {
return value_.handle() > 0; }
90 Status SetTypeAndShape(DataType type,
const TensorShape& shape);
94 Status SetValue(
const xla::ComputationDataHandle& value);
103 Status GetOrCreateTensorArrayGradient(
const string& source,
105 XlaResource** gradient_out);
112 Status Pack(xla::ComputationDataHandle* pack,
121 Status SetFromPack(
const std::set<string>& gradient_sources,
122 const xla::ComputationDataHandle& pack,
131 int64 tensor_array_size()
const {
return tensor_array_size_; }
132 void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
140 const std::map<string, std::unique_ptr<XlaResource>>& tensor_array_gradients()
142 return tensor_array_gradients_;
152 xla::ComputationDataHandle value_;
153 xla::ComputationDataHandle initial_value_;
155 int64 tensor_array_size_ = -1;
157 std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
162 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ Definition: computation_builder.h:59
Definition: compile.cc:35