18 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ 19 #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ 22 #include <initializer_list> 27 #include "tensorflow/compiler/xla/array.h" 28 #include "tensorflow/compiler/xla/array2d.h" 29 #include "tensorflow/compiler/xla/array3d.h" 30 #include "tensorflow/compiler/xla/array4d.h" 31 #include "tensorflow/compiler/xla/client/client.h" 32 #include "tensorflow/compiler/xla/client/computation.h" 33 #include "tensorflow/compiler/xla/client/global_data.h" 34 #include "tensorflow/compiler/xla/client/padding.h" 35 #include "tensorflow/compiler/xla/literal_util.h" 36 #include "tensorflow/compiler/xla/statusor.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/lib/core/bitmap.h" 40 #include "tensorflow/core/lib/core/stringpiece.h" 41 #include "tensorflow/core/lib/gtl/array_slice.h" 42 #include "tensorflow/core/platform/macros.h" 43 #include "tensorflow/core/platform/stacktrace.h" 44 #include "tensorflow/core/platform/types.h" 68 Client* client()
const {
return client_; }
71 const string& name()
const {
return name_; }
79 void SetOpMetadata(
const OpMetadata& metadata) { metadata_ = metadata; }
82 void ClearOpMetadata() { metadata_.Clear(); }
85 void SetSharding(
const OpSharding& sharding) { sharding_ = sharding; }
89 void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
92 const tensorflow::gtl::optional<OpSharding>& sharding()
const {
99 void set_die_immediately_on_error(
bool enabled) {
100 die_immediately_on_error_ = enabled;
105 ComputationDataHandle
Parameter(int64 parameter_number,
const Shape& shape,
109 StatusOr<std::unique_ptr<Shape>> GetShape(
110 const ComputationDataHandle& operand);
113 StatusOr<ProgramShape> GetProgramShape();
117 ComputationDataHandle ConstantLiteral(
const Literal& literal);
135 template <
typename NativeT>
136 ComputationDataHandle ConstantR0(NativeT value);
137 template <
typename NativeT>
138 ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
139 ComputationDataHandle ConstantR1(
const tensorflow::core::Bitmap& values);
140 template <
typename NativeT>
141 ComputationDataHandle ConstantR2(
142 std::initializer_list<std::initializer_list<NativeT>> values);
143 template <
typename NativeT>
144 ComputationDataHandle ConstantFromArrayWithLayout(
145 const Array<NativeT>& values,
const Layout& layout);
146 template <
typename NativeT>
147 ComputationDataHandle ConstantFromArray(
const Array<NativeT>& values);
148 template <
typename NativeT>
149 ComputationDataHandle ConstantR2FromArray2DWithLayout(
150 const Array2D<NativeT>& values,
const Layout& layout);
151 template <
typename NativeT>
152 ComputationDataHandle ConstantR2FromArray2D(
const Array2D<NativeT>& values);
153 template <
typename NativeT>
154 ComputationDataHandle ConstantR3FromArray3DWithLayout(
155 const Array3D<NativeT>& values,
const Layout& layout);
156 template <
typename NativeT>
157 ComputationDataHandle ConstantR3FromArray3D(
const Array3D<NativeT>& values);
158 template <
typename NativeT>
159 ComputationDataHandle ConstantR4FromArray4DWithLayout(
160 const Array4D<NativeT>& values,
const Layout& layout);
161 template <
typename NativeT>
162 ComputationDataHandle ConstantR4FromArray4D(
const Array4D<NativeT>& values);
166 template <
typename NativeT>
167 ComputationDataHandle ConstantR1(int64 length, NativeT value);
179 ComputationDataHandle Broadcast(
180 const ComputationDataHandle& operand,
181 tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
186 ComputationDataHandle Pad(
const ComputationDataHandle& operand,
187 const ComputationDataHandle& padding_value,
188 const PaddingConfig& padding_config);
195 ComputationDataHandle Reshape(
const ComputationDataHandle& operand,
196 tensorflow::gtl::ArraySlice<int64> dimensions,
197 tensorflow::gtl::ArraySlice<int64> new_sizes);
202 ComputationDataHandle Reshape(
const ComputationDataHandle& operand,
203 tensorflow::gtl::ArraySlice<int64> new_sizes);
223 ComputationDataHandle Collapse(
const ComputationDataHandle& operand,
224 tensorflow::gtl::ArraySlice<int64> dimensions);
237 ComputationDataHandle Slice(
const ComputationDataHandle& operand,
238 tensorflow::gtl::ArraySlice<int64> start_indices,
239 tensorflow::gtl::ArraySlice<int64> limit_indices,
240 tensorflow::gtl::ArraySlice<int64> strides);
248 ComputationDataHandle SliceInDim(
const ComputationDataHandle& operand,
249 int64 start_index, int64 limit_index,
250 int64 stride, int64 dimno);
261 ComputationDataHandle DynamicSlice(
262 const ComputationDataHandle& operand,
263 const ComputationDataHandle& start_indices,
264 tensorflow::gtl::ArraySlice<int64> slice_sizes);
282 ComputationDataHandle DynamicUpdateSlice(
283 const ComputationDataHandle& operand,
const ComputationDataHandle& update,
284 const ComputationDataHandle& start_indices);
288 ComputationDataHandle ConcatInDim(
289 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
294 void Trace(
const string& tag,
const ComputationDataHandle& operand);
298 ComputationDataHandle Select(
const ComputationDataHandle& pred,
299 const ComputationDataHandle& on_true,
300 const ComputationDataHandle& on_false);
303 ComputationDataHandle Tuple(
304 tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
307 ComputationDataHandle GetTupleElement(
const ComputationDataHandle& tuple_data,
311 ComputationDataHandle Eq(
312 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
313 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
316 ComputationDataHandle Ne(
317 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
318 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
321 ComputationDataHandle Ge(
322 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
323 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
326 ComputationDataHandle Gt(
327 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
328 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
331 ComputationDataHandle Lt(
332 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
333 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
336 ComputationDataHandle Le(
337 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
338 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
341 ComputationDataHandle Dot(
const ComputationDataHandle& lhs,
342 const ComputationDataHandle& rhs);
345 ComputationDataHandle DotGeneral(
346 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
347 const DotDimensionNumbers& dimension_numbers);
350 static constexpr int64 kConvBatchDimension = 0;
351 static constexpr int64 kConvFeatureDimension = 1;
352 static constexpr int64 kConvFirstSpatialDimension = 2;
353 static constexpr int64 kConvSecondSpatialDimension = 3;
354 static constexpr int64 kConvKernelOutputDimension = 0;
355 static constexpr int64 kConvKernelInputDimension = 1;
356 static constexpr int64 kConvKernelFirstSpatialDimension = 2;
357 static constexpr int64 kConvKernelSecondSpatialDimension = 3;
363 static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
364 int num_spatial_dims = 2);
368 static StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers(
369 int64 input_batch, int64 input_feature, int64 input_first_spatial,
370 int64 input_second_spatial, int64 output_batch, int64 output_feature,
371 int64 output_first_spatial, int64 output_second_spatial,
372 int64 kernel_output_feature, int64 kernel_input_feature,
373 int64 kernel_first_spatial, int64 kernel_second_spatial);
377 ComputationDataHandle Conv(
const ComputationDataHandle& lhs,
378 const ComputationDataHandle& rhs,
379 tensorflow::gtl::ArraySlice<int64> window_strides,
384 ComputationDataHandle ConvWithGeneralPadding(
385 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
386 tensorflow::gtl::ArraySlice<int64> window_strides,
387 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
391 ComputationDataHandle ConvWithGeneralDimensions(
392 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
393 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
394 const ConvolutionDimensionNumbers& dimension_numbers);
398 ComputationDataHandle ConvGeneral(
399 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
400 tensorflow::gtl::ArraySlice<int64> window_strides,
401 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
402 const ConvolutionDimensionNumbers& dimension_numbers);
406 ComputationDataHandle ConvGeneralDilated(
407 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
408 tensorflow::gtl::ArraySlice<int64> window_strides,
409 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
410 tensorflow::gtl::ArraySlice<int64> lhs_dilation,
411 tensorflow::gtl::ArraySlice<int64> rhs_dilation,
412 const ConvolutionDimensionNumbers& dimension_numbers);
416 ComputationDataHandle Fft(
const ComputationDataHandle& operand,
418 tensorflow::gtl::ArraySlice<int64> fft_length);
422 ComputationDataHandle Infeed(
const Shape& shape,
const string& config =
"");
430 void Outfeed(
const ComputationDataHandle& operand,
431 const Shape& shape_with_layout,
const string& outfeed_config);
434 ComputationDataHandle Call(
435 const Computation& computation,
436 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
442 ComputationDataHandle CustomCall(
443 const string& call_target_name,
444 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
453 ComputationDataHandle HostCompute(
454 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
455 const string& channel_name, int64 cost_estimate_ns,
const Shape& shape);
463 ComputationDataHandle Complex(
464 const ComputationDataHandle& real,
const ComputationDataHandle& imag,
465 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
468 ComputationDataHandle Conj(
const ComputationDataHandle& operand);
471 ComputationDataHandle Add(
472 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
473 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
476 ComputationDataHandle Sub(
477 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
478 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
481 ComputationDataHandle Mul(
482 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
483 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
486 ComputationDataHandle Div(
487 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
488 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
491 ComputationDataHandle Rem(
492 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
493 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
496 ComputationDataHandle Max(
497 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
498 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
501 ComputationDataHandle Min(
502 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
503 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
506 ComputationDataHandle And(
507 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
508 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
510 ComputationDataHandle Or(
511 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
512 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
514 ComputationDataHandle Xor(
515 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
516 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
518 ComputationDataHandle Not(
const ComputationDataHandle& operand);
520 ComputationDataHandle ShiftLeft(
521 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
522 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
523 ComputationDataHandle ShiftRightArithmetic(
524 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
525 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
526 ComputationDataHandle ShiftRightLogical(
527 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
528 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
532 ComputationDataHandle Reduce(
533 const ComputationDataHandle& operand,
534 const ComputationDataHandle& init_value,
const Computation& computation,
535 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
539 ComputationDataHandle ReduceAll(
const ComputationDataHandle& operand,
540 const ComputationDataHandle& init_value,
541 const Computation& computation);
544 ComputationDataHandle ReduceWindow(
545 const ComputationDataHandle& operand,
546 const ComputationDataHandle& init_value,
const Computation& computation,
547 tensorflow::gtl::ArraySlice<int64> window_dimensions,
548 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
552 ComputationDataHandle ReduceWindowWithGeneralPadding(
553 const ComputationDataHandle& operand,
554 const ComputationDataHandle& init_value,
const Computation& computation,
555 tensorflow::gtl::ArraySlice<int64> window_dimensions,
556 tensorflow::gtl::ArraySlice<int64> window_strides,
557 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
561 ComputationDataHandle CrossReplicaSum(
const ComputationDataHandle& operand);
565 ComputationDataHandle SelectAndScatter(
566 const ComputationDataHandle& operand,
const Computation& select,
567 tensorflow::gtl::ArraySlice<int64> window_dimensions,
568 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
569 const ComputationDataHandle& source,
570 const ComputationDataHandle& init_value,
const Computation& scatter);
574 ComputationDataHandle SelectAndScatterWithGeneralPadding(
575 const ComputationDataHandle& operand,
const Computation& select,
576 tensorflow::gtl::ArraySlice<int64> window_dimensions,
577 tensorflow::gtl::ArraySlice<int64> window_strides,
578 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
579 const ComputationDataHandle& source,
580 const ComputationDataHandle& init_value,
const Computation& scatter);
583 ComputationDataHandle Abs(
const ComputationDataHandle& operand);
586 ComputationDataHandle Atan2(
587 const ComputationDataHandle& y,
const ComputationDataHandle& x,
588 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
591 ComputationDataHandle Exp(
const ComputationDataHandle& operand);
594 ComputationDataHandle Floor(
const ComputationDataHandle& operand);
597 ComputationDataHandle Ceil(
const ComputationDataHandle& operand);
601 ComputationDataHandle Round(
const ComputationDataHandle& operand);
604 ComputationDataHandle Log(
const ComputationDataHandle& operand);
607 ComputationDataHandle Sign(
const ComputationDataHandle& operand);
610 ComputationDataHandle Cos(
const ComputationDataHandle& operand);
613 ComputationDataHandle Sin(
const ComputationDataHandle& operand);
616 ComputationDataHandle Tanh(
const ComputationDataHandle& operand);
619 ComputationDataHandle Real(
const ComputationDataHandle& operand);
622 ComputationDataHandle Imag(
const ComputationDataHandle& operand);
627 ComputationDataHandle SqrtF32(
const ComputationDataHandle& operand);
632 ComputationDataHandle SquareF32(
const ComputationDataHandle& operand);
635 ComputationDataHandle Pow(
636 const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs,
637 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
643 ComputationDataHandle IsFinite(
const ComputationDataHandle& operand);
647 ComputationDataHandle ConvertElementType(
const ComputationDataHandle& operand,
648 PrimitiveType new_element_type);
654 ComputationDataHandle BitcastConvertType(
const ComputationDataHandle& operand,
655 PrimitiveType new_element_type);
663 ComputationDataHandle ReciprocalF32(
const ComputationDataHandle& operand);
666 ComputationDataHandle Neg(
const ComputationDataHandle& operand);
669 ComputationDataHandle Transpose(
670 const ComputationDataHandle& operand,
671 tensorflow::gtl::ArraySlice<int64> permutation);
676 ComputationDataHandle Rev(
const ComputationDataHandle& operand,
677 tensorflow::gtl::ArraySlice<int64> dimensions);
680 ComputationDataHandle Sort(
const ComputationDataHandle& operand);
683 ComputationDataHandle Clamp(
const ComputationDataHandle& min,
684 const ComputationDataHandle& operand,
685 const ComputationDataHandle& max);
688 ComputationDataHandle Map(
689 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
690 const Computation& computation,
691 tensorflow::gtl::ArraySlice<int64> dimensions,
692 tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {});
696 ComputationDataHandle RngNormal(
const ComputationDataHandle& mu,
697 const ComputationDataHandle& sigma,
702 ComputationDataHandle RngUniform(
const ComputationDataHandle& a,
703 const ComputationDataHandle& b,
707 ComputationDataHandle While(
const Computation& condition,
708 const Computation& body,
709 const ComputationDataHandle& init);
712 ComputationDataHandle Conditional(
const ComputationDataHandle& predicate,
713 const ComputationDataHandle& true_operand,
714 const Computation& true_computation,
715 const ComputationDataHandle& false_operand,
716 const Computation& false_computation);
719 ComputationDataHandle ReducePrecision(
const ComputationDataHandle& operand,
720 const int exponent_bits,
721 const int mantissa_bits);
724 ComputationDataHandle Gather(
725 const ComputationDataHandle& input,
726 const ComputationDataHandle& gather_indices,
727 const GatherDimensionNumbers& dimension_numbers,
728 tensorflow::gtl::ArraySlice<int64> window_bounds);
732 void Send(
const ComputationDataHandle& operand,
const ChannelHandle& handle);
737 ComputationDataHandle Recv(
const Shape& shape,
const ChannelHandle& handle);
744 StatusOr<bool> IsConstant(
const ComputationDataHandle& operand,
745 int64 num_parameters = 0);
752 ComputationDataHandle BatchNormTraining(
const ComputationDataHandle& operand,
753 const ComputationDataHandle& scale,
754 const ComputationDataHandle& offset,
755 float epsilon, int64 feature_index);
767 ComputationDataHandle BatchNormInference(
768 const ComputationDataHandle& operand,
const ComputationDataHandle& scale,
769 const ComputationDataHandle& offset,
const ComputationDataHandle& mean,
770 const ComputationDataHandle& variance,
float epsilon,
771 int64 feature_index);
782 ComputationDataHandle BatchNormGrad(
const ComputationDataHandle& operand,
783 const ComputationDataHandle& scale,
784 const ComputationDataHandle& batch_mean,
785 const ComputationDataHandle& batch_var,
786 const ComputationDataHandle& grad_output,
787 float epsilon, int64 feature_index);
818 StatusOr<std::unique_ptr<Literal>> ComputeConstant(
819 const ComputationDataHandle& operand,
820 const Layout* output_layout =
nullptr,
821 tensorflow::gtl::ArraySlice<Literal> parameters = {});
826 std::unique_ptr<ComputationBuilder> CreateSubBuilder(
827 const string& computation_name);
834 Status SetReturnValue(
const ComputationDataHandle& operand);
838 StatusOr<Computation> Build();
848 Computation BuildAndNoteError();
856 Status first_error()
const {
return first_error_; }
861 bool VerifyConvolution(
const Shape& lhs_shape,
const Shape& rhs_shape,
862 const ConvolutionDimensionNumbers& dimension_numbers);
870 bool MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions,
871 tensorflow::gtl::ArraySlice<int64> window_strides,
872 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
873 tensorflow::gtl::ArraySlice<int64> lhs_dilation,
874 tensorflow::gtl::ArraySlice<int64> rhs_dilation,
878 ComputationDataHandle UnaryOp(UnaryOperation unop,
879 const ComputationDataHandle& operand);
884 ComputationDataHandle BinaryOp(
885 BinaryOperation binop,
const ComputationDataHandle& lhs,
886 const ComputationDataHandle& rhs,
887 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
890 ComputationDataHandle TernaryOp(TernaryOperation triop,
891 const ComputationDataHandle& lhs,
892 const ComputationDataHandle& rhs,
893 const ComputationDataHandle& ehs);
897 ComputationDataHandle RngOp(
898 RandomDistribution distribution,
899 tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
910 void NoteError(
const Status& error);
915 Status
RunOp(OpRequest* op_request, OpResponse* op_response);
918 void RunOpAndNoteError(OpRequest* op_request);
926 StatusOr<std::unique_ptr<Shape>> GetShapeWithoutNoteError(
927 const ComputationDataHandle& operand);
936 tensorflow::SavedStackTrace first_error_backtrace_;
939 Computation computation_;
945 bool die_immediately_on_error_ =
false;
950 OpMetadata metadata_;
954 tensorflow::gtl::optional<OpSharding> sharding_;
959 template <
typename NativeT>
960 ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
961 return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
964 template <
typename NativeT>
965 ComputationDataHandle ComputationBuilder::ConstantR1(
966 tensorflow::gtl::ArraySlice<NativeT> values) {
967 return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
970 template <
typename NativeT>
971 ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
973 Literal literal(ShapeUtil::MakeShape(
974 primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
975 literal.PopulateWithValue(value);
976 return ConstantLiteral(literal);
979 inline ComputationDataHandle ComputationBuilder::ConstantR1(
980 const tensorflow::core::Bitmap& values) {
981 return ConstantLiteral(*Literal::CreateR1(values));
984 template <
typename NativeT>
985 ComputationDataHandle ComputationBuilder::ConstantR2(
986 std::initializer_list<std::initializer_list<NativeT>> values) {
987 return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
990 template <
typename NativeT>
991 ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
992 const Array<NativeT>& values,
const Layout& layout) {
993 return ConstantLiteral(
994 *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
997 template <
typename NativeT>
998 ComputationDataHandle ComputationBuilder::ConstantFromArray(
999 const Array<NativeT>& values) {
1000 return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
1003 template <
typename NativeT>
1004 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
1005 const Array2D<NativeT>& values,
const Layout& layout) {
1006 return ConstantLiteral(
1007 *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
1010 template <
typename NativeT>
1011 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
1012 const Array2D<NativeT>& values) {
1013 return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
1016 template <
typename NativeT>
1017 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
1018 const Array3D<NativeT>& values,
const Layout& layout) {
1019 return ConstantLiteral(
1020 *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
1023 template <
typename NativeT>
1024 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
1025 const Array3D<NativeT>& values) {
1026 return ConstantFromArray(values);
1029 template <
typename NativeT>
1030 ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
1031 const Array4D<NativeT>& values,
const Layout& layout) {
1032 return ConstantFromArrayWithLayout(values, layout);
1035 template <
typename NativeT>
1036 ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
1037 const Array4D<NativeT>& values) {
1038 return ConstantFromArray(values);
1043 class ScopedShardingAssignment {
1046 tensorflow::gtl::optional<OpSharding> sharding)
1047 : builder_(builder), prev_sharding_(builder->sharding()) {
1048 SetSharding(sharding);
1051 ~ScopedShardingAssignment() { SetSharding(prev_sharding_); }
1054 void SetSharding(
const tensorflow::gtl::optional<OpSharding>& sharding) {
1055 if (sharding.has_value()) {
1056 builder_->SetSharding(sharding.value());
1058 builder_->ClearSharding();
1063 tensorflow::gtl::optional<OpSharding> prev_sharding_;
1065 TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment);
1070 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ ComputationDataHandle Parameter(int64 parameter_number, const Shape &shape, const string &name)
Enqueues a "retrieve parameter value" instruction to the UserComputation.
Definition: computation_builder.cc:228
Definition: computation_builder.h:59
Status RunOp(OpRequest *op_request, OpResponse *op_response)
Run the given parameter op_request and fill in op_response.
Definition: computation_builder.cc:109
ComputationDataHandle RunOpAndParseResponse(OpRequest *op_request)
Call RunOp() and either return the output ComputationDataHandle (on success) or an empty ComputationD...
Definition: computation_builder.cc:142
Status PrepareComputation()
Populates computation_ with a valid object. Used before any given operation is enqueued.
Definition: computation_builder.cc:79
namespace for xla
Definition: client_library.cc:26