18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 21 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" 22 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 23 #include "tensorflow/compiler/xla/client/local_client.h" 24 #include "tensorflow/core/common_runtime/device.h" 25 #include "tensorflow/core/common_runtime/device_mgr.h" 26 #include "tensorflow/core/common_runtime/function.h" 27 #include "tensorflow/core/framework/function.h" 28 #include "tensorflow/core/platform/env.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/platform/notification.h" 31 #include "tensorflow/core/platform/thread_annotations.h" 32 #include "tensorflow/core/public/version.h" 109 Kind kind = kInvalid;
129 Tensor constant_value;
135 XlaResource::Kind resource_kind = XlaResource::kInvalid;
138 bool initialized =
false;
142 int64 tensor_array_size = -1;
147 std::set<string> tensor_array_gradients;
149 bool operator==(
const Argument& other)
const;
154 struct CompileOptions {
157 bool use_tuple_arg =
false;
164 bool return_updated_values_for_all_resources =
false;
170 bool resolve_compile_time_constants =
true;
174 bool is_entry_computation =
true;
177 struct OutputDescription {
184 bool is_constant =
false;
185 Tensor constant_value;
189 struct ResourceUpdate {
203 std::set<string> tensor_array_gradients_accessed;
206 struct CompilationResult {
211 std::vector<int> input_mapping;
214 std::vector<xla::Shape> xla_input_shapes;
217 xla::Shape xla_output_shape;
222 std::vector<OutputDescription> outputs;
226 tf2xla::HostComputeMetadata host_compute_metadata;
231 std::vector<ResourceUpdate> resource_updates;
234 std::shared_ptr<xla::Computation> computation;
240 const DeviceType* device_type =
nullptr;
242 xla::Client* client =
nullptr;
245 const FunctionLibraryDefinition* flib_def =
nullptr;
248 int graph_def_version = TF_GRAPH_DEF_VERSION;
252 bool allow_cpu_custom_calls =
false;
257 std::function<TensorShape(const TensorShape&, DataType)>
258 variable_representation_shape_fn;
264 std::function<Status(ResourceMgr*)>* populate_resource_manager =
nullptr;
277 xla::DeviceMemoryAllocator* device_allocator =
nullptr;
284 Status CompileFunction(
const CompileOptions& options,
285 const NameAttrList& fn_name_attrs,
286 std::vector<Argument> args, CompilationResult* result);
291 Status
CompileGraph(
const CompileOptions& options,
string const& name,
292 std::unique_ptr<Graph> graph,
293 const std::vector<Argument>& args,
294 CompilationResult* result);
299 Status CompileSingleOp(
const CompileOptions& options,
string const& name,
300 OpKernelContext* ctx,
301 const std::vector<Argument>& args,
302 CompilationResult* result);
307 Status XLAShapeForArgument(
const Argument& arg, xla::Shape* xla_shape);
314 Status GetChannelHandle(
const string& key, xla::ChannelHandle* channel);
318 Status SetDeviceToHostMetadata(
const string& key,
319 gtl::ArraySlice<DataType> types,
320 gtl::ArraySlice<TensorShape> shapes);
323 Status GetDeviceToHostShapes(
const string& key,
324 std::vector<TensorShape>* shapes)
const;
328 Status SetHostToDeviceMetadata(
const string& key,
329 gtl::ArraySlice<DataType> types,
330 gtl::ArraySlice<TensorShape> shapes);
332 const Options& options()
const {
return options_; }
333 xla::Client* client()
const {
return options_.client; }
334 FunctionLibraryRuntime* flib_runtime()
const {
return flib_runtime_; }
338 Status FindFunctionBody(
const NameAttrList&
function,
339 const FunctionBody** fbody);
342 std::unique_ptr<Graph> GetGraph(
const FunctionBody* fbody);
347 const std::vector<XlaCompiler::Argument>& args,
349 XlaContext* context, std::vector<int>* arg_cores,
350 std::vector<XlaExpression>* arg_expressions,
351 std::vector<int>* input_mapping,
352 std::vector<xla::Shape>* input_shapes,
353 bool is_entry_computation);
357 friend class GraphCompiler;
358 friend class XlaCompilerTest;
363 Status initialization_status_;
371 XlaCompilationDevice* device_;
372 DeviceMgr device_mgr_;
377 std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
378 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
379 std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
381 FunctionLibraryRuntime* local_flib_runtime_;
382 FunctionLibraryRuntime* flib_runtime_;
384 struct SignatureHash {
386 const std::pair<
string, std::vector<Argument>>& signature)
const;
389 std::unordered_map<std::pair<string, std::vector<Argument>>,
390 CompilationResult, SignatureHash>
393 std::unordered_map<string, xla::ChannelHandle> channels_;
395 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
396 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
403 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ Definition: computation_builder.h:59
Definition: xla_compiler.h:89
Status CompileGraph(const CompileOptions &options, string const &name, std::unique_ptr< Graph > graph, const std::vector< Argument > &args, CompilationResult *result)
Compiles a tensorflow::Graph into an xla::Computation.
Definition: xla_compiler.cc:656
Status BuildArguments(const Graph &graph, const std::vector< XlaCompiler::Argument > &args, bool use_tuple_arg, xla::ComputationBuilder *builder, XlaContext *context, std::vector< int > *arg_cores, std::vector< XlaExpression > *arg_expressions, std::vector< int > *input_mapping, std::vector< xla::Shape > *input_shapes, bool is_entry_computation)
Build XLA computations for each of the arguments to the UserComputation.
Definition: xla_compiler.cc:439
Definition: compile.cc:35