23 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 24 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 33 #include <unordered_map> 34 #include <unordered_set> 37 #include "tensorflow/compiler/xla/iterator_util.h" 38 #include "tensorflow/compiler/xla/literal_util.h" 39 #include "tensorflow/compiler/xla/map_util.h" 40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 41 #include "tensorflow/compiler/xla/service/hlo.pb.h" 42 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 43 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 44 #include "tensorflow/compiler/xla/service/name_uniquer.h" 45 #include "tensorflow/compiler/xla/types.h" 46 #include "tensorflow/compiler/xla/xla_data.pb.h" 47 #include "tensorflow/core/lib/core/status.h" 48 #include "tensorflow/core/lib/core/stringpiece.h" 49 #include "tensorflow/core/lib/gtl/array_slice.h" 50 #include "tensorflow/core/lib/gtl/flatmap.h" 51 #include "tensorflow/core/lib/gtl/inlined_vector.h" 52 #include "tensorflow/core/lib/gtl/iterator_range.h" 53 #include "tensorflow/core/platform/logging.h" 54 #include "tensorflow/core/platform/macros.h" 55 #include "tensorflow/core/platform/types.h" 63 class HloPrintOptions {
68 : print_large_constants_(false),
69 print_subcomputation_references_(true),
70 print_metadata_(true),
71 compact_operands_(false),
72 print_operand_shape_(true),
73 print_program_shape_(true),
77 static HloPrintOptions ShortParsable() {
78 return HloPrintOptions()
79 .set_print_large_constants(
true)
80 .set_print_subcomputation_references(
true)
81 .set_print_metadata(
false)
82 .set_print_operand_shape(
false)
83 .set_print_program_shape(
false)
84 .set_print_percent(
false);
88 HloPrintOptions& set_print_large_constants(
bool value) {
89 print_large_constants_ = value;
99 HloPrintOptions& set_print_subcomputation_references(
bool value) {
100 print_subcomputation_references_ = value;
105 HloPrintOptions& set_print_metadata(
bool value) {
106 print_metadata_ = value;
111 HloPrintOptions& set_print_operand_shape(
bool value) {
112 print_operand_shape_ = value;
117 HloPrintOptions& set_print_program_shape(
bool value) {
118 print_program_shape_ = value;
123 HloPrintOptions& set_print_percent(
bool value) {
124 print_percent_ = value;
130 HloPrintOptions& set_compact_operands(
bool value) {
131 compact_operands_ = value;
136 HloPrintOptions& set_indent_amount(
int value) {
137 indent_amount_ = value;
141 bool print_large_constants()
const {
return print_large_constants_; }
142 bool print_subcomputation_references()
const {
143 return print_subcomputation_references_;
145 bool print_metadata()
const {
return print_metadata_; }
146 bool compact_operands()
const {
return compact_operands_; }
147 bool print_operand_shape()
const {
return print_operand_shape_; }
148 bool print_program_shape()
const {
return print_program_shape_; }
149 bool print_percent()
const {
return print_percent_; }
150 int indent_amount()
const {
return indent_amount_; }
153 bool print_large_constants_;
154 bool print_subcomputation_references_;
155 bool print_metadata_;
156 bool compact_operands_;
157 bool print_operand_shape_;
158 bool print_program_shape_;
194 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
195 HloModule* module,
const HloInstructionProto& proto,
196 const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
197 const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
200 static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
205 static std::unique_ptr<HloInstruction> CreateConstant(
206 std::unique_ptr<Literal> literal);
209 static std::unique_ptr<HloInstruction> CreateGetTupleElement(
213 static std::unique_ptr<HloInstruction> CreateTrace(
const string& tag,
218 static std::unique_ptr<HloInstruction> CreateRng(
219 const Shape& shape, RandomDistribution distribution,
220 tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
224 static std::unique_ptr<HloInstruction> CreateUnary(
const Shape& shape,
230 static std::unique_ptr<HloInstruction> CreateBinary(
const Shape& shape,
237 static std::unique_ptr<HloInstruction> CreateTernary(
const Shape& shape,
245 static std::unique_ptr<HloInstruction> CreateVariadic(
246 const Shape& shape, HloOpcode opcode,
247 tensorflow::gtl::ArraySlice<HloInstruction*> operands);
252 static std::unique_ptr<HloInstruction> CreateMap(
253 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
255 tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
259 static std::unique_ptr<HloInstruction> CreateConvolve(
260 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
261 const Window& window,
262 const ConvolutionDimensionNumbers& dimension_numbers);
265 static std::unique_ptr<HloInstruction> CreateFft(
266 const Shape& shape, HloInstruction* operand, FftType fft_type,
267 tensorflow::gtl::ArraySlice<int64> fft_length);
271 static std::unique_ptr<HloInstruction> CreateDot(
272 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
273 const DotDimensionNumbers& dimension_numbers);
278 static std::unique_ptr<HloInstruction> CreateCanonicalDot(
279 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
284 static std::unique_ptr<HloInstruction> CreateReducePrecision(
285 const Shape& shape, HloInstruction* operand,
const int exponent_bits,
286 const int mantissa_bits);
289 static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
291 tensorflow::gtl::ArraySlice<HloInstruction*> operands);
295 static std::unique_ptr<HloInstruction> CreateConvert(
const Shape& shape,
296 HloInstruction* operand);
300 static std::unique_ptr<HloInstruction> CreateBitcastConvert(
301 const Shape& shape, HloInstruction* operand);
305 static std::unique_ptr<HloInstruction> CreateInfeed(
const Shape& shape,
306 const string& config);
309 static std::unique_ptr<HloInstruction> CreateOutfeed(
310 const Shape& shape, HloInstruction* operand,
311 tensorflow::StringPiece outfeed_config);
316 static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
321 static std::unique_ptr<HloInstruction> CreateSendDone(
322 HloInstruction* operand);
327 static std::unique_ptr<HloInstruction> CreateRecv(
const Shape& shape,
332 static std::unique_ptr<HloInstruction> CreateRecvDone(
333 HloInstruction* operand);
337 static std::unique_ptr<HloInstruction> CreateSlice(
338 const Shape& shape, HloInstruction* operand,
339 tensorflow::gtl::ArraySlice<int64> start_indices,
340 tensorflow::gtl::ArraySlice<int64> limit_indices,
341 tensorflow::gtl::ArraySlice<int64> strides);
346 static std::unique_ptr<HloInstruction> CreateDynamicSlice(
347 const Shape& shape, HloInstruction* operand,
348 HloInstruction* start_indices,
349 tensorflow::gtl::ArraySlice<int64> slice_sizes);
353 static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
354 const Shape& shape, HloInstruction* operand, HloInstruction* update,
355 HloInstruction* start_indices);
359 static std::unique_ptr<HloInstruction> CreateConcatenate(
360 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
370 static std::unique_ptr<HloInstruction> CreateReduce(
371 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
372 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
373 HloComputation* reduce_computation);
378 static std::unique_ptr<HloInstruction> CreateReduceWindow(
379 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
380 const Window& window, HloComputation* reduce_computation);
383 static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
384 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
385 HloInstruction* offset,
float epsilon, int64 feature_index);
388 static std::unique_ptr<HloInstruction> CreateBatchNormInference(
389 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
390 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
391 float epsilon, int64 feature_index);
394 static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
395 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
396 HloInstruction* mean, HloInstruction* variance,
397 HloInstruction* grad_output,
float epsilon, int64 feature_index);
401 static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
402 const Shape& shape, HloInstruction* operand, HloComputation* select,
403 const Window& window, HloInstruction* source, HloInstruction* init_value,
404 HloComputation* scatter);
407 static std::unique_ptr<HloInstruction> CreateBroadcast(
408 const Shape& shape, HloInstruction* operand,
409 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
412 static std::unique_ptr<HloInstruction> CreateBroadcastDimOne(
413 const Shape& shape, HloInstruction* operand);
424 static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
425 const Shape& output_shape, HloInstruction* operand,
426 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
431 static std::unique_ptr<HloInstruction> CreatePad(
432 const Shape& shape, HloInstruction* operand,
433 HloInstruction* padding_value,
const PaddingConfig& padding_config);
437 static std::unique_ptr<HloInstruction> CreateReshape(
const Shape& shape,
438 HloInstruction* operand);
441 static std::unique_ptr<HloInstruction> CreateTranspose(
442 const Shape& shape, HloInstruction* operand,
443 tensorflow::gtl::ArraySlice<int64> dimensions);
450 static std::unique_ptr<HloInstruction> CreateWhile(
const Shape& shape,
451 HloComputation* condition,
452 HloComputation* body,
453 HloInstruction* init);
455 static std::unique_ptr<HloInstruction> CreateConditional(
456 const Shape& shape, HloInstruction* pred,
457 HloInstruction* true_computation_arg, HloComputation* true_computation,
458 HloInstruction* false_computation_arg, HloComputation* false_computation);
460 static std::unique_ptr<HloInstruction> CreateGather(
461 const Shape& shape, HloInstruction* operand,
462 HloInstruction* gather_indices,
463 const GatherDimensionNumbers& gather_dim_numbers,
464 tensorflow::gtl::ArraySlice<int64> window_bounds);
473 const Shape& shape,
FusionKind fusion_kind, HloInstruction* fused_root);
477 tensorflow::gtl::ArraySlice<HloInstruction*> operands,
478 HloComputation* fusion_computation);
482 static std::unique_ptr<HloInstruction> CreateCall(
483 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
484 HloComputation* computation);
488 static std::unique_ptr<HloInstruction> CreateCustomCall(
489 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
490 tensorflow::StringPiece custom_call_target);
494 static std::unique_ptr<HloInstruction> CreateHostCompute(
495 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
496 tensorflow::StringPiece channel_name,
const int64 cost_estimate_ns);
500 static std::unique_ptr<HloInstruction> CreateTuple(
501 tensorflow::gtl::ArraySlice<HloInstruction*> elements);
505 static std::unique_ptr<HloInstruction> CreateReverse(
506 const Shape& shape, HloInstruction* operand,
507 tensorflow::gtl::ArraySlice<int64> dimensions);
510 static GatherDimensionNumbers MakeGatherDimNumbers(
511 tensorflow::gtl::ArraySlice<int64> output_window_dims,
512 tensorflow::gtl::ArraySlice<int64> elided_window_dims,
513 tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
514 int64 index_vector_dim);
517 HloOpcode opcode()
const {
return opcode_; }
522 bool HasSideEffect()
const;
525 const Shape& shape()
const;
528 Shape* mutable_shape() {
return &shape_; }
531 const HloInstruction* operand(int64 i)
const;
534 HloInstruction* mutable_operand(int64 i);
537 int64 operand_count()
const {
return operands_.size(); }
540 using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
541 const InstructionVector& operands()
const {
return operands_; }
545 int64 operand_index(
const HloInstruction* target)
const;
548 int64 user_count()
const {
return users_.size(); }
551 const std::vector<HloInstruction*>& users()
const {
return users_; }
554 bool IsUserOf(
const HloInstruction* instruction)
const {
555 return ContainsKey(instruction->user_set_,
this);
567 Status AddControlDependencyTo(HloInstruction* instruction);
571 Status RemoveControlDependencyTo(HloInstruction* instruction);
576 const std::vector<HloInstruction*>& control_predecessors()
const {
577 return control_predecessors_;
579 const std::vector<HloInstruction*>& control_successors()
const {
580 return control_successors_;
585 const HloInstruction& other,
586 const std::function<
bool(
const HloInstruction*,
const HloInstruction*)>&
587 eq_operands = std::equal_to<const HloInstruction*>(),
588 const std::function<
bool(
const HloComputation*,
const HloComputation*)>&
589 eq_computations = std::equal_to<const HloComputation*>(),
590 bool layout_sensitive =
true)
const {
592 if (
this == &other) {
598 if (opcode() != other.opcode()) {
601 using EqShapeFuncType = bool (*)(
const Shape&,
const Shape&);
602 EqShapeFuncType eq_shapes =
603 layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
604 if (!eq_shapes(shape(), other.shape())) {
607 if (operands().size() != other.operands().size()) {
613 for (
size_t i = 0; i < operands().size(); ++i) {
614 if (!eq_operands(operand(i), other.operand(i))) {
619 return IdenticalSlowPath(other, eq_computations, eq_shapes);
623 bool HasConstantOperand()
const;
626 bool IsRank2Transpose()
const;
631 Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
634 Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
642 Status ReplaceAllUsesWith(HloInstruction* new_producer);
647 void DetachFromOperands();
655 template <
typename HloInstructionPtr>
656 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
657 bool call_finish_visit =
true,
658 bool ignore_control_predecessors =
false);
659 Status Accept(ConstDfsHloVisitor* visitor,
bool call_finish_visit =
true,
660 bool ignore_control_predecessors =
false)
const {
661 return const_cast<HloInstruction*
>(
this)->Accept(
662 visitor, call_finish_visit, ignore_control_predecessors);
668 using CompareFunction =
669 std::function<bool(const HloInstruction*, const HloInstruction*)>;
670 Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
671 const CompareFunction& operand_order,
672 bool call_finish_visit =
true);
676 Status Accept(
const std::function<Status(HloInstruction*)>& visitor_func);
678 const std::function<Status(
const HloInstruction*)>& visitor_func)
const;
687 Status AcceptOrdered(DfsHloVisitor* visitor,
688 const std::vector<const HloInstruction*>& order);
691 template <
typename HloInstructionPtr>
692 Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
697 const Literal& literal()
const;
702 int64 parameter_number()
const {
703 CHECK_EQ(HloOpcode::kParameter, opcode_);
704 return parameter_number_;
711 const std::vector<int64>& dimensions()
const;
712 int64 dimensions(int64 index)
const;
716 int64 concatenate_dimension()
const;
721 int64 tuple_index()
const;
726 std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
729 std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
731 const_cast<const HloInstruction*
>(
this)->LatestNonGteAncestorAndIndex();
732 return {
const_cast<HloInstruction*
>(rv.first), rv.second};
736 const HloInstruction* LatestNonGteAncestor()
const;
738 HloInstruction* LatestNonGteAncestor() {
739 return const_cast<HloInstruction*
>(
740 const_cast<const HloInstruction*
>(
this)->LatestNonGteAncestor());
747 HloComputation* to_apply()
const;
748 void set_to_apply(HloComputation* to_apply);
752 const string& custom_call_target()
const;
756 const string& outfeed_config()
const;
760 const Shape& outfeed_shape()
const;
766 HloComputation* while_condition()
const;
767 HloComputation* while_body()
const;
768 void set_while_condition(HloComputation* while_condition);
769 void set_while_body(HloComputation* while_body);
775 HloComputation* select()
const;
776 HloComputation* scatter()
const;
777 void set_select(HloComputation* select);
778 void set_scatter(HloComputation* scatter);
784 HloComputation* true_computation()
const;
785 HloComputation* false_computation()
const;
786 void set_true_computation(HloComputation* true_computation);
787 void set_false_computation(HloComputation* false_computation);
791 string SignatureString()
const;
801 string ToString()
const {
return ToString(HloPrintOptions()); }
802 string ToString(
const HloPrintOptions& options)
const;
807 string OperandsToString(
const HloPrintOptions& options)
const;
810 std::vector<string> ExtraAttributesToString(
811 const HloPrintOptions& options)
const;
814 string ToShortString()
const;
817 HloInstructionProto ToProto()
const;
821 string ToCategory()
const;
826 HloInstruction* tracing()
const;
827 void set_tracing(HloInstruction* trace_instruction);
834 int64 channel_id()
const {
return channel_id_; }
840 string channel_name()
const {
return channel_name_; }
847 int64 feature_index()
const {
return feature_index_; }
854 float epsilon()
const {
return epsilon_; }
859 string infeed_config()
const {
return infeed_config_; }
860 void set_infeed_config(
const string& config) { infeed_config_ = config; }
865 string TracingTag()
const;
868 bool IsConstant()
const;
872 bool IsFused()
const;
877 HloComputation* fused_instructions_computation()
const;
881 bool IsFusable()
const;
887 HloInstruction* fused_expression_root()
const;
893 const tensorflow::gtl::iterator_range<UnwrappingIterator<
894 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
895 fused_instructions()
const;
897 const tensorflow::gtl::iterator_range<
898 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
899 fused_instructions();
904 int64 fused_instruction_count()
const;
910 HloInstruction* fused_parameter(int64 parameter_number)
const;
915 const std::vector<HloInstruction*>& fused_parameters()
const;
919 const bool IsMultiOutputFusion()
const {
920 return opcode() == HloOpcode::kFusion &&
921 fused_expression_root()->opcode() == HloOpcode::kTuple;
925 CHECK_EQ(HloOpcode::kFusion, opcode_);
930 CHECK_EQ(HloOpcode::kFusion, opcode_);
936 const HloSharding& sharding()
const {
937 CHECK(has_sharding());
941 const HloSharding& sharding_or_default(
const HloSharding& default_)
const {
942 return sharding_ ? *sharding_ : default_;
945 tensorflow::gtl::optional<int64> sharding_unique_device()
const {
946 if (sharding_ ==
nullptr || !sharding_->HasUniqueDevice()) {
947 return tensorflow::gtl::optional<int64>();
949 return sharding_->UniqueDevice().ValueOrDie();
953 void set_sharding(
const HloSharding& sharding) {
954 sharding_ = MakeUnique<HloSharding>(sharding);
957 void clear_sharding() { sharding_ =
nullptr; }
959 bool has_sharding()
const {
return sharding_ !=
nullptr; }
962 HloInstruction* AddFusionOperand(HloInstruction* new_operand);
969 void MergeFusionInstruction(HloInstruction* instruction_to_merge);
978 void MergeFusionInstructionIntoMultiOutput(
979 HloInstruction* instruction_to_merge);
990 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
991 return FuseInstructionInternal(instruction_to_fuse);
1001 HloInstruction* FuseInstructionIntoMultiOutput(
1002 HloInstruction* instruction_to_fuse) {
1003 return FuseInstructionInternal(instruction_to_fuse,
true);
1009 int64 slice_starts(int64 dimension)
const {
1010 CHECK_EQ(HloOpcode::kSlice, opcode_);
1011 return slice_starts_[dimension];
1013 const std::vector<int64>& slice_starts()
const {
return slice_starts_; }
1019 int64 slice_limits(int64 dimension)
const {
1020 CHECK_EQ(HloOpcode::kSlice, opcode_);
1021 return slice_limits_[dimension];
1023 const std::vector<int64>& slice_limits()
const {
1024 CHECK_EQ(HloOpcode::kSlice, opcode_);
1025 return slice_limits_;
1031 int64 slice_strides(int64 dimension)
const {
1032 CHECK_EQ(HloOpcode::kSlice, opcode_);
1033 return slice_strides_[dimension];
1035 const std::vector<int64>& slice_strides()
const {
return slice_strides_; }
1039 bool IsInPlaceSlice()
const {
return is_in_place_slice_; }
1043 bool SetIsInPlaceSlice(
bool value) {
1044 is_in_place_slice_ = value;
1052 int64 slice_sizes(int64 dimension)
const {
1053 CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1054 return dynamic_slice_sizes_[dimension];
1056 const std::vector<int64>& dynamic_slice_sizes()
const {
1057 CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1058 return dynamic_slice_sizes_;
1064 int32 exponent_bits()
const {
1065 CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1066 return exponent_bits_;
1072 int32 mantissa_bits()
const {
1073 CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1074 return mantissa_bits_;
1079 const Window& window()
const {
1080 CHECK(window_ !=
nullptr);
1085 void set_window(
const Window& window) {
1086 window_ = MakeUnique<Window>(window);
1092 const PaddingConfig& padding_config()
const {
1093 CHECK(padding_config_ !=
nullptr);
1094 return *padding_config_;
1100 const ConvolutionDimensionNumbers& convolution_dimension_numbers()
const {
1101 CHECK(convolution_dimension_numbers_ !=
nullptr);
1102 return *convolution_dimension_numbers_;
1108 void set_convolution_dimension_numbers(
1109 const ConvolutionDimensionNumbers& dnums) {
1110 convolution_dimension_numbers_ =
1111 MakeUnique<ConvolutionDimensionNumbers>(dnums);
1114 FftType fft_type()
const {
1115 CHECK_EQ(HloOpcode::kFft, opcode_);
1119 const std::vector<int64>& fft_length()
const {
1120 CHECK_EQ(HloOpcode::kFft, opcode_);
1125 string ConvolutionDimensionNumbersToString()
const;
1128 const DotDimensionNumbers& dot_dimension_numbers()
const {
1129 CHECK(dot_dimension_numbers_ !=
nullptr);
1130 return *dot_dimension_numbers_;
1134 string DotDimensionNumbersToString()
const;
1136 const GatherDimensionNumbers& gather_dimension_numbers()
const {
1137 CHECK(gather_dimension_numbers_ !=
nullptr);
1138 return *gather_dimension_numbers_;
1141 tensorflow::gtl::ArraySlice<int64> gather_window_bounds()
const {
1142 CHECK_EQ(opcode(), HloOpcode::kGather);
1143 return gather_window_bounds_;
1147 string GatherDimensionNumbersToString()
const;
1152 RandomDistribution random_distribution()
const;
1161 std::unique_ptr<HloInstruction> Clone(
const string& suffix =
"clone",
1162 HloModule* module =
nullptr)
const;
1168 std::unique_ptr<HloInstruction> CloneWithNewOperands(
1169 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1170 HloModule* module =
nullptr)
const;
1173 const std::vector<HloComputation*>& called_computations()
const {
1174 return called_computations_;
1180 void ReplaceCalledComputations(
1181 std::function<HloComputation*(HloComputation*)> map_function) {
1182 for (int64 i = 0; i < called_computations_.size(); ++i) {
1183 called_computations_[i] = map_function(called_computations_[i]);
1195 void ClearCalledComputations() { called_computations_.clear(); }
1207 bool IsElementwiseOnOperand(int64 operand_idx)
const;
1210 bool IsElementwise()
const;
1216 bool ImplicitlyBroadcastsOperand(int64 operand_idx)
const;
1219 bool IsElementwiseBinary()
const;
1222 bool ReusesOperandElements(int64 i)
const {
1223 return OperandElementUse(i) == UseKind::kReuse;
1229 std::vector<int64> OperandIndices(
const HloInstruction* operand)
const;
1237 std::tuple<bool, std::vector<int64>, std::vector<int64>>
1238 ReshapeMerelyInsertsOrDeletes1SizedDimensions()
const;
1241 const string& name()
const {
return name_; }
1242 void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); }
1246 void UniquifyName(NameUniquer* name_uniquer);
1249 void SetUniqueId(
int id) {
1250 CHECK_EQ(unique_id_, -1);
1257 int unique_id()
const {
return unique_id_; }
1260 void set_metadata(
const OpMetadata& metadata) { metadata_ = metadata; }
1261 const OpMetadata& metadata()
const {
return metadata_; }
1266 void set_parent(HloComputation* computation) { parent_ = computation; }
1267 const HloComputation* parent()
const {
return parent_; }
1268 HloComputation* parent() {
return parent_; }
1271 HloModule* GetModule()
const;
1275 bool CouldBeBitcast()
const;
1282 const std::vector<int64>& outer_dimension_partitions()
const {
1283 return outer_dimension_partitions_;
1285 void set_outer_dimension_partitions(
1286 const std::vector<int64>& outer_dimension_partitions);
1291 void RelayoutConstant(
const Layout& new_layout,
1292 const ShapeIndex& shape_index = {});
1295 enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
1298 class FusionReusesParamElements;
1304 bool IdenticalSlowPath(
1305 const HloInstruction& other,
1306 const std::function<
bool(
const HloComputation*,
const HloComputation*)>&
1308 const std::function<
bool(
const Shape&,
const Shape&)>& eq_shapes)
const;
1311 static std::unique_ptr<HloInstruction> CreateNary(
1312 const Shape& shape, HloOpcode opcode,
1313 tensorflow::gtl::ArraySlice<HloInstruction*> operands);
1317 void AppendOperand(HloInstruction* operand);
1320 void AddUser(HloInstruction* user);
1323 void RemoveUser(HloInstruction* user);
1327 HloInstruction(HloOpcode opcode,
const Shape& shape);
1340 HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
1341 bool add_output =
false);
1349 HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
1350 bool add_output =
false);
1353 std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
1354 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1355 HloModule* module =
nullptr)
const;
1359 bool CanHaveDimensionsField()
const;
1362 UseKind OperandElementUse(int64 i)
const;
1370 InstructionVector operands_;
1373 std::vector<HloInstruction*> control_predecessors_;
1379 std::vector<HloInstruction*> users_;
1380 std::unordered_set<const HloInstruction*> user_set_;
1383 std::vector<HloInstruction*> control_successors_;
1386 HloComputation* parent_ =
nullptr;
1389 Shape outfeed_shape_;
1395 std::unique_ptr<Literal> literal_;
1398 int64 tuple_index_ = -1;
1402 std::vector<int64> dimensions_;
1405 std::unique_ptr<Window> window_;
1408 std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1411 std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
1413 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1414 std::vector<int64> gather_window_bounds_;
1417 FftType fft_type_ = FftType::FFT;
1420 std::vector<int64> fft_length_;
1423 std::vector<int64> slice_starts_;
1424 std::vector<int64> slice_limits_;
1425 std::vector<int64> slice_strides_;
1428 bool is_in_place_slice_ =
false;
1431 int32 exponent_bits_ = 0;
1432 int32 mantissa_bits_ = 0;
1436 std::vector<int64> dynamic_slice_sizes_;
1440 std::unique_ptr<PaddingConfig> padding_config_;
1446 std::unique_ptr<HloSharding> sharding_;
1449 int64 parameter_number_ = 0;
1452 string custom_call_target_;
1455 string channel_name_;
1458 int64 cost_estimate_ns_;
1461 std::vector<HloComputation*> called_computations_;
1467 kBodyComputationIndex = 0,
1468 kConditionComputationIndex = 1,
1471 kSelectComputationIndex = 0,
1472 kScatterComputationIndex = 1,
1475 kTrueComputationIndex = 0,
1476 kFalseComputationIndex = 1,
1480 string outfeed_config_;
1486 HloInstruction* trace_instruction_ =
nullptr;
1490 RandomDistribution distribution_;
1494 float epsilon_ = 0.0f;
1498 int64 feature_index_ = -1;
1502 int64 channel_id_ = -1;
1505 string infeed_config_;
1511 OpMetadata metadata_;
1515 std::vector<int64> outer_dimension_partitions_;
1517 TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
1521 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
1522 const string& kind_name);
1526 string PaddingConfigToString(
const PaddingConfig& padding);
1527 string OpMetadataToString(
const OpMetadata& metadata);
1528 string RandomDistributionToString(
const RandomDistribution& distribution);
1529 StatusOr<RandomDistribution> StringToRandomDistribution(
const string& name);
1541 struct HloPtrComparator {
1542 bool operator()(
const HloInstruction*
const& lhs,
1543 const HloInstruction*
const& rhs)
const {
1544 return lhs->unique_id() < rhs->unique_id();
1548 template <
typename ValueT>
1549 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
1551 template <
typename ValueT>
1552 using ConstHloInstructionMap =
1553 std::map<const HloInstruction*, ValueT, HloPtrComparator>;
1555 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
1556 using ConstHloInstructionSet =
1557 std::set<const HloInstruction*, HloPtrComparator>;
1561 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ Definition: hlo_computation.h:60
static std::unique_ptr< HloInstruction > CreateFusion(const Shape &shape, FusionKind fusion_kind, HloInstruction *fused_root)
Definition: hlo_instruction.cc:800
Definition: hlo_instruction.h:165
FusionKind
Definition: hlo_instruction.h:170
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52