tf_1.8_xla_doc
service.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7  http://www.apache.org/licenses/LICENSE-2.0
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
15 #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
16 #include <functional>
17 #include <memory>
18 #include <string>
19 #include <vector>
20 #include "tensorflow/compiler/xla/executable_run_options.h"
21 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
23 #include "tensorflow/compiler/xla/service/backend.h"
24 #include "tensorflow/compiler/xla/service/channel_tracker.h"
25 #include "tensorflow/compiler/xla/service/compilation_cache.h"
27 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
28 #include "tensorflow/compiler/xla/service/executable.h"
29 #include "tensorflow/compiler/xla/service/execution_tracker.h"
30 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
32 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
33 #include "tensorflow/compiler/xla/service/session.pb.h"
36 #include "tensorflow/compiler/xla/service_interface.h"
37 #include "tensorflow/compiler/xla/statusor.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/xla.pb.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/gtl/array_slice.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/macros.h"
44 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
48 namespace xla {
49 // Options to configure the service when it is created.
50 class ServiceOptions {
51  public:
52  // Set the platform backing the service, or nullptr for the default platform.
53  ServiceOptions& set_platform(perftools::gputools::Platform* platform);
54  perftools::gputools::Platform* platform() const;
55  // Set the number of replicas to use when compiling replicated
56  // programs.
57  ServiceOptions& set_number_of_replicas(int number_of_replicas);
58  int number_of_replicas() const;
59  // Sets the thread pool size for parallel execution of an individual operator.
60  ServiceOptions& set_intra_op_parallelism_threads(int num_threads);
61  int intra_op_parallelism_threads() const;
62  private:
63  perftools::gputools::Platform* platform_ = nullptr;
64  int number_of_replicas_ = 1;
65  int intra_op_parallelism_threads_ = -1;
66 };
74 class Service : public ServiceInterface {
75  public:
76  // Factory method for creating a new Service.
77  static StatusOr<std::unique_ptr<Service>> NewService(
78  perftools::gputools::Platform* platform = nullptr);
79  static StatusOr<std::unique_ptr<Service>> NewService(
80  const ServiceOptions& options);
81  // Creates a new computation with the given name.
82  // A unique ComputationHandle is returned.
83  tensorflow::Status Computation(const ComputationRequest* arg,
84  ComputationResponse* result) override;
85  // Unregisters a previously-allocated global handle.
86  //
87  // If the handle given is not currently allocated, a NOT_FOUND status is
88  // returned.
89  tensorflow::Status Unregister(const UnregisterRequest* arg,
90  UnregisterResponse* result) override;
91  // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each
92  // element in the tuple.
93  tensorflow::Status DeconstructTuple(
94  const DeconstructTupleRequest* arg,
95  DeconstructTupleResponse* result) override;
96  // Modifies the provided computation so that subsequent executions
97  // will compute the provided ComputationDataHandle, rather than the
98  // last expression enqueued on that Computation.
99  tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg,
100  SetReturnValueResponse* results) override;
101  // Executes a computation with the provided global data passed as
102  // immutable arguments. Returns global data output and execution timing.
103  tensorflow::Status Execute(const ExecuteRequest* arg,
104  ExecuteResponse* result) override;
105  // Executes a computation with the provided global data passed as
106  // immutable arguments. The request contains the whole computation graph.
107  // Returns global data output and execution timing.
108  //
109  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
110  tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg,
111  ExecuteResponse* result) override;
112  // Executes one or more computations in parallel with the provided global data
113  // passed as immutable arguments. Returns global data output for each
114  // computation.
115  tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
116  ExecuteParallelResponse* result) override;
117  // Executes one or more computations in parallel with the provided global data
118  // passed as immutable arguments. Returns global data output for each
119  // computation.
120  //
121  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
122  tensorflow::Status ExecuteGraphParallel(
123  const ExecuteGraphParallelRequest* arg,
124  ExecuteParallelResponse* result) override;
125  // Requests one or more device handles from the target.
126  //
127  // When N device handles are requested and the number of replicas is R, at
128  // least N * R devices must be available. The devices are assigned based on
129  // the device ordinals such that the first R available devices are assigned to
130  // the first set of replicas, and the next R devices to the second set of
131  // replicas, etc. Each returned device handle represents the device with the
132  // replica id 0.
133  tensorflow::Status GetDeviceHandles(
134  const GetDeviceHandlesRequest* arg,
135  GetDeviceHandlesResponse* result) override;
136  // Asynchronously executes a computation with provided arguments. Invokes
137  // the provided computation with the provided global data passed as
138  // immutable arguments. Returns a handle to the execution.
139  //
140  // (Note: The corresponding function in xla::Client was removed as part of
141  // b/64116060, in an attempt to simplify our API. We're keeping this around
142  // for now in case we want to expose this to clients in a different way.)
143  tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
144  ExecuteAsyncResponse* result) override;
145  // Waits until the specified execution is complete and returns the result.
146  // Calling this API multiple times with the same execution handle returns the
147  // method with an error since the execution handle is destroyed after the
148  // first call.
149  tensorflow::Status WaitForExecution(
150  const WaitForExecutionRequest* arg,
151  WaitForExecutionResponse* result) override;
152  // Requests that global data be transferred to the client in literal form.
153  tensorflow::Status TransferToClient(
154  const TransferToClientRequest* arg,
155  TransferToClientResponse* result) override;
156  // Transfers data from a literal provided by the client, into device memory.
157  tensorflow::Status TransferToServer(
158  const TransferToServerRequest* arg,
159  TransferToServerResponse* result) override;
160  // Transfers data from a literal provided by the client, into the Infeed
161  // buffer of the device.
162  tensorflow::Status TransferToInfeed(
163  const TransferToInfeedRequest* arg,
164  TransferToInfeedResponse* result) override;
165  // Transfers data from the Outfeed othe device to the literal provided by the
166  // client.
167  tensorflow::Status TransferFromOutfeed(
168  const TransferFromOutfeedRequest* arg,
169  TransferFromOutfeedResponse* result) override;
170  // Resets devices, clearing all existing state on all the devices associated
171  // with this service (including memory allocated on the devices).
172  //
173  // ResetDevice may only be called where no previous Execution state on the
174  // device is used by the next Execution.
175  //
176  // ResetDevice should be called before an Execution that expect the device to
177  // be in the reset state. For example, if the prior Execution modifies device
178  // state (e.g., architectural state) that the next Execution depends on.
179  tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
180  ResetDeviceResponse* result) override;
181  // Tests if an expression is a compile-time constant.
182  tensorflow::Status IsConstant(const IsConstantRequest* arg,
183  IsConstantResponse* result) override;
184  // Computes the value of a constant expression.
185  tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg,
186  ComputeConstantResponse* result) override;
187  // Returns the shape (with layout) of an array associated with a given data
188  // handle.
189  tensorflow::Status GetShape(const GetShapeRequest* arg,
190  GetShapeResponse* result) override;
191  // Returns the program shape of the computation associated with the given
192  // handle.
193  tensorflow::Status GetComputationShape(
194  const GetComputationShapeRequest* arg,
195  GetComputationShapeResponse* result) override;
197  // Computation-oriented methods.
198  // Enqueues an Op on the computation.
199  tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override;
200  // Retrieves the inferred shape for a value within a computation.
201  tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg,
202  GetLocalShapeResponse* result) override;
203  // Retrieves the statistics of a computation.
204  tensorflow::Status GetComputationStats(
205  const ComputationStatsRequest* arg,
206  ComputationStatsResponse* result) override;
207  // Retrieves the statistics of a computation.
208  //
209  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
210  tensorflow::Status GetComputationGraphStats(
211  const ComputationGraphStatsRequest* arg,
212  ComputationStatsResponse* result) override;
213  // Snapshots the current state of a computation handle into a serializable
214  // protocol buffer form, so it can be loaded via
215  // LoadComputationSnapshot.
216  tensorflow::Status SnapshotComputation(
217  const SnapshotComputationRequest* arg,
218  SnapshotComputationResponse* result) override;
219  // Loads a computation from a serialized protocol buffer created via
220  // SnapshotComputation.
221  tensorflow::Status LoadComputationSnapshot(
222  const LoadComputationSnapshotRequest* arg,
223  LoadComputationSnapshotResponse* result) override;
224  // Creates a unique channel handle that can be used for Send/Recv
225  // instructions.
226  tensorflow::Status CreateChannelHandle(
227  const CreateChannelHandleRequest* arg,
228  CreateChannelHandleResponse* result) override;
229  // Returns the ComputationTracker of the current service instance.
230  // Only used in unit tests to access user computations from client.
231  const ComputationTracker& computation_tracker() {
232  return computation_tracker_;
233  }
234  // Returns the backend used to execute computations.
235  const Backend& backend() const { return *execute_backend_; }
236  Backend* mutable_backend() { return execute_backend_.get(); }
237  private:
238  // A private overload for Service itself, used by other methods within this
239  // class.
240  StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
241  const ProgramShape& program_shape,
242  tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
243  const ExecutionOptions& execution_options,
244  const UserComputation* user_computation = nullptr);
245  // Picks a parallel response and fills the result.
246  Status PickParallelResponse(const ExecuteParallelResponse& parallel_result,
247  ExecuteResponse* result);
248  // Prepare the executors for executing parallel.
249  StatusOr<std::vector<perftools::gputools::StreamExecutor*>> GetExecutors(
250  const ExecutionOptions& execution_options, int64 requests_size,
251  int64 request_index) const;
252  // Prepare the arguments for executing parallel.
253  StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
254  const ExecutionOptions& execution_options,
255  tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
256  protected:
257  friend class LocalExecutable;
258  // The constructor is private. Use the NewService factory to create new
259  // service objects.
260  Service(const ServiceOptions& options,
261  std::unique_ptr<Backend> execute_backend);
262  // Resolves the given argument handles in the allocation tracker and returns
263  // the corresponding allocations for every replica. The function also verifies
264  // that each allocation matches the execution platform and device ordinal of
265  // the corresponding replica.
266  StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
267  ResolveAndValidateArguments(
268  tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
269  tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
270  stream_executors);
271  // Create a Hlo module config for the given program shape and arguments.
272  // execution_options is optional; if not given a default is used.
273  StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
274  const ProgramShape& program_shape,
275  tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
276  const ExecutionOptions* execution_options,
277  const UserComputation* user_computation = nullptr);
278  // Builds an Executable for the given parameters.
279  //
280  // If device_allocator is not null, the compiler may use it to allocate temp
281  // buffers, which the compiler is responsible for freeing. The allocator
282  // given here need not match the allocator used when running the executable.
283  StatusOr<std::unique_ptr<Executable>> BuildExecutable(
284  const VersionedComputationHandle& versioned_handle,
285  std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
286  perftools::gputools::StreamExecutor* executor,
287  DeviceMemoryAllocator* device_allocator = nullptr);
288  // Builds an Executable for the given HLO module proto.
289  //
290  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
291  StatusOr<std::unique_ptr<Executable>> BuildExecutable(
292  const HloModuleProto& module_proto,
293  std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
294  perftools::gputools::StreamExecutor* executor,
295  DeviceMemoryAllocator* device_allocator = nullptr);
296  // Same as BuildExecutable() above, but builds a list of Executables for the
297  // given computations that may interact with each other.
298  StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
299  std::vector<VersionedComputationHandle> versioned_handles,
300  std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
301  Backend* backend,
302  std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
303  DeviceMemoryAllocator* device_allocator);
304  StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
305  const std::vector<const HloModuleProto*>& module_protos,
306  std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
307  Backend* backend,
308  std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
309  DeviceMemoryAllocator* device_allocator);
310  // Similar to BuildExecutable, but look in the compilation cache for the
311  // executable first. If the executable is not in the cache, it is built and
312  // inserted into the cache.
313  StatusOr<std::shared_ptr<Executable>> BuildAndCacheExecutable(
314  const VersionedComputationHandle& versioned_handle,
315  std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
316  perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile,
317  DeviceMemoryAllocator* device_allocator = nullptr);
318  // Runs the given executable with the given arguments and register the result
319  // in the allocation tracker. The handle of the result from the tracker is
320  // returned. If the parameter "profile" is not null, it points to an
321  // ExecutionProfile object which will be filled in with profile data.
322  StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
323  Executable* executable,
324  const tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>>
325  arguments,
326  Backend* backend, const string& result_tag, ExecutionProfile* profile);
327  // Runs the given executables with the given arguments and register the result
328  // from each executable in the allocation tracker. The handles of the result
329  // from the tracker are returned.
330  StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
331  tensorflow::gtl::ArraySlice<Executable*> executables,
332  tensorflow::gtl::ArraySlice<std::vector<std::vector<const ShapedBuffer*>>>
333  arguments,
334  Backend* backend,
335  tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
336  tensorflow::gtl::ArraySlice<string> result_tags,
337  ExecutionProfile* profile);
338  // Convenience function for adding a function to a user computation.
339  template <typename RequestT, typename ResponseT>
340  tensorflow::Status AddInstruction(
341  const RequestT* arg, ResponseT* result,
342  const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
343  adder);
344  // Executes a single computation which has more than one target device.
345  // The N devices are expected to all return an empty tuple, but one, which
346  // will be the result of this computation.
347  tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg,
348  ExecuteResponse* result);
349  tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg,
350  ExecuteResponse* result);
351  // Convenience function which checks whether the given shape_with_layout
352  // (presumably passed by the client to set the result layout) is valid for the
353  // given computation result shape.
354  tensorflow::Status ValidateResultShapeWithLayout(
355  const Shape& shape_with_layout, const Shape& result_shape) const;
356  // Returns the stream executors assigned to the replicas represented by the
357  // given device handle. Each device_handle is a virtual replicated device that
358  // represents a set of physical devices for the replicas.
359  StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
360  const Backend& backend, const DeviceHandle& device_handle) const;
361  Status MaybeDumpHloModule(const HloModule& module) const;
362  // Returns the device handle that represents the replicated device for a
363  // single computation that is not model-parallelized.
364  DeviceHandle SingleComputationDeviceHandle() const;
365  ServiceOptions options_;
366  // Tracks computations built via the API.
367  ComputationTracker computation_tracker_;
368  // Tracks channels created via the API.
369  ChannelTracker channel_tracker_;
370  // Tracks allocations made via the API and computation execution.
371  AllocationTracker allocation_tracker_;
372  // Tracks asynchronously launched executions via the API.
373  ExecutionTracker execution_tracker_;
374  // Cache containing previously built Executables.
375  CompilationCache compilation_cache_;
376  // Backend to compile and execute computations on.
377  std::unique_ptr<Backend> execute_backend_;
378  TF_DISALLOW_COPY_AND_ASSIGN(Service);
379 };
380 } // namespace xla
381 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
tensorflow::Status Computation(const ComputationRequest *arg, ComputationResponse *result) override
Populates computation_ with a valid object. Used before any given operation is enqueued.
Definition: service.cc:169
Definition: versioned_computation_handle.h:37
Definition: service.h:74
Definition: user_computation.h:49
tensorflow::Status Op(const OpRequest *arg, OpResponse *result) override
Enqueue an Op on the computation.
Definition: service.cc:1445
Tracks computations for the XLA service. Registered with a xla::UserComputation instance and can be r...
Definition: computation_tracker.h:46
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52