tf_1.8_xla_doc
computation_builder.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
19 #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
20 
21 #include <functional>
22 #include <initializer_list>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 
27 #include "tensorflow/compiler/xla/array.h"
28 #include "tensorflow/compiler/xla/array2d.h"
29 #include "tensorflow/compiler/xla/array3d.h"
30 #include "tensorflow/compiler/xla/array4d.h"
31 #include "tensorflow/compiler/xla/client/client.h"
32 #include "tensorflow/compiler/xla/client/computation.h"
33 #include "tensorflow/compiler/xla/client/global_data.h"
34 #include "tensorflow/compiler/xla/client/padding.h"
35 #include "tensorflow/compiler/xla/literal_util.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/lib/core/bitmap.h"
40 #include "tensorflow/core/lib/core/stringpiece.h"
41 #include "tensorflow/core/lib/gtl/array_slice.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/stacktrace.h"
44 #include "tensorflow/core/platform/types.h"
45 
49 namespace xla {
50 
54 // Wraps an XLA client with a convenient interface for building up
55 // computations. Any errors encountered in building up the computation are
56 // deferred from being handled until Build() is called.
57 //
58 // Thread-compatible.
60  public:
61  // client: client in which to build the computation.
62  // computation_name: name to use for the built computation.
63  ComputationBuilder(Client* client, const string& computation_name);
64 
66 
67  // Returns the client the builder was initialized with.
68  Client* client() const { return client_; }
69 
70  // Returns the computation name.
71  const string& name() const { return name_; }
72 
73  // Sets OpMetadata that will be added to all instructions until cleared.
74  //
75  // OpMetadata is often applied to a series of XLA HLO instructions. As a
76  // result, OpMetadata is set on the Computation Builder. All subsequent
77  // instructions generated via this Computation Builder will have the same
78  // OpMetadata attached until a call to ClearOpMetadata.
79  void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
80 
81  // Clears the HloMetadata state.
82  void ClearOpMetadata() { metadata_.Clear(); }
83 
84  // Sets an OpSharding that will be attached to all instructions until cleared.
85  void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
86 
87  // Clears the sharding. Ops will be sharded according to the default placement
88  // policy.
89  void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
90 
91  // Returns the OpSharding that will be attached to all instructions.
92  const tensorflow::gtl::optional<OpSharding>& sharding() const {
93  return sharding_;
94  }
95 
96  // Sets the builder to a mode where it will die immediately when an error is
97  // encountered, rather than producing it in a deferred fashion when Build() is
98  // called (which is the default).
99  void set_die_immediately_on_error(bool enabled) {
100  die_immediately_on_error_ = enabled;
101  }
102 
103  // Enqueues a "retrieve parameter value" instruction for a parameter that was
104  // passed to the computation.
105  ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
106  const string& name);
107 
108  // Retrieves the (inferred) shape of the operand in the computation.
109  StatusOr<std::unique_ptr<Shape>> GetShape(
110  const ComputationDataHandle& operand);
111 
112  // Retrieves the (inferred) result for the current computation's shape.
113  StatusOr<ProgramShape> GetProgramShape();
114 
115  // Enqueues a constant with the value of the given literal onto the
116  // computation.
117  ComputationDataHandle ConstantLiteral(const Literal& literal);
118 
119  // Enqueues a constant onto the computation. Methods are templated on the
120  // native host type (NativeT) which corresponds to a specific XLA
121  // PrimitiveType as given in the following table:
122  //
123  // Native Type PrimitiveType
124  // -----------------------------
125  // bool PRED
126  // int32 S32
127  // int64 S64
128  // uint32 U32
129  // uint64 U64
130  // float F32
131  // double F64
132  //
133  // Note: not all primitive types defined in xla_data.proto have a
134  // corresponding native type yet.
135  template <typename NativeT>
136  ComputationDataHandle ConstantR0(NativeT value);
137  template <typename NativeT>
138  ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
139  ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values);
140  template <typename NativeT>
141  ComputationDataHandle ConstantR2(
142  std::initializer_list<std::initializer_list<NativeT>> values);
143  template <typename NativeT>
144  ComputationDataHandle ConstantFromArrayWithLayout(
145  const Array<NativeT>& values, const Layout& layout);
146  template <typename NativeT>
147  ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
148  template <typename NativeT>
149  ComputationDataHandle ConstantR2FromArray2DWithLayout(
150  const Array2D<NativeT>& values, const Layout& layout);
151  template <typename NativeT>
152  ComputationDataHandle ConstantR2FromArray2D(const Array2D<NativeT>& values);
153  template <typename NativeT>
154  ComputationDataHandle ConstantR3FromArray3DWithLayout(
155  const Array3D<NativeT>& values, const Layout& layout);
156  template <typename NativeT>
157  ComputationDataHandle ConstantR3FromArray3D(const Array3D<NativeT>& values);
158  template <typename NativeT>
159  ComputationDataHandle ConstantR4FromArray4DWithLayout(
160  const Array4D<NativeT>& values, const Layout& layout);
161  template <typename NativeT>
162  ComputationDataHandle ConstantR4FromArray4D(const Array4D<NativeT>& values);
163 
164  // Enqueues a rank one constant (vector) onto the computation. The vector has
165  // size 'length' and every element has the value 'value'.
166  template <typename NativeT>
167  ComputationDataHandle ConstantR1(int64 length, NativeT value);
168 
169  // Adds dimensions to an array by duplicating the data in the array.
170  //
171  // The new dimensions are inserted on the left, i.e. if
172  // broadcast_sizes has values {a0, ..., aN} and the operand shape
173  // has dimensions {b0, ..., bM} then the shape of the output has
174  // dimensions {a0, ..., aN, b0, ..., bM}.
175  //
176  // The new dimensions index into copies of the operand, i.e.
177  //
178  // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
179  ComputationDataHandle Broadcast(
180  const ComputationDataHandle& operand,
181  tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
182 
183  // Enqueues a pad operation onto the computation that pads the given value on
184  // the edges as well as between the elements of the input. padding_config
185  // specifies the padding amount for each dimension.
186  ComputationDataHandle Pad(const ComputationDataHandle& operand,
187  const ComputationDataHandle& padding_value,
188  const PaddingConfig& padding_config);
189 
190  // Enqueues an operation onto the computation that flattens the operand based
191  // on the dimension order (major/slowest-varying to minor/fastest-varying)
192  // given, followed by reshaping it into the shape with the given dimension
193  // sizes (also major to minor). Conceptually, this is a limited form of
194  // "shape casting".
195  ComputationDataHandle Reshape(const ComputationDataHandle& operand,
196  tensorflow::gtl::ArraySlice<int64> dimensions,
197  tensorflow::gtl::ArraySlice<int64> new_sizes);
198 
199  // Enqueues an operation onto the computation that collapses the operand, from
200  // first to last dimension (C order), then reshapes it to the given dimension
201  // sizes. Conceptually, this is a limited form of "shape casting".
202  ComputationDataHandle Reshape(const ComputationDataHandle& operand,
203  tensorflow::gtl::ArraySlice<int64> new_sizes);
204 
205  // Wrapper for Reshape.
206  // Enqueues an operation to collapse the provided dimensions; e.g. an
207  // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
208  // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
209  // be a consecutive, in-order subsequence of the operand dimensions.
210  //
211  // Note that collapsing a single dimension does nothing:
212  //
213  // {256} collapsing {0} => {256}
214  // {1} collapsing {0} => {1}
215  //
216  // Collapsing multiple dimensions produces a single result dimension:
217  //
218  // {256, 2} collapsing {0,1} => {512}
219  // {256, 2, 3} collapsing {0,1} => {512, 3}
220  //
221  // This could potentially cause data to be moved -- it provides a more
222  // structured form of reshaping than an arbitrary Reshape operation.
223  ComputationDataHandle Collapse(const ComputationDataHandle& operand,
224  tensorflow::gtl::ArraySlice<int64> dimensions);
225 
226  // Enqueues a slice operation onto the computation that slices the operand
227  // from the start indices to the limit indices; e.g.
228  //
229  // x
230  // [ 0 1 2 3 ]
231  // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
232  // [ 8 9 a b ]
233  //
234  // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
235  // range notation.
236  // The strides parameter determines the stride over the slice
237  ComputationDataHandle Slice(const ComputationDataHandle& operand,
238  tensorflow::gtl::ArraySlice<int64> start_indices,
239  tensorflow::gtl::ArraySlice<int64> limit_indices,
240  tensorflow::gtl::ArraySlice<int64> strides);
241 
242  // Enqueues a slice operation in a given dimension, taking all other
243  // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
244  // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
245  // for:
246  //
247  // array[:, 2:4:1, :]
248  ComputationDataHandle SliceInDim(const ComputationDataHandle& operand,
249  int64 start_index, int64 limit_index,
250  int64 stride, int64 dimno);
251 
252  // Enqueues a slice operation onto the computation that slices the 'operand'
253  // from dynamic start indices which are passed in 'start_indices'.
254  // The size of the slice in each dimension is passed in 'slice_sizes',
255  // which specify the end point of exclusive slice intervals in each
256  // dimension [start, start + size).
257  // The shape of 'start_indices' must be rank == 1, with dimension size
258  // equal to the rank of the 'operand'.
259  // Slice index calculations are computed modulo input dimension sizes to
260  // prevent dynamic start indices from generating out-of-bound array accesses.
261  ComputationDataHandle DynamicSlice(
262  const ComputationDataHandle& operand,
263  const ComputationDataHandle& start_indices,
264  tensorflow::gtl::ArraySlice<int64> slice_sizes);
265 
266  // Enqueues a dynamic update slice operation onto the computation, which
267  // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
268  // The shape of 'update' determines the shape of the slice of 'operand'
269  // which is updated.
270  // The indices specified in 'start_indices' specify the offset of the slice
271  // of 'operand' which is updated.
272  //
273  // update = {10, 11} // calculated at runtime.
274  // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
275  // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
276  // [7 8 9] [7 8 9 ]
277  //
278  // The shape of 'start_indices' must be rank == 1, with dimension size
279  // equal to the rank of the 'operand'.
280  // Slice index calculations are computed modulo update dimension sizes to
281  // prevent dynamic start indices from generating out-of-bound array accesses.
282  ComputationDataHandle DynamicUpdateSlice(
283  const ComputationDataHandle& operand, const ComputationDataHandle& update,
284  const ComputationDataHandle& start_indices);
285 
286  // Enqueues a concatenate instruction onto the computation. 'operands' must
287  // have >= 1 entry.
288  ComputationDataHandle ConcatInDim(
289  tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
290  int64 dimension);
291 
292  // Enqueue a tracing operation onto the computation; the computation will emit
293  // a logging message with the operand.
294  void Trace(const string& tag, const ComputationDataHandle& operand);
295 
296  // Enqueues a conditional-move-like select operation onto the computation;
297  // predicated on pred, selects between on_true and on_false.
298  ComputationDataHandle Select(const ComputationDataHandle& pred,
299  const ComputationDataHandle& on_true,
300  const ComputationDataHandle& on_false);
301 
302  // Enqueues a tuple-creation instruction onto the computation.
303  ComputationDataHandle Tuple(
304  tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
305 
306  // Enqueues a tuple-element-get instruction onto the computation.
307  ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
308  int64 index);
309 
310  // Enqueues an equal-to comparison instruction onto the computation.
311  ComputationDataHandle Eq(
312  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
313  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
314 
315  // Enqueues a not-equal comparison instruction onto the computation.
316  ComputationDataHandle Ne(
317  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
318  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
319 
320  // Enqueues a greater-or-equal comparison instruction onto the computation.
321  ComputationDataHandle Ge(
322  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
323  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
324 
325  // Enqueues a greater-than comparison instruction onto the computation.
326  ComputationDataHandle Gt(
327  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
328  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
329 
330  // Enqueues a less-than comparison instruction onto the computation.
331  ComputationDataHandle Lt(
332  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
333  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
334 
335  // Enqueues a less-or-equal comparison instruction onto the computation.
336  ComputationDataHandle Le(
337  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
338  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
339 
340  // Enqueues a dot instruction onto the computation.
341  ComputationDataHandle Dot(const ComputationDataHandle& lhs,
342  const ComputationDataHandle& rhs);
343 
344  // Enqueues a general dot instruction onto the computation.
345  ComputationDataHandle DotGeneral(
346  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
347  const DotDimensionNumbers& dimension_numbers);
348 
349  // Default dimension numbers used for a 2D convolution.
350  static constexpr int64 kConvBatchDimension = 0;
351  static constexpr int64 kConvFeatureDimension = 1;
352  static constexpr int64 kConvFirstSpatialDimension = 2;
353  static constexpr int64 kConvSecondSpatialDimension = 3;
354  static constexpr int64 kConvKernelOutputDimension = 0;
355  static constexpr int64 kConvKernelInputDimension = 1;
356  static constexpr int64 kConvKernelFirstSpatialDimension = 2;
357  static constexpr int64 kConvKernelSecondSpatialDimension = 3;
358 
359  // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
360  // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
361  // the kernel operand
362  // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
363  static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
364  int num_spatial_dims = 2);
365 
366  // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an
367  // error if either the input or the weight dimension numbers have conflicts.
368  static StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers(
369  int64 input_batch, int64 input_feature, int64 input_first_spatial,
370  int64 input_second_spatial, int64 output_batch, int64 output_feature,
371  int64 output_first_spatial, int64 output_second_spatial,
372  int64 kernel_output_feature, int64 kernel_input_feature,
373  int64 kernel_first_spatial, int64 kernel_second_spatial);
374 
375  // Enqueues a convolution instruction onto the computation, which uses the
376  // default convolution dimension numbers.
377  ComputationDataHandle Conv(const ComputationDataHandle& lhs,
378  const ComputationDataHandle& rhs,
379  tensorflow::gtl::ArraySlice<int64> window_strides,
380  Padding padding);
381 
382  // Enqueues a convolution instruction onto the computation, with the caller
383  // provided padding configuration in the format returned by MakePadding().
384  ComputationDataHandle ConvWithGeneralPadding(
385  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
386  tensorflow::gtl::ArraySlice<int64> window_strides,
387  tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
388 
389  // Enqueues a convolution instruction onto the computation, with the caller
390  // provided dimension numbers configuration.
391  ComputationDataHandle ConvWithGeneralDimensions(
392  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
393  tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
394  const ConvolutionDimensionNumbers& dimension_numbers);
395 
396  // Enqueues a convolution instruction onto the computation, with the caller
397  // provided padding configuration as well as the dimension numbers.
398  ComputationDataHandle ConvGeneral(
399  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
400  tensorflow::gtl::ArraySlice<int64> window_strides,
401  tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
402  const ConvolutionDimensionNumbers& dimension_numbers);
403 
404  // Enqueues a convolution instruction onto the computation, with the caller
405  // provided padding configuration, dilation factors and dimension numbers.
406  ComputationDataHandle ConvGeneralDilated(
407  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
408  tensorflow::gtl::ArraySlice<int64> window_strides,
409  tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
410  tensorflow::gtl::ArraySlice<int64> lhs_dilation,
411  tensorflow::gtl::ArraySlice<int64> rhs_dilation,
412  const ConvolutionDimensionNumbers& dimension_numbers);
413 
414  // Enqueues an FFT instruction onto the computation, of the given type and
415  // with the given FFT length.
416  ComputationDataHandle Fft(const ComputationDataHandle& operand,
417  FftType fft_type,
418  tensorflow::gtl::ArraySlice<int64> fft_length);
419 
420  // Enqueues an infeed instruction onto the computation, which writes data of
421  // the given shape to the infeed buffer of the device.
422  ComputationDataHandle Infeed(const Shape& shape, const string& config = "");
423 
424  // Enqueues an outfeed instruction onto the computation. This instruction
425  // generates outgoing data transfers for the given data.
426  //
427  // shape_with_layout communicates the laid out shape that we want to outfeed
428  // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
429  // will occur.
430  void Outfeed(const ComputationDataHandle& operand,
431  const Shape& shape_with_layout, const string& outfeed_config);
432 
433  // Enqueues a call instruction onto the computation.
434  ComputationDataHandle Call(
435  const Computation& computation,
436  tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
437 
438  // Enqueues a custom call instruction onto the computation.
439  // During code generation, a call instruction is emitted which targets a
440  // symbol with the name |call_target_name|. The |operands| are passed to the
441  // call instruction. |shape| is the resultant shape.
442  ComputationDataHandle CustomCall(
443  const string& call_target_name,
444  tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
445  const Shape& shape);
446 
447  // Enqueues a pseudo-op to represent host-side computation data-dependencies.
448  // During code generation, host send and receive operations will be generated
449  // to transfer |operands| to the host and a single result of |shape| back to
450  // the device. Host send/recv operations are emitted using |channel_name|.
451  // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
452  // instruction scheduling.
453  ComputationDataHandle HostCompute(
454  tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
455  const string& channel_name, int64 cost_estimate_ns, const Shape& shape);
456 
457  // The following methods enqueue element-wise binary arithmetic operations
458  // onto the computation. The shapes of the operands have to match unless one
459  // of the operands is a scalar, or an explicit broadcast dimension is given
460  // (see g3doc for more details).
461 
462  // Enqueues a complex compose instruction onto the computation.
463  ComputationDataHandle Complex(
464  const ComputationDataHandle& real, const ComputationDataHandle& imag,
465  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
466 
467  // Enqueues a complex conjugate instruction onto the computation.
468  ComputationDataHandle Conj(const ComputationDataHandle& operand);
469 
470  // Enqueues an add instruction onto the computation.
471  ComputationDataHandle Add(
472  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
473  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
474 
475  // Enqueues a subtract instruction onto the computation.
476  ComputationDataHandle Sub(
477  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
478  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
479 
480  // Enqueues a multiply instruction onto the computation.
481  ComputationDataHandle Mul(
482  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
483  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
484 
485  // Enqueues a divide instruction onto the computation.
486  ComputationDataHandle Div(
487  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
488  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
489 
490  // Enqueues a remainder instruction onto the computation.
491  ComputationDataHandle Rem(
492  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
493  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
494 
495  // Enqueues a max instruction onto the computation.
496  ComputationDataHandle Max(
497  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
498  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
499 
500  // Enqueues a min instruction onto the computation.
501  ComputationDataHandle Min(
502  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
503  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
504 
505  // Element-wise logical operators
506  ComputationDataHandle And(
507  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
508  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
509 
510  ComputationDataHandle Or(
511  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
512  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
513 
514  ComputationDataHandle Xor(
515  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
516  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
517 
518  ComputationDataHandle Not(const ComputationDataHandle& operand);
519 
520  ComputationDataHandle ShiftLeft(
521  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
522  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
523  ComputationDataHandle ShiftRightArithmetic(
524  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
525  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
526  ComputationDataHandle ShiftRightLogical(
527  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
528  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
529 
530  // Reduces an array among the provided dimensions, given "computation" as a
531  // reduction operator.
532  ComputationDataHandle Reduce(
533  const ComputationDataHandle& operand,
534  const ComputationDataHandle& init_value, const Computation& computation,
535  tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
536 
537  // Convenience wrapper around the above that reduces all the dimensions in the
538  // operand shape.
539  ComputationDataHandle ReduceAll(const ComputationDataHandle& operand,
540  const ComputationDataHandle& init_value,
541  const Computation& computation);
542 
543  // Enqueues a windowed reduce instruction onto the computation.
544  ComputationDataHandle ReduceWindow(
545  const ComputationDataHandle& operand,
546  const ComputationDataHandle& init_value, const Computation& computation,
547  tensorflow::gtl::ArraySlice<int64> window_dimensions,
548  tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
549 
550  // As ReduceWindow(), but the padding is given in the format
551  // returned by MakePadding().
552  ComputationDataHandle ReduceWindowWithGeneralPadding(
553  const ComputationDataHandle& operand,
554  const ComputationDataHandle& init_value, const Computation& computation,
555  tensorflow::gtl::ArraySlice<int64> window_dimensions,
556  tensorflow::gtl::ArraySlice<int64> window_strides,
557  tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
558 
559  // Returns the sum of the operand value across all replicas. All replicas
560  // supply one input to the sum and all replicas receive the resulting sum.
561  ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand);
562 
563  // Enqueues an operation that scatters the `source` array to the selected
564  // indices of each window.
565  ComputationDataHandle SelectAndScatter(
566  const ComputationDataHandle& operand, const Computation& select,
567  tensorflow::gtl::ArraySlice<int64> window_dimensions,
568  tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
569  const ComputationDataHandle& source,
570  const ComputationDataHandle& init_value, const Computation& scatter);
571 
572  // As SelectAndScatter(), but the padding is given in the format
573  // returned by MakePadding().
574  ComputationDataHandle SelectAndScatterWithGeneralPadding(
575  const ComputationDataHandle& operand, const Computation& select,
576  tensorflow::gtl::ArraySlice<int64> window_dimensions,
577  tensorflow::gtl::ArraySlice<int64> window_strides,
578  tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
579  const ComputationDataHandle& source,
580  const ComputationDataHandle& init_value, const Computation& scatter);
581 
582  // Enqueues an abs instruction onto the computation.
583  ComputationDataHandle Abs(const ComputationDataHandle& operand);
584 
585  // Enqueues a atan2 instruction onto the computation.
586  ComputationDataHandle Atan2(
587  const ComputationDataHandle& y, const ComputationDataHandle& x,
588  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
589 
590  // Enqueues an exp instruction onto the computation.
591  ComputationDataHandle Exp(const ComputationDataHandle& operand);
592 
593  // Enqueues a floor instruction onto the computation.
594  ComputationDataHandle Floor(const ComputationDataHandle& operand);
595 
596  // Enqueues a ceil instruction onto the computation.
597  ComputationDataHandle Ceil(const ComputationDataHandle& operand);
598 
599  // Enqueues a round instruction onto the computation, rounding to nearest even
600  // with half-way cases rounding away from zero.
601  ComputationDataHandle Round(const ComputationDataHandle& operand);
602 
603  // Enqueues an log instruction (natural logarithm) onto the computation.
604  ComputationDataHandle Log(const ComputationDataHandle& operand);
605 
606  // Enqueues a sign instruction onto the computation.
607  ComputationDataHandle Sign(const ComputationDataHandle& operand);
608 
609  // Enqueues a cosine instruction onto the computation.
610  ComputationDataHandle Cos(const ComputationDataHandle& operand);
611 
612  // Enqueues a sine instruction onto the computation.
613  ComputationDataHandle Sin(const ComputationDataHandle& operand);
614 
615  // Enqueues a tanh instruction onto the computation.
616  ComputationDataHandle Tanh(const ComputationDataHandle& operand);
617 
618  // Enqueues a real-part instruction onto the computation.
619  ComputationDataHandle Real(const ComputationDataHandle& operand);
620 
621  // Enqueues an imaginary-part instruction onto the computation.
622  ComputationDataHandle Imag(const ComputationDataHandle& operand);
623 
624  // Enqueues a float32 sqrt instruction onto the computation.
625  // (float32 is specified as there is an implicit float32 0.5f constant
626  // exponent).
627  ComputationDataHandle SqrtF32(const ComputationDataHandle& operand);
628 
629  // Enqueues a float32 square instruction onto the computation.
630  // (float32 is specified as there is an implicit float32 2.0f constant
631  // exponent).
632  ComputationDataHandle SquareF32(const ComputationDataHandle& operand);
633 
634  // Enqueues a lhs^rhs computation onto the computation.
635  ComputationDataHandle Pow(
636  const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
637  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
638 
639  // Enqueues an operator that tests if the operand's values are finite, i.e.,
640  // not Inf or NaN. Defined only for floating-point types. Returns an array of
641  // booleans with the same shape where entries are true iff the corresponding
642  // entry was NaN.
643  ComputationDataHandle IsFinite(const ComputationDataHandle& operand);
644 
645  // Enqueues a convert instruction onto the computation that changes the
646  // element type of the operand array to primitive_type.
647  ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
648  PrimitiveType new_element_type);
649 
650  // Enqueues a no-op instruction onto the computation that changes
651  // the element type of the operand array to primitive_type. The
652  // bit-widths of the source and destination element types must be
653  // identical.
654  ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand,
655  PrimitiveType new_element_type);
656 
657  // Enqueues a float32 reciprocal instruction onto the computation.
658  // (float32 is specified as there is an implicit float32 -1.0f constant
659  // exponent).
660  //
661  // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
662  // shape of the operand.
663  ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand);
664 
665  // Enqueues a negate instruction onto the computation.
666  ComputationDataHandle Neg(const ComputationDataHandle& operand);
667 
668  // Enqueues a transpose instruction onto the computation.
669  ComputationDataHandle Transpose(
670  const ComputationDataHandle& operand,
671  tensorflow::gtl::ArraySlice<int64> permutation);
672 
673  // Enqueues a reverse instruction onto the computation. The order of the
674  // elements in the given dimensions is reversed (i.e., the element at index i
675  // is moved to index dimension_size - 1 - i).
676  ComputationDataHandle Rev(const ComputationDataHandle& operand,
677  tensorflow::gtl::ArraySlice<int64> dimensions);
678 
679  // Enqueues a sort (as increasing order) instruction onto the computation.
680  ComputationDataHandle Sort(const ComputationDataHandle& operand);
681 
682  // Enqueues a clamp instruction onto the computation.
683  ComputationDataHandle Clamp(const ComputationDataHandle& min,
684  const ComputationDataHandle& operand,
685  const ComputationDataHandle& max);
686 
687  // Enqueues a map instruction onto the computation.
688  ComputationDataHandle Map(
689  tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
690  const Computation& computation,
691  tensorflow::gtl::ArraySlice<int64> dimensions,
692  tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {});
693 
694  // Enqueues a N(mu, sigma) random number generation instruction onto the
695  // computation.
696  ComputationDataHandle RngNormal(const ComputationDataHandle& mu,
697  const ComputationDataHandle& sigma,
698  const Shape& shape);
699 
700  // Enqueues a U(a, b) random number generation instruction onto the
701  // computation. Returns values in the semi-open interval [a, b).
702  ComputationDataHandle RngUniform(const ComputationDataHandle& a,
703  const ComputationDataHandle& b,
704  const Shape& shape);
705 
706  // Enqueues a while node onto the computation.
707  ComputationDataHandle While(const Computation& condition,
708  const Computation& body,
709  const ComputationDataHandle& init);
710 
711  // Enqueues a conditional node onto the computation.
712  ComputationDataHandle Conditional(const ComputationDataHandle& predicate,
713  const ComputationDataHandle& true_operand,
714  const Computation& true_computation,
715  const ComputationDataHandle& false_operand,
716  const Computation& false_computation);
717 
718  // Enqueues a ReducePrecision node onto the computation.
719  ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand,
720  const int exponent_bits,
721  const int mantissa_bits);
722 
723  // Enqueues a Gather node onto the computation.
724  ComputationDataHandle Gather(
725  const ComputationDataHandle& input,
726  const ComputationDataHandle& gather_indices,
727  const GatherDimensionNumbers& dimension_numbers,
728  tensorflow::gtl::ArraySlice<int64> window_bounds);
729 
730  // Enqueues a Send node onto the computation, to send the given operand to
731  // a Recv instruction that shares the same channel handle.
732  void Send(const ComputationDataHandle& operand, const ChannelHandle& handle);
733 
734  // Enqueues a Recv node onto the computation. The data comes from a Send
735  // instruction that shares the same channel handle and its shape must
736  // be the same as the given shape.
737  ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle);
738 
739  // Returns true if 'operand' is a compile-time constant. A compile-time
740  // constant does not depend on parameters with index greater than or equal to
741  // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
742  // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
743  // compile-time constant without evaluating the computation.
744  StatusOr<bool> IsConstant(const ComputationDataHandle& operand,
745  int64 num_parameters = 0);
746 
747  // Normalizes operand across spatial and batch dimensions for each feature.
748  //
749  // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
750  // is the normalized result and batch_mean and batch_var are the mean and
751  // variance, respectively, across batch for the operand.
752  ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand,
753  const ComputationDataHandle& scale,
754  const ComputationDataHandle& offset,
755  float epsilon, int64 feature_index);
756 
757  // Normalizes operand across spatial and batch dimensions for each feature.
758  //
759  // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
760  // computing `mean` and `variance` for each batch inside the operation. It
761  // uses the input `mean` and `variance` instead as estimated values. The
762  // purpose of this op is to reduce latency in inference, hence the name
763  // `BatchNormInference`.
764  //
765  // The output has the same shape as `operand`, and contains the normalized
766  // values for each batch.
767  ComputationDataHandle BatchNormInference(
768  const ComputationDataHandle& operand, const ComputationDataHandle& scale,
769  const ComputationDataHandle& offset, const ComputationDataHandle& mean,
770  const ComputationDataHandle& variance, float epsilon,
771  int64 feature_index);
772 
773  // Calculates the gradients of a batch norm op.
774  //
775  // The inputs `batch_mean` and `batch_var` represent the mean and variance
776  // across the batch.
777  //
778  // Returns a tuple of three elements:
779  // - grad_operand: Gradient with respect to input `operand`
780  // - grad_offset: Gradient with respect to input `offset`
781  // - grad_scale: Gradient with respect to input `scale`
782  ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand,
783  const ComputationDataHandle& scale,
784  const ComputationDataHandle& batch_mean,
785  const ComputationDataHandle& batch_var,
786  const ComputationDataHandle& grad_output,
787  float epsilon, int64 feature_index);
788 
789  // Computes the value of a constant indicated by a
790  // ComputationDataHandle using a non-optimized interpreter on the host.
791  //
792  // The operand must be from the computation currently being built -
793  // i.e., returned from this builder with no intervening call to
794  // Build(). This happens to currently work regardless of that, but
795  // that may stop working at any time.
796  //
797  // The operand must represent a constant value, which in this case
798  // means that it must not statically depend on any parameter of the
799  // computation that is being built other then the ones specified on the
800  // parameter list. The parameters in the list will be indexed by their
801  // parameter id property so the number of parameters specified should be at
802  // least as many as the largest used parameter index.
803  //
804  // `IsConstant` can be used to test whether a computation is a compile-time
805  // constant without evaluation it. `ComputeConstant` only succeeds for
806  // computations where `IsConstant` returns true.
807  //
808  // This functionality can be useful when translating a computation
809  // into XLA where something that looked dynamic is required by
810  // XLA to be specified as a constant. E.g. the source
811  // computation (outside of XLA) may include a dynamic
812  // computation of the shape of something and ComputeConstant lets
813  // you determine what the value of that computation is in the case
814  // where the value can be determined at compile time.
815  //
816  // If output_layout is non-null, then the output of the computation
817  // will be stored using that layout.
818  StatusOr<std::unique_ptr<Literal>> ComputeConstant(
819  const ComputationDataHandle& operand,
820  const Layout* output_layout = nullptr,
821  tensorflow::gtl::ArraySlice<Literal> parameters = {});
822 
823  // Returns a new ComputationBuilder whose resultant Computation is used only
824  // by this ComputationBuilder. The sub-ComputationBuilder has the same
825  // die_immediately_on_error behavior as the parent.
826  std::unique_ptr<ComputationBuilder> CreateSubBuilder(
827  const string& computation_name);
828 
829  // Modifies the computation being built so that executions of it
830  // will return the value associated with operand, rather than the
831  // last expression enqueued on the ComputationBuilder. Any subsequent
832  // operations added to the ComputationBuilder will not have any effect unless
833  // SetReturnValue is called again.
834  Status SetReturnValue(const ComputationDataHandle& operand);
835 
836  // Builds the computation with the requested operations, or returns a non-ok
837  // status.
838  StatusOr<Computation> Build();
839 
840  // Builds the computation with the requested operations, or notes an error in
841  // the parent ComputationBuilder and returns an empty computation if building
842  // failed. This function is intended to be used where the returned
843  // Computation is only used by the parent ComputationBuilder and hence further
844  // operation on the returned Computation will simply be error'ed out if an
845  // error occurred while building this computation. If the built computation is
846  // to be used by a ComputationBuilder other than the parent ComputationBuilder
847  // then Build() should be used instead.
848  Computation BuildAndNoteError();
849 
850  // Returns the first error that was encountered while building the
851  // computation. When an error is encountered, by default we return a vacuous
852  // ComputationDataHandle and inform the user of the error that occurred while
853  // building the computation when they make a final call to Build().
854  //
855  // See also set_die_immediately_on_error().
856  Status first_error() const { return first_error_; }
857 
858  private:
859  // Limited checking of convolution parameters. Returns false on
860  // error.
861  bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape,
862  const ConvolutionDimensionNumbers& dimension_numbers);
863 
864  // The parent ComputationBuilder of a sub-ComputationBuilder. The
865  // parent_builder_ will be the nullptr if not a sub-ComputationBuilder.
866  ComputationBuilder* parent_builder_{nullptr};
867 
868  // Helper function for creating a Window proto from user-supplied
869  // data. Returns true if the user-supplied data was valid.
870  bool MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions,
871  tensorflow::gtl::ArraySlice<int64> window_strides,
872  tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
873  tensorflow::gtl::ArraySlice<int64> lhs_dilation,
874  tensorflow::gtl::ArraySlice<int64> rhs_dilation,
875  Window* window);
876 
877  // Internal helper method that does the building for an arbitrary unary op.
878  ComputationDataHandle UnaryOp(UnaryOperation unop,
879  const ComputationDataHandle& operand);
880 
881  // Internal helper method that does the building for an arbitrary binary op.
882  // broadcast_dimensions specifies which dimensions to use for broadcasting
883  // when the operation is between tensors of different ranks.
884  ComputationDataHandle BinaryOp(
885  BinaryOperation binop, const ComputationDataHandle& lhs,
886  const ComputationDataHandle& rhs,
887  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
888 
889  // Internal helper method that does the building for an arbitrary ternary op.
890  ComputationDataHandle TernaryOp(TernaryOperation triop,
891  const ComputationDataHandle& lhs,
892  const ComputationDataHandle& rhs,
893  const ComputationDataHandle& ehs);
894 
895  // Internal helper method that does the building for a random number generator
896  // of a given distribution with an explicitly specified shape.
897  ComputationDataHandle RngOp(
898  RandomDistribution distribution,
899  tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
900  const Shape& shape);
901 
902  // Populates computation_ with a valid object or returns a failing status.
903  // This is used before any given operation is enqueued.
904  Status PrepareComputation();
905 
906  // Notes that the error occurred by:
907  // * storing it internally and capturing a backtrace if it's the first error
908  // (this deferred value will be produced on the call to Build())
909  // * dying if die_immediately_on_error_ is true
910  void NoteError(const Status& error);
911 
912  // Helper function that runs the given op_request, filling in op_response.
913  // Before the op is run, PrepareComputation is called, and common fields in
914  // the op_request are filled in.
915  Status RunOp(OpRequest* op_request, OpResponse* op_response);
916 
917  // Helper function that calls RunOp and calls NoteError on failures.
918  void RunOpAndNoteError(OpRequest* op_request);
919 
920  // Helper function that calls RunOp and either returns the output computation
921  // data handle (on success) or a vacuous computation data handle (on failure).
922  ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request);
923 
924  // Helper function that implements GetShape without noting errors. This makes
925  // it easier to ensure the real GetShape will note errors on every error path.
926  StatusOr<std::unique_ptr<Shape>> GetShapeWithoutNoteError(
927  const ComputationDataHandle& operand);
928 
929  string name_; // Name to use for the built computation.
930 
931  // The first error encountered while building the computation.
932  // This is OK until the first error is encountered.
933  Status first_error_;
934 
935  // The saved stack trace from the point at which the first error occurred.
936  tensorflow::SavedStackTrace first_error_backtrace_;
937 
938  // The computation that operations are enqueued onto.
939  Computation computation_;
940 
941  // The client that the computation is created in. Not owned.
942  Client* client_;
943 
944  // Mode bit that indicates whether to die when a first error is encountered.
945  bool die_immediately_on_error_ = false;
946 
947  // The metadata to attach to each op. This is structured as a "modal"-like
948  // operation, in order to simplify client code (and not sprinkle this metadata
949  // throughout the TensorFlow op kernel implementations).
950  OpMetadata metadata_;
951 
952  // Sharding for this operator. This is structured as a "model"-like operation,
953  // in order to simplify client code, similar to metadata_.
954  tensorflow::gtl::optional<OpSharding> sharding_;
955 
956  TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
957 };
958 
959 template <typename NativeT>
960 ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
961  return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
962 }
963 
964 template <typename NativeT>
965 ComputationDataHandle ComputationBuilder::ConstantR1(
966  tensorflow::gtl::ArraySlice<NativeT> values) {
967  return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
968 }
969 
970 template <typename NativeT>
971 ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
972  NativeT value) {
973  Literal literal(ShapeUtil::MakeShape(
974  primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
975  literal.PopulateWithValue(value);
976  return ConstantLiteral(literal);
977 }
978 
979 inline ComputationDataHandle ComputationBuilder::ConstantR1(
980  const tensorflow::core::Bitmap& values) {
981  return ConstantLiteral(*Literal::CreateR1(values));
982 }
983 
984 template <typename NativeT>
985 ComputationDataHandle ComputationBuilder::ConstantR2(
986  std::initializer_list<std::initializer_list<NativeT>> values) {
987  return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
988 }
989 
990 template <typename NativeT>
991 ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
992  const Array<NativeT>& values, const Layout& layout) {
993  return ConstantLiteral(
994  *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
995 }
996 
997 template <typename NativeT>
998 ComputationDataHandle ComputationBuilder::ConstantFromArray(
999  const Array<NativeT>& values) {
1000  return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
1001 }
1002 
1003 template <typename NativeT>
1004 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
1005  const Array2D<NativeT>& values, const Layout& layout) {
1006  return ConstantLiteral(
1007  *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
1008 }
1009 
1010 template <typename NativeT>
1011 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
1012  const Array2D<NativeT>& values) {
1013  return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
1014 }
1015 
1016 template <typename NativeT>
1017 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
1018  const Array3D<NativeT>& values, const Layout& layout) {
1019  return ConstantLiteral(
1020  *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
1021 }
1022 
1023 template <typename NativeT>
1024 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
1025  const Array3D<NativeT>& values) {
1026  return ConstantFromArray(values);
1027 }
1028 
1029 template <typename NativeT>
1030 ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
1031  const Array4D<NativeT>& values, const Layout& layout) {
1032  return ConstantFromArrayWithLayout(values, layout);
1033 }
1034 
1035 template <typename NativeT>
1036 ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
1037  const Array4D<NativeT>& values) {
1038  return ConstantFromArray(values);
1039 }
1040 
1041 // RAII-style object: sets the current sharding assignment in builder on
1042 // construction, and sets back to the previous assignment on destruction.
1043 class ScopedShardingAssignment {
1044  public:
1045  ScopedShardingAssignment(xla::ComputationBuilder* builder,
1046  tensorflow::gtl::optional<OpSharding> sharding)
1047  : builder_(builder), prev_sharding_(builder->sharding()) {
1048  SetSharding(sharding);
1049  }
1050 
1051  ~ScopedShardingAssignment() { SetSharding(prev_sharding_); }
1052 
1053  private:
1054  void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
1055  if (sharding.has_value()) {
1056  builder_->SetSharding(sharding.value());
1057  } else {
1058  builder_->ClearSharding();
1059  }
1060  }
1061 
1062  xla::ComputationBuilder* const builder_;
1063  tensorflow::gtl::optional<OpSharding> prev_sharding_;
1064 
1065  TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment);
1066 };
1067 
1068 } // namespace xla
1069 
1070 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
ComputationDataHandle Parameter(int64 parameter_number, const Shape &shape, const string &name)
Enqueues a "retrieve parameter value" instruction to the UserComputation.
Definition: computation_builder.cc:228
Definition: computation_builder.h:59
Status RunOp(OpRequest *op_request, OpResponse *op_response)
Run the given parameter op_request and fill in op_response.
Definition: computation_builder.cc:109
ComputationDataHandle RunOpAndParseResponse(OpRequest *op_request)
Call RunOp() and either return the output ComputationDataHandle (on success) or an empty ComputationD...
Definition: computation_builder.cc:142
Status PrepareComputation()
Populates computation_ with a valid object. Used before any given operation is enqueued.
Definition: computation_builder.cc:79
namespace for xla
Definition: client_library.cc:26