tf_1.8_xla_doc
user_computation.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7  http://www.apache.org/licenses/LICENSE-2.0
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
15 #define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
16 #include <functional>
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <vector>
22 #include "tensorflow/compiler/xla/service/session.pb.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 #include "tensorflow/core/platform/types.h"
35 namespace xla {
50  public:
51  // Factory used when restoring a computation from serialized session
52  // computation (computation snapshot) data. Remaps any references to
53  // computation handle via the old_to_new mapping.
54  //
55  // An error will occur if the old_to_new mapping cannot resolve a reference to
56  // a computation that is present in session_computation.
57  static StatusOr<std::unique_ptr<UserComputation>> MakeWithRemapping(
58  const SessionComputation& session_computation,
59  const ComputationHandle& handle,
60  const std::map<int64, ComputationHandle>& old_to_new);
61  // Creates an empty computation with the given name and computation handle.
62  explicit UserComputation(const string& name, const ComputationHandle& handle);
63  // Enqueues a parameter-retrieving instruction onto this user computation.
64  // Returns an error status if the parameter number is already registered with
65  // different values.
66  StatusOr<ComputationDataHandle> AddParameterInstruction(
67  const ParameterRequest& parameter_request);
68  // Enqueues a pad instruction onto this user computation.
69  StatusOr<ComputationDataHandle> AddPadInstruction(
70  const PadRequest& pad_request);
71  // Enqueues a tracing instruction onto this user computation.
72  // Returns an error status if the operand cannot be resolved.
73  Status AddTraceInstruction(const TraceRequest& trace_request);
74  // Enqueues a random number generation instruction onto this user computation.
75  StatusOr<ComputationDataHandle> AddRngInstruction(
76  const RngRequest& rng_request);
77  // Enqueues a unary instruction onto this user computation.
78  // Returns an error status if the operand index is out of bounds.
79  StatusOr<ComputationDataHandle> AddUnaryInstruction(
80  const UnaryOpRequest& unary_request);
81  // Enqueues a batch norm training instruction onto this user computation.
82  StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
83  const BatchNormTrainingRequest& batch_norm_training_request);
84  // Enqueues a batch norm inference instruction onto this user computation.
85  StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
86  const BatchNormInferenceRequest& batch_norm_inference_request);
87  // Enqueues a batch norm grad instruction onto this user computation.
88  StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
89  const BatchNormGradRequest& batch_norm_grad_request);
90  // Enqueues a binary instruction onto this user computation.
91  // Returns an error status if the operand indices are out of bounds.
92  StatusOr<ComputationDataHandle> AddBinaryInstruction(
93  const BinaryOpRequest& binary_request);
94  // Enqueues a ternary instruction onto this user computation.
95  // Returns an error status if the operand indices are out of bounds.
96  StatusOr<ComputationDataHandle> AddTernaryInstruction(
97  const TernaryOpRequest& ternary_request);
98  // Enqueues a variadic instruction onto this user computation.
99  // Returns an error status if the operand indices are out of bounds.
100  StatusOr<ComputationDataHandle> AddVariadicInstruction(
101  const VariadicOpRequest& variadic_request);
102  // Enqueues a constant instruction onto this user computation.
103  StatusOr<ComputationDataHandle> AddConstantInstruction(
104  const ConstantRequest& constant_request);
105  // Enqueues a get tuple element instruction onto this user computation.
106  StatusOr<ComputationDataHandle> AddGetTupleElementInstruction(
107  const GetTupleElementRequest& get_tuple_element_request);
108  // Enqueues a map instruction onto this user computation.
109  StatusOr<ComputationDataHandle> AddMapInstruction(
110  const MapRequest& map_request,
111  const UserComputation& to_apply_computation);
112  // Enqueues a reduce-precision instruction onto this user computation.
113  StatusOr<ComputationDataHandle> AddReducePrecisionInstruction(
114  const ReducePrecisionRequest& reduce_precision_request);
115  // Enqueues a convolution instruction onto this user computation.
116  StatusOr<ComputationDataHandle> AddConvolveInstruction(
117  const ConvolveRequest& convolve_request);
118  // Enqueues an FFT instruction onto this user computation.
119  StatusOr<ComputationDataHandle> AddFftInstruction(
120  const FftRequest& fft_request);
121  // Enqueues a cross replica sum instruction onto this user computation.
122  StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction(
123  const CrossReplicaSumRequest& cross_replica_sum_request);
124  // Enqueues an infeed instruction onto this user computation.
125  StatusOr<ComputationDataHandle> AddInfeedInstruction(
126  const InfeedRequest& infeed_request);
127  // Enqueues an outfeed instruction onto this user computation.
128  StatusOr<ComputationDataHandle> AddOutfeedInstruction(
129  const OutfeedRequest& outfeed_request);
130  // Enqueues a host compute instruction onto this user computation.
131  StatusOr<ComputationDataHandle> AddHostComputeInstruction(
132  const HostComputeRequest& host_compute_request);
133  // Enqueues a call instruction onto this user computation.
134  StatusOr<ComputationDataHandle> AddCallInstruction(
135  const CallRequest& call_request,
136  const UserComputation& to_apply_computation);
137  // Enqueues a custom call instruction onto this user computation.
138  StatusOr<ComputationDataHandle> AddCustomCallInstruction(
139  const CustomCallRequest& custom_call_request);
140  // Enqueues a dot instruction onto this user computation.
141  StatusOr<ComputationDataHandle> AddDotInstruction(
142  const DotRequest& dot_request);
143  // Enqueues a broadcast instruction onto this user computation.
144  StatusOr<ComputationDataHandle> AddBroadcastInstruction(
145  const BroadcastRequest& broadcast_request);
146  // Enqueues a reshape instruction onto this user computation.
147  StatusOr<ComputationDataHandle> AddReshapeInstruction(
148  const ReshapeRequest& reshape_request);
149  // Enqueues a transpose instruction onto this user computation.
150  StatusOr<ComputationDataHandle> AddTransposeInstruction(
151  const TransposeRequest& transpose_request);
152  // Enqueues a slice instruction onto this user computation.
153  StatusOr<ComputationDataHandle> AddSliceInstruction(
154  const SliceRequest& slice_request);
155  // Enqueues a dynamic slice instruction onto this user computation.
156  StatusOr<ComputationDataHandle> AddDynamicSliceInstruction(
157  const DynamicSliceRequest& dynamic_slice_request);
158  // Enqueues a dynamic update slice instruction onto this user computation.
159  StatusOr<ComputationDataHandle> AddDynamicUpdateSliceInstruction(
160  const DynamicUpdateSliceRequest& dynamic_update_slice_request);
161  // Enqueues a concatenate instruction onto this user computation.
162  StatusOr<ComputationDataHandle> AddConcatenateInstruction(
163  const ConcatenateRequest& concatenate_request);
164  // Enqueues a convert instruction onto this user computation.
165  StatusOr<ComputationDataHandle> AddConvertInstruction(
166  const ConvertRequest& convert_request);
167  // Enqueues a bitcast element instruction onto this user computation.
168  StatusOr<ComputationDataHandle> AddBitcastConvertInstruction(
169  const ConvertRequest& convert_request);
170  // Enqueues a reduce instruction onto this user computation.
171  StatusOr<ComputationDataHandle> AddReduceInstruction(
172  const ReduceRequest& reduce_request,
173  const UserComputation& to_apply_computation);
174  // Enqueues a windowed reduce instruction onto this user computation.
175  StatusOr<ComputationDataHandle> AddReduceWindowInstruction(
176  const ReduceWindowRequest& reduce_window_request,
177  const UserComputation& to_apply_computation);
178  // Enqueues a select-and-scatter instruction onto this user
179  // computation.
180  StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction(
181  const SelectAndScatterRequest& select_and_scatter_request,
182  const UserComputation& select_computation,
183  const UserComputation& scatter_computation);
184  // Enqueues a reverse instruction onto this user computation.
185  StatusOr<ComputationDataHandle> AddReverseInstruction(
186  const ReverseRequest& reverse_request);
187  // Enqueues a while instruction onto this user computation.
188  StatusOr<ComputationDataHandle> AddWhileInstruction(
189  const WhileRequest& while_request,
190  const UserComputation& condition_computation,
191  const UserComputation& body_computation);
192  // Enqueues a conditional instruction on this user computation.
193  StatusOr<ComputationDataHandle> AddConditionalInstruction(
194  const ConditionalRequest& conditional_request,
195  const UserComputation& true_computation,
196  const UserComputation& false_computation);
197  // Enqueues a Send instruction onto this user computation.
198  StatusOr<ComputationDataHandle> AddSendInstruction(
199  const SendRequest& send_request);
200  // Enqueues a Recv instruction onto this user computation.
201  StatusOr<ComputationDataHandle> AddRecvInstruction(
202  const RecvRequest& recv_request);
203  // Enqueues a Gather instruction onto this user computation.
204  StatusOr<ComputationDataHandle> AddGatherInstruction(
205  const GatherRequest& gather_request);
206  // Returns the user-provided name of this user computation, which is provided
207  // via the XLA computation-building API.
208  const string& name() const { return name_; }
209  // Subsequent executions of this computation will compute the value
210  // represented by handle, rather than the last expression enqueued
211  // on the computation.
212  Status SetReturnValue(const ComputationDataHandle& handle);
213  // Return a versioned handle for this computation.
215  // Return a versioned handle for this computation with a version equal to the
216  // point at which given operation was added to the computation.
217  VersionedComputationHandle GetVersionedHandleAtOperation(
218  const ComputationDataHandle& operation) const;
219  // Return a version value representing the current state of the
220  // computation.
221  VersionedComputationHandle::Version version() const;
222  // Computes and returns the program shape for the user computation -- gathers
223  // parameters and result type into a single proto. A shared_ptr is used
224  // because the returned pointer refers to an internally cached value which may
225  // be discarded by the UserComputation object. This avoid unnecessary copies.
226  //
227  // If the parameter space is not dense (i.e. there are holes in the parameter
228  // numbers provided) then an error status is returned.
229  StatusOr<std::shared_ptr<const ProgramShape>> ComputeProgramShape(
230  VersionedComputationHandle::Version version) const;
231  // Returns true if the given data handle does not depend on any parameter with
232  // index higher then num_parameters. That is, the value can be computed at
233  // compile time if we know the first num_parameters arguments.
234  StatusOr<bool> IsConstant(const ComputationDataHandle& handle,
235  int64 num_parameters);
236  // Returns the output shape of the operation indicated by the given handle.
237  StatusOr<Shape> GetShape(const ComputationDataHandle& handle);
238  // Sets metadata on the Hlo instruction referenced by the given handle.
239  Status SetOpMetadata(const ComputationDataHandle& handle,
240  const OpMetadata& metadata);
241  // Sets the device assignment on the Hlo instruction referenced by 'handle'.
242  Status SetOpSharding(const ComputationDataHandle& handle,
243  const OpSharding& sharding);
244  // Builds a HLO computation from the UserComputation. The parameter "resolver"
245  // is a function which returns a pointer to the HloComputation corresponding
246  // to the given ComputationHandle at the given version. The resolver is used
247  // for operations, such as map, which call other computations and need a
248  // pointer to the called HloComputation to construct the respective HLO
249  // instructions. If include_unreachable_instructions is true, then
250  // instructions which are not reachable from the root are lowered into
251  // HloInstructions.
252  using HloComputationResolver =
253  std::function<HloComputation*(const VersionedComputationHandle& handle)>;
254  StatusOr<std::unique_ptr<HloComputation>> BuildHloComputation(
255  VersionedComputationHandle::Version version,
256  HloComputationResolver hlo_resolver, const DebugOptions& debug_options,
257  bool include_unreachable_instructions = true) const;
258  // Return a vector containing the embedded computations used by this
259  // UserComputation. Only embedded computations which are called directly by
260  // this UserComputation are included. That is, the transitive closure of
261  // embedded computations is not included.
262  std::vector<VersionedComputationHandle> GetEmbeddedComputations(
263  VersionedComputationHandle::Version version) const;
264  // Returns the number of OperationRequest objects in this UserComputation.
265  // The 'version' of a computation is identical to the number of
266  // OperationRequests in the UserComputation.
267  int64 request_count(VersionedComputationHandle::Version version) const {
268  return version;
269  }
270  // Returns a copy of the internal session state for this computation -- this
271  // is useful for serializing the guts of a user computation, though references
272  // to other handles (e.g. referred-to computations) must be handled with care
273  // in the serialization / de-serialization process.
274  SessionComputation CloneSessionComputation(
275  VersionedComputationHandle::Version version) const;
276  // Warning: typically we don't want to look up computation data handles until
277  // the computation is finished being built, for consistency purposes. We
278  // expose this routine for error reporting purposes so that we can provide
279  // more meaningful error messages from the XLA service layer.
280  //
281  // Returns the operation request that the handle comes from.
282  StatusOr<const OperationRequest*> LookUpRequestForErrorReporting(
283  const ComputationDataHandle& handle) const;
284  // Retrieves the parameter metadata for the given parameter number.
285  //
286  // If the parameter number is invalid for this computation, nullopt is
287  // returned. When the return value has_value(), nullptr will never be
288  // the held value.
289  tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
290  int parameter_number) const;
291  private:
292  // Warning: dangerous mutating operation that doesn't respect versioning.
293  // This is only used at initialization time when constructing from a
294  // SessionComputation a la MakeWithRemapping.
295  //
296  // Remaps references to old computations (with handle values in the keys of
297  // old_to_new) to the computation handle given in the values. This is useful
298  // when loading computations from snapshots, to finish initialization, before
299  // the user computation is released into the wild.
300  Status RemapEmbeddedComputations(
301  const std::map<int64, ComputationHandle>& old_to_new)
302  EXCLUSIVE_LOCKS_REQUIRED(mutex_);
303  // Returns the OperationRequest corresponding to the given handle.
304  StatusOr<const OperationRequest*> LookUpRequest(
305  const ComputationDataHandle& handle) const
306  EXCLUSIVE_LOCKS_REQUIRED(mutex_);
307  // Creates a new ComputationDataHandle with the next available handle value.
308  ComputationDataHandle CreateComputationDataHandle()
309  EXCLUSIVE_LOCKS_REQUIRED(mutex_);
310  // Checks whether the parameter numbers of the parameter operations are
311  // contiguous starting from zero. Returns appropriate error status if not.
312  Status CheckParametersAreContiguous(
313  VersionedComputationHandle::Version version) const
314  EXCLUSIVE_LOCKS_REQUIRED(mutex_);
316  EXCLUSIVE_LOCKS_REQUIRED(mutex_);
317  // Name of the computation.
318  string name_;
319  mutable tensorflow::mutex mutex_;
320  // State of the computation as a record of all operation-building requests.
321  SessionComputation session_computation_ GUARDED_BY(mutex_);
322  // Mapping from parameter number to operation request containing the
323  // respective ParameterRequest.
324  std::map<int64, OperationRequest*> parameters_ GUARDED_BY(mutex_);
325  // The next ComputationDataHandle value to assign. Handle values are assigned
326  // sequentially.
327  int64 next_handle_value_ GUARDED_BY(mutex_);
328  // If handle_to_return_.has_handle() then an Execution of this Computation
329  // will compute the value represented by handle_to_return_, otherwise it will
330  // compute the value of (next_handle_value_ - 1).
331  ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_);
332  // Memoized ProgramShape and its version. A shared_ptr is used because
333  // references to this object are returned by ComputeProgramShape.
334  mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0;
335  mutable std::shared_ptr<const ProgramShape> program_shape_ GUARDED_BY(mutex_);
336  TF_DISALLOW_COPY_AND_ASSIGN(UserComputation);
337 };
338 } // namespace xla
339 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
340 
Definition: versioned_computation_handle.h:37
Definition: user_computation.h:49
StatusOr< std::unique_ptr< HloComputation > > BuildHloComputation(VersionedComputationHandle::Version version, HloComputationResolver hlo_resolver, const DebugOptions &debug_options, bool include_unreachable_instructions=true) const
Build a HLO computation from the UserComputation
Definition: user_computation.cc:3102
StatusOr< ComputationDataHandle > AddParameterInstruction(const ParameterRequest &parameter_request)
Enqueue a parameter-retrieving instruction onto this UserComputation.
Definition: user_computation.cc:195
namespace for xla
Definition: client_library.cc:26
VersionedComputationHandle GetVersionedHandleInternal() const EXCLUSIVE_LOCKS_REQUIRED(mutex_)
Get the VersionedComputationHandle recorded in UserComputation object.
Definition: user_computation.cc:1263
VersionedComputationHandle GetVersionedHandle() const
Lock mutex and call GetVersionedHandleInternal()
Definition: user_computation.cc:1253