14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ 15 #define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ 22 #include "tensorflow/compiler/xla/service/session.pb.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/compiler/xla/xla.pb.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/platform/thread_annotations.h" 31 #include "tensorflow/core/platform/types.h" 57 static StatusOr<std::unique_ptr<UserComputation>> MakeWithRemapping(
58 const SessionComputation& session_computation,
59 const ComputationHandle& handle,
60 const std::map<int64, ComputationHandle>& old_to_new);
62 explicit UserComputation(
const string& name,
const ComputationHandle& handle);
67 const ParameterRequest& parameter_request);
69 StatusOr<ComputationDataHandle> AddPadInstruction(
70 const PadRequest& pad_request);
73 Status AddTraceInstruction(
const TraceRequest& trace_request);
75 StatusOr<ComputationDataHandle> AddRngInstruction(
76 const RngRequest& rng_request);
79 StatusOr<ComputationDataHandle> AddUnaryInstruction(
80 const UnaryOpRequest& unary_request);
82 StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
83 const BatchNormTrainingRequest& batch_norm_training_request);
85 StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
86 const BatchNormInferenceRequest& batch_norm_inference_request);
88 StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
89 const BatchNormGradRequest& batch_norm_grad_request);
92 StatusOr<ComputationDataHandle> AddBinaryInstruction(
93 const BinaryOpRequest& binary_request);
96 StatusOr<ComputationDataHandle> AddTernaryInstruction(
97 const TernaryOpRequest& ternary_request);
100 StatusOr<ComputationDataHandle> AddVariadicInstruction(
101 const VariadicOpRequest& variadic_request);
103 StatusOr<ComputationDataHandle> AddConstantInstruction(
104 const ConstantRequest& constant_request);
106 StatusOr<ComputationDataHandle> AddGetTupleElementInstruction(
107 const GetTupleElementRequest& get_tuple_element_request);
109 StatusOr<ComputationDataHandle> AddMapInstruction(
110 const MapRequest& map_request,
113 StatusOr<ComputationDataHandle> AddReducePrecisionInstruction(
114 const ReducePrecisionRequest& reduce_precision_request);
116 StatusOr<ComputationDataHandle> AddConvolveInstruction(
117 const ConvolveRequest& convolve_request);
119 StatusOr<ComputationDataHandle> AddFftInstruction(
120 const FftRequest& fft_request);
122 StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction(
123 const CrossReplicaSumRequest& cross_replica_sum_request);
125 StatusOr<ComputationDataHandle> AddInfeedInstruction(
126 const InfeedRequest& infeed_request);
128 StatusOr<ComputationDataHandle> AddOutfeedInstruction(
129 const OutfeedRequest& outfeed_request);
131 StatusOr<ComputationDataHandle> AddHostComputeInstruction(
132 const HostComputeRequest& host_compute_request);
134 StatusOr<ComputationDataHandle> AddCallInstruction(
135 const CallRequest& call_request,
138 StatusOr<ComputationDataHandle> AddCustomCallInstruction(
139 const CustomCallRequest& custom_call_request);
141 StatusOr<ComputationDataHandle> AddDotInstruction(
142 const DotRequest& dot_request);
144 StatusOr<ComputationDataHandle> AddBroadcastInstruction(
145 const BroadcastRequest& broadcast_request);
147 StatusOr<ComputationDataHandle> AddReshapeInstruction(
148 const ReshapeRequest& reshape_request);
150 StatusOr<ComputationDataHandle> AddTransposeInstruction(
151 const TransposeRequest& transpose_request);
153 StatusOr<ComputationDataHandle> AddSliceInstruction(
154 const SliceRequest& slice_request);
156 StatusOr<ComputationDataHandle> AddDynamicSliceInstruction(
157 const DynamicSliceRequest& dynamic_slice_request);
159 StatusOr<ComputationDataHandle> AddDynamicUpdateSliceInstruction(
160 const DynamicUpdateSliceRequest& dynamic_update_slice_request);
162 StatusOr<ComputationDataHandle> AddConcatenateInstruction(
163 const ConcatenateRequest& concatenate_request);
165 StatusOr<ComputationDataHandle> AddConvertInstruction(
166 const ConvertRequest& convert_request);
168 StatusOr<ComputationDataHandle> AddBitcastConvertInstruction(
169 const ConvertRequest& convert_request);
171 StatusOr<ComputationDataHandle> AddReduceInstruction(
172 const ReduceRequest& reduce_request,
175 StatusOr<ComputationDataHandle> AddReduceWindowInstruction(
176 const ReduceWindowRequest& reduce_window_request,
180 StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction(
181 const SelectAndScatterRequest& select_and_scatter_request,
185 StatusOr<ComputationDataHandle> AddReverseInstruction(
186 const ReverseRequest& reverse_request);
188 StatusOr<ComputationDataHandle> AddWhileInstruction(
189 const WhileRequest& while_request,
193 StatusOr<ComputationDataHandle> AddConditionalInstruction(
194 const ConditionalRequest& conditional_request,
198 StatusOr<ComputationDataHandle> AddSendInstruction(
199 const SendRequest& send_request);
201 StatusOr<ComputationDataHandle> AddRecvInstruction(
202 const RecvRequest& recv_request);
204 StatusOr<ComputationDataHandle> AddGatherInstruction(
205 const GatherRequest& gather_request);
208 const string& name()
const {
return name_; }
212 Status SetReturnValue(
const ComputationDataHandle& handle);
218 const ComputationDataHandle& operation)
const;
221 VersionedComputationHandle::Version version()
const;
229 StatusOr<std::shared_ptr<const ProgramShape>> ComputeProgramShape(
230 VersionedComputationHandle::Version version)
const;
234 StatusOr<bool> IsConstant(
const ComputationDataHandle& handle,
235 int64 num_parameters);
237 StatusOr<Shape> GetShape(
const ComputationDataHandle& handle);
239 Status SetOpMetadata(
const ComputationDataHandle& handle,
240 const OpMetadata& metadata);
242 Status SetOpSharding(
const ComputationDataHandle& handle,
243 const OpSharding& sharding);
252 using HloComputationResolver =
253 std::function<HloComputation*(const VersionedComputationHandle& handle)>;
255 VersionedComputationHandle::Version version,
256 HloComputationResolver hlo_resolver,
const DebugOptions& debug_options,
257 bool include_unreachable_instructions =
true)
const;
262 std::vector<VersionedComputationHandle> GetEmbeddedComputations(
263 VersionedComputationHandle::Version version)
const;
267 int64 request_count(VersionedComputationHandle::Version version)
const {
274 SessionComputation CloneSessionComputation(
275 VersionedComputationHandle::Version version)
const;
282 StatusOr<const OperationRequest*> LookUpRequestForErrorReporting(
283 const ComputationDataHandle& handle)
const;
289 tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
290 int parameter_number)
const;
300 Status RemapEmbeddedComputations(
301 const std::map<int64, ComputationHandle>& old_to_new)
302 EXCLUSIVE_LOCKS_REQUIRED(mutex_);
304 StatusOr<const OperationRequest*> LookUpRequest(
305 const ComputationDataHandle& handle)
const 306 EXCLUSIVE_LOCKS_REQUIRED(mutex_);
308 ComputationDataHandle CreateComputationDataHandle()
309 EXCLUSIVE_LOCKS_REQUIRED(mutex_);
312 Status CheckParametersAreContiguous(
313 VersionedComputationHandle::Version version)
const 314 EXCLUSIVE_LOCKS_REQUIRED(mutex_);
316 EXCLUSIVE_LOCKS_REQUIRED(mutex_);
319 mutable tensorflow::mutex mutex_;
321 SessionComputation session_computation_ GUARDED_BY(mutex_);
324 std::map<int64, OperationRequest*> parameters_ GUARDED_BY(mutex_);
327 int64 next_handle_value_ GUARDED_BY(mutex_);
331 ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_);
334 mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0;
335 mutable std::shared_ptr<const ProgramShape> program_shape_ GUARDED_BY(mutex_);
339 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
Definition: versioned_computation_handle.h:37
Definition: user_computation.h:49
StatusOr< std::unique_ptr< HloComputation > > BuildHloComputation(VersionedComputationHandle::Version version, HloComputationResolver hlo_resolver, const DebugOptions &debug_options, bool include_unreachable_instructions=true) const
Build a HLO computation from the UserComputation
Definition: user_computation.cc:3102
StatusOr< ComputationDataHandle > AddParameterInstruction(const ParameterRequest ¶meter_request)
Enqueue a parameter-retrieving instruction onto this UserComputation.
Definition: user_computation.cc:195
namespace for xla
Definition: client_library.cc:26
VersionedComputationHandle GetVersionedHandleInternal() const EXCLUSIVE_LOCKS_REQUIRED(mutex_)
Get the VersionedComputationHandle recorded in UserComputation object.
Definition: user_computation.cc:1263
VersionedComputationHandle GetVersionedHandle() const
Lock mutex and call GetVersionedHandleInternal()
Definition: user_computation.cc:1253