tf_1.8_xla_doc
xla_compiler.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
20 
21 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
22 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_mgr.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/notification.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/public/version.h"
33 
34 namespace tensorflow {
35 
36 class XlaContext;
37 
89 class XlaCompiler {
90  public:
91  // Describes how to derive the value of each _Arg node in the graph/function
92  // being compiled. There must be one Argument for each _Arg index.
93  struct Argument {
94  enum Kind {
95  // Default value; not a valid kind.
96  kInvalid,
97 
98  // Argument is a compile-time constant. No associated runtime parameter.
99  kConstant,
100 
101  // Argument is a Variable, TensorArray, or Stack resource. Has an
102  // associated runtime parameter iff `initialized` is true.
103  kResource,
104 
105  // Argument is a run-time parameter.
106  kParameter,
107  };
108 
109  Kind kind = kInvalid;
110 
111  // The type of the argument. If the argument is a resource, this
112  // is the type of the variable's value, not DT_RESOURCE.
113  DataType type;
114 
115  // The shape of the argument. For:
116  // * a parameter: the shape of the parameter.
117  // * a constant: ignored; the shape given by constant_value is used
118  // instead.
119  // * an uninitialized resource: ignored. We don't yet know the shape of an
120  // uninitialized resource (otherwise we would have initialized it!)
121  // * an initialized variable: the shape of the variable's value.
122  // * an initialized TensorArray or Stack resource: the shape of an entry in
123  // the TensorArray/Stack. Note this is the size of a single entry, not the
124  // XLA data structure that represents the complete stack/array.
125  TensorShape shape;
126 
127  // The value of the argument, if it is a compile-time constant. Must be a
128  // host-memory tensor.
129  Tensor constant_value;
130 
131  // The name of this argument, used for debugging.
132  string name;
133 
134  // For a kResource, what kind of resource is it?
135  XlaResource::Kind resource_kind = XlaResource::kInvalid;
136 
137  // For a kResource, has this resource been initialized?
138  bool initialized = false;
139 
140  // For a TensorArray or Stack resource, what is the array's declared size?
141  // (Used for lazy initialization.)
142  int64 tensor_array_size = -1;
143 
144  // TensorArray resource parameters are passed as (array, gradient array 0,
145  // ..., gradient array k), where the gradient arrays are in the same order
146  // as `tensor_array_gradients`.
147  std::set<string> tensor_array_gradients;
148 
149  bool operator==(const Argument& other) const;
150  };
151 
152  // Options pertaining to an individual call to CompileGraph() or
153  // CompileFunction().
154  struct CompileOptions {
155  // If `use_tuple_arg` is true, a single tuple parameter will be used for all
156  // arguments; if false, each argument gets its own parameter.
157  bool use_tuple_arg = false;
158 
159  // If 'return_updated_values_for_all_resources' is true, then updated
160  // values of all resource arguments will be included in the
161  // 'resource_updates' of the computation, even if the resource was not
162  // modified by the computation. Used when compiling loop bodies to ensure
163  // the input and output signatures match.
164  bool return_updated_values_for_all_resources = false;
165 
166  // If 'resolve_compile_time_constants' is true, then outputs of a
167  // computation that are known to be compile-time constants will be returned
168  // as Tensors at compile-time, rather than as run-time outputs of the
169  // computation.
170  bool resolve_compile_time_constants = true;
171 
172  // True when compiling the entry computation, false for subcomputations
173  // (while, call, etc.)
174  bool is_entry_computation = true;
175  };
176 
177  struct OutputDescription {
178  // Type and shape of the output.
179  DataType type;
180  TensorShape shape;
181 
182  // Constant output value, if known to be constant at JIT compilation time.
183  // 'Tensor' is in host memory.
184  bool is_constant = false;
185  Tensor constant_value;
186  };
187 
188  // Describes a variable write side effect of the computation.
189  struct ResourceUpdate {
190  // Index of the input that contains the variable resource to write to.
191  int input_index;
192 
193  // Type and shape of the tensor to be written back.
194  // The `shape` field has the same meaning as the Argument::shape field.
195  DataType type;
196  TensorShape shape;
197 
198  // Was the value of the variable modified by the computation?
199  // (Always true, unless `return_updated_values_for_all_resources` is true.)
200  bool modified;
201 
202  // If the resource is a TensorArray, the set of gradients read or written.
203  std::set<string> tensor_array_gradients_accessed;
204  };
205 
206  struct CompilationResult {
207  // Vector that maps from the parameters of the XLA computation to their
208  // original argument positions. To handle compile-time constant inputs and
209  // resources, the parameters to the XLA computation may be a subset of the
210  // original arguments, and are not necessarily in the same order.)
211  std::vector<int> input_mapping;
212 
213  // Input shapes of the computation.
214  std::vector<xla::Shape> xla_input_shapes;
215 
216  // Output shape in XLA format. The output shape is always a tuple.
217  xla::Shape xla_output_shape;
218 
219  // TensorFlow shapes of outputs, together with the values of any
220  // constant arguments. Vector indexed by Tensorflow _Retval number,
221  // containing both constant and non-constant results.
222  std::vector<OutputDescription> outputs;
223 
224  // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
225  // matching RecvAtHost/SendFromHost Ops in the outer graph.
226  tf2xla::HostComputeMetadata host_compute_metadata;
227 
228  // Resources whose values were updated by the computation, ordered
229  // by return value position. Resource updates follow the non-constant
230  // results in the outputs of XLA computation.
231  std::vector<ResourceUpdate> resource_updates;
232 
233  // The XLA computation built from the tensorflow subgraph.
234  std::shared_ptr<xla::Computation> computation;
235  };
236 
237  struct Options {
238  // Name of the compilation device to use. Needs to be live only during
239  // XlaCompiler's constructor.
240  const DeviceType* device_type = nullptr;
241 
242  xla::Client* client = nullptr;
243 
244  // Function library in which to find function definitions. Must be non-null.
245  const FunctionLibraryDefinition* flib_def = nullptr;
246 
247  // The graph def version to be compiled.
248  int graph_def_version = TF_GRAPH_DEF_VERSION;
249 
250  // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
251  // for CPU.
252  bool allow_cpu_custom_calls = false;
253 
254  // If set, the XLA representation of variables represented to XLA as the
255  // shape given by this shape function. Variables are reshaped to this shape
256  // on write, and reshaped to their original shape on read.
257  std::function<TensorShape(const TensorShape&, DataType)>
258  variable_representation_shape_fn;
259 
260  // If not nullptr, populate_resource_manager is called with the
261  // compilation device's resource manager when the compilation
262  // device is created, and can be used to create metadata objects
263  // that can be accessed by XLA op kernels.
264  std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
265 
266  // If not nullptr, this memory allocator can be used by the compiler for
267  // temporary allocations it might want to make during compilation.
268  //
269  // For example, the compiler may want to try out different algorithms and
270  // choose the fastest one, and it might run those algorithms over buffers
271  // created using this allocator.
272  //
273  // The compiler can function correctly without an explicit allocator given
274  // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
275  // allocate most or all available memory on the device, leaving none for the
276  // compiler to access, unless it can use TensorFlow's allocator.
277  xla::DeviceMemoryAllocator* device_allocator = nullptr;
278  };
279 
280  explicit XlaCompiler(Options options);
281 
282  ~XlaCompiler();
283 
284  Status CompileFunction(const CompileOptions& options,
285  const NameAttrList& fn_name_attrs,
286  std::vector<Argument> args, CompilationResult* result);
287 
288  // Compiles a tensorflow::Graph into an xla::Computation.
289  // Similar to CompileFunction, but takes a Graph as input rather than a
290  // function.
291  Status CompileGraph(const CompileOptions& options, string const& name,
292  std::unique_ptr<Graph> graph,
293  const std::vector<Argument>& args,
294  CompilationResult* result);
295 
296  // Compiles a single Op, given by an OpKernelContext, into an
297  // xla::Computation. Similar to CompileFunction but takes a single Op as
298  // input.
299  Status CompileSingleOp(const CompileOptions& options, string const& name,
300  OpKernelContext* ctx,
301  const std::vector<Argument>& args,
302  CompilationResult* result);
303 
304  // Returns the shape of the XLA parameter for an argument 'arg'.
305  // See the class comment for more details about the argument passing
306  // convention.
307  Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
308 
309  // Retrieves the channel handle associated with `key`. Allocates
310  // a new channel handle if none exists.
311  // Channel handles can be used to communicate between different
312  // computations. Computations that communicate should be compiled with the
313  // same XlaCompiler.
314  Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
315 
316  // Sets the shapes and types for the device to host transfer associated with
317  // 'key'.
318  Status SetDeviceToHostMetadata(const string& key,
319  gtl::ArraySlice<DataType> types,
320  gtl::ArraySlice<TensorShape> shapes);
321 
322  // Gets the shapes the device to host transfer associated with 'key'.
323  Status GetDeviceToHostShapes(const string& key,
324  std::vector<TensorShape>* shapes) const;
325 
326  // Sets the shapes and types for the host to device transfer associated with
327  // 'key'.
328  Status SetHostToDeviceMetadata(const string& key,
329  gtl::ArraySlice<DataType> types,
330  gtl::ArraySlice<TensorShape> shapes);
331 
332  const Options& options() const { return options_; }
333  xla::Client* client() const { return options_.client; }
334  FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
335 
336  private:
337  // Sets the function body `fbody` to the one registered as `function`.
338  Status FindFunctionBody(const NameAttrList& function,
339  const FunctionBody** fbody);
340 
341  // Returns the optimized graph object in this function body.
342  std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
343 
344  // Builds XLA computations for each of the arguments to the computation.
345  // `args` are the arguments to the computation.
346  Status BuildArguments(const Graph& graph,
347  const std::vector<XlaCompiler::Argument>& args,
348  bool use_tuple_arg, xla::ComputationBuilder* builder,
349  XlaContext* context, std::vector<int>* arg_cores,
350  std::vector<XlaExpression>* arg_expressions,
351  std::vector<int>* input_mapping,
352  std::vector<xla::Shape>* input_shapes,
353  bool is_entry_computation);
354 
355  // Graph compiler needs to know how to get an optimized graph from a function
356  // body.
357  friend class GraphCompiler;
358  friend class XlaCompilerTest;
359 
360  Options options_;
361 
362  // Status set to non-OK in the constructor if initialization fails.
363  Status initialization_status_;
364 
365  // Returns the next step sequence number.
366  int64 NextStepId();
367 
368  // Internal sequence number for steps executed on the compilation device.
369  int64 next_step_id_;
370 
371  XlaCompilationDevice* device_; // Owned by device_mgr_
372  DeviceMgr device_mgr_;
373 
374  // To avoid copying the client's function library, use a local function
375  // library and runtime for functions created as part of the functionalize
376  // control flow transformation.
377  std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
378  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
379  std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
380 
381  FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_.
382  FunctionLibraryRuntime* flib_runtime_; // owned by pflr_.
383 
384  struct SignatureHash {
385  uint64 operator()(
386  const std::pair<string, std::vector<Argument>>& signature) const;
387  };
388 
389  std::unordered_map<std::pair<string, std::vector<Argument>>,
390  CompilationResult, SignatureHash>
391  cache_;
392 
393  std::unordered_map<string, xla::ChannelHandle> channels_;
394 
395  std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
396  std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
397 
398  TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
399 };
400 
401 } // namespace tensorflow
402 
403 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
Definition: computation_builder.h:59
Definition: xla_compiler.h:89
Status CompileGraph(const CompileOptions &options, string const &name, std::unique_ptr< Graph > graph, const std::vector< Argument > &args, CompilationResult *result)
Compiles a tensorflow::Graph into an xla::Computation.
Definition: xla_compiler.cc:656
Status BuildArguments(const Graph &graph, const std::vector< XlaCompiler::Argument > &args, bool use_tuple_arg, xla::ComputationBuilder *builder, XlaContext *context, std::vector< int > *arg_cores, std::vector< XlaExpression > *arg_expressions, std::vector< int > *input_mapping, std::vector< xla::Shape > *input_shapes, bool is_entry_computation)
Build XLA computations for each of the arguments to the UserComputation.
Definition: xla_compiler.cc:439
Definition: compile.cc:35