14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 15 #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 20 #include "tensorflow/compiler/xla/executable_run_options.h" 21 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" 22 #include "tensorflow/compiler/xla/service/allocation_tracker.h" 23 #include "tensorflow/compiler/xla/service/backend.h" 24 #include "tensorflow/compiler/xla/service/channel_tracker.h" 25 #include "tensorflow/compiler/xla/service/compilation_cache.h" 27 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 28 #include "tensorflow/compiler/xla/service/executable.h" 29 #include "tensorflow/compiler/xla/service/execution_tracker.h" 30 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 32 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 33 #include "tensorflow/compiler/xla/service/session.pb.h" 36 #include "tensorflow/compiler/xla/service_interface.h" 37 #include "tensorflow/compiler/xla/statusor.h" 38 #include "tensorflow/compiler/xla/types.h" 39 #include "tensorflow/compiler/xla/xla.pb.h" 40 #include "tensorflow/compiler/xla/xla_data.pb.h" 41 #include "tensorflow/core/lib/gtl/array_slice.h" 42 #include "tensorflow/core/platform/logging.h" 43 #include "tensorflow/core/platform/macros.h" 44 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 50 class ServiceOptions {
53 ServiceOptions& set_platform(perftools::gputools::Platform* platform);
54 perftools::gputools::Platform* platform()
const;
57 ServiceOptions& set_number_of_replicas(
int number_of_replicas);
58 int number_of_replicas()
const;
60 ServiceOptions& set_intra_op_parallelism_threads(
int num_threads);
61 int intra_op_parallelism_threads()
const;
63 perftools::gputools::Platform* platform_ =
nullptr;
64 int number_of_replicas_ = 1;
65 int intra_op_parallelism_threads_ = -1;
77 static StatusOr<std::unique_ptr<Service>> NewService(
78 perftools::gputools::Platform* platform =
nullptr);
79 static StatusOr<std::unique_ptr<Service>> NewService(
80 const ServiceOptions& options);
83 tensorflow::Status
Computation(
const ComputationRequest* arg,
84 ComputationResponse* result)
override;
89 tensorflow::Status Unregister(
const UnregisterRequest* arg,
90 UnregisterResponse* result)
override;
93 tensorflow::Status DeconstructTuple(
94 const DeconstructTupleRequest* arg,
95 DeconstructTupleResponse* result)
override;
99 tensorflow::Status SetReturnValue(
const SetReturnValueRequest* arg,
100 SetReturnValueResponse* results)
override;
103 tensorflow::Status Execute(
const ExecuteRequest* arg,
104 ExecuteResponse* result)
override;
110 tensorflow::Status ExecuteGraph(
const ExecuteGraphRequest* arg,
111 ExecuteResponse* result)
override;
115 tensorflow::Status ExecuteParallel(
const ExecuteParallelRequest* arg,
116 ExecuteParallelResponse* result)
override;
122 tensorflow::Status ExecuteGraphParallel(
123 const ExecuteGraphParallelRequest* arg,
124 ExecuteParallelResponse* result)
override;
133 tensorflow::Status GetDeviceHandles(
134 const GetDeviceHandlesRequest* arg,
135 GetDeviceHandlesResponse* result)
override;
143 tensorflow::Status ExecuteAsync(
const ExecuteAsyncRequest* arg,
144 ExecuteAsyncResponse* result)
override;
149 tensorflow::Status WaitForExecution(
150 const WaitForExecutionRequest* arg,
151 WaitForExecutionResponse* result)
override;
153 tensorflow::Status TransferToClient(
154 const TransferToClientRequest* arg,
155 TransferToClientResponse* result)
override;
157 tensorflow::Status TransferToServer(
158 const TransferToServerRequest* arg,
159 TransferToServerResponse* result)
override;
162 tensorflow::Status TransferToInfeed(
163 const TransferToInfeedRequest* arg,
164 TransferToInfeedResponse* result)
override;
167 tensorflow::Status TransferFromOutfeed(
168 const TransferFromOutfeedRequest* arg,
169 TransferFromOutfeedResponse* result)
override;
179 tensorflow::Status ResetDevice(
const ResetDeviceRequest* arg,
180 ResetDeviceResponse* result)
override;
182 tensorflow::Status IsConstant(
const IsConstantRequest* arg,
183 IsConstantResponse* result)
override;
185 tensorflow::Status ComputeConstant(
const ComputeConstantRequest* arg,
186 ComputeConstantResponse* result)
override;
189 tensorflow::Status GetShape(
const GetShapeRequest* arg,
190 GetShapeResponse* result)
override;
193 tensorflow::Status GetComputationShape(
194 const GetComputationShapeRequest* arg,
195 GetComputationShapeResponse* result)
override;
199 tensorflow::Status
Op(
const OpRequest* arg, OpResponse* result)
override;
201 tensorflow::Status GetLocalShape(
const GetLocalShapeRequest* arg,
202 GetLocalShapeResponse* result)
override;
204 tensorflow::Status GetComputationStats(
205 const ComputationStatsRequest* arg,
206 ComputationStatsResponse* result)
override;
210 tensorflow::Status GetComputationGraphStats(
211 const ComputationGraphStatsRequest* arg,
212 ComputationStatsResponse* result)
override;
216 tensorflow::Status SnapshotComputation(
217 const SnapshotComputationRequest* arg,
218 SnapshotComputationResponse* result)
override;
221 tensorflow::Status LoadComputationSnapshot(
222 const LoadComputationSnapshotRequest* arg,
223 LoadComputationSnapshotResponse* result)
override;
226 tensorflow::Status CreateChannelHandle(
227 const CreateChannelHandleRequest* arg,
228 CreateChannelHandleResponse* result)
override;
232 return computation_tracker_;
235 const Backend& backend()
const {
return *execute_backend_; }
236 Backend* mutable_backend() {
return execute_backend_.get(); }
240 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
241 const ProgramShape& program_shape,
242 tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
243 const ExecutionOptions& execution_options,
246 Status PickParallelResponse(
const ExecuteParallelResponse& parallel_result,
247 ExecuteResponse* result);
249 StatusOr<std::vector<perftools::gputools::StreamExecutor*>> GetExecutors(
250 const ExecutionOptions& execution_options, int64 requests_size,
251 int64 request_index)
const;
253 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
254 const ExecutionOptions& execution_options,
255 tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
257 friend class LocalExecutable;
260 Service(
const ServiceOptions& options,
261 std::unique_ptr<Backend> execute_backend);
266 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
267 ResolveAndValidateArguments(
268 tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
269 tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
273 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
274 const ProgramShape& program_shape,
275 tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
276 const ExecutionOptions* execution_options,
283 StatusOr<std::unique_ptr<Executable>> BuildExecutable(
285 std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
286 perftools::gputools::StreamExecutor* executor,
287 DeviceMemoryAllocator* device_allocator =
nullptr);
291 StatusOr<std::unique_ptr<Executable>> BuildExecutable(
292 const HloModuleProto& module_proto,
293 std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
294 perftools::gputools::StreamExecutor* executor,
295 DeviceMemoryAllocator* device_allocator =
nullptr);
298 StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
299 std::vector<VersionedComputationHandle> versioned_handles,
300 std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
302 std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
303 DeviceMemoryAllocator* device_allocator);
304 StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
305 const std::vector<const HloModuleProto*>& module_protos,
306 std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
308 std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors,
309 DeviceMemoryAllocator* device_allocator);
313 StatusOr<std::shared_ptr<Executable>> BuildAndCacheExecutable(
315 std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
316 perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile,
317 DeviceMemoryAllocator* device_allocator =
nullptr);
322 StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
323 Executable* executable,
324 const tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>>
326 Backend* backend,
const string& result_tag, ExecutionProfile* profile);
330 StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
331 tensorflow::gtl::ArraySlice<Executable*> executables,
332 tensorflow::gtl::ArraySlice<std::vector<std::vector<const ShapedBuffer*>>>
335 tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
336 tensorflow::gtl::ArraySlice<string> result_tags,
337 ExecutionProfile* profile);
339 template <
typename RequestT,
typename ResponseT>
340 tensorflow::Status AddInstruction(
341 const RequestT* arg, ResponseT* result,
342 const std::function<StatusOr<ComputationDataHandle>(
UserComputation*)>&
347 tensorflow::Status ExecuteOneToN(
const ExecuteRequest* arg,
348 ExecuteResponse* result);
349 tensorflow::Status ExecuteOneToN(
const ExecuteGraphRequest* arg,
350 ExecuteResponse* result);
354 tensorflow::Status ValidateResultShapeWithLayout(
355 const Shape& shape_with_layout,
const Shape& result_shape)
const;
359 StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
360 const Backend& backend,
const DeviceHandle& device_handle)
const;
361 Status MaybeDumpHloModule(
const HloModule& module)
const;
364 DeviceHandle SingleComputationDeviceHandle()
const;
365 ServiceOptions options_;
369 ChannelTracker channel_tracker_;
371 AllocationTracker allocation_tracker_;
373 ExecutionTracker execution_tracker_;
375 CompilationCache compilation_cache_;
377 std::unique_ptr<Backend> execute_backend_;
378 TF_DISALLOW_COPY_AND_ASSIGN(
Service);
381 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ tensorflow::Status Computation(const ComputationRequest *arg, ComputationResponse *result) override
Populates computation_ with a valid object. Used before any given operation is enqueued.
Definition: service.cc:169
Definition: versioned_computation_handle.h:37
Definition: user_computation.h:49
tensorflow::Status Op(const OpRequest *arg, OpResponse *result) override
Enqueue an Op on the computation.
Definition: service.cc:1445
Tracks computations for the XLA service. Registered with a xla::UserComputation instance and can be r...
Definition: computation_tracker.h:46
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52