18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ 24 #include <unordered_map> 25 #include <unordered_set> 29 #include "tensorflow/compiler/xla/iterator_util.h" 30 #include "tensorflow/compiler/xla/map_util.h" 31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 33 #include "tensorflow/compiler/xla/service/hlo.pb.h" 35 #include "tensorflow/compiler/xla/service/hlo_reachability.h" 36 #include "tensorflow/compiler/xla/service/name_uniquer.h" 37 #include "tensorflow/compiler/xla/shape_tree.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 #include "tensorflow/compiler/xla/types.h" 40 #include "tensorflow/compiler/xla/xla_data.pb.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/lib/gtl/array_slice.h" 43 #include "tensorflow/core/lib/gtl/flatmap.h" 44 #include "tensorflow/core/lib/gtl/flatset.h" 45 #include "tensorflow/core/platform/macros.h" 46 #include "tensorflow/core/platform/types.h" 67 explicit Builder(
const string& name,
70 last_added_instruction_(
nullptr),
71 fusion_instruction_(fusion_instruction) {}
77 std::unique_ptr<HloComputation>
Build(
84 std::unique_ptr<HloInstruction> instruction) {
85 instructions_.push_back(std::move(instruction));
86 last_added_instruction_ = instructions_.back().get();
87 return last_added_instruction_;
90 Status ForEachInstruction(
92 for (
const auto& instruction : instructions_) {
93 TF_RETURN_IF_ERROR(func(instruction.get()));
100 HloInstruction* last_added_instruction_;
101 HloInstruction* fusion_instruction_;
102 std::vector<std::unique_ptr<HloInstruction>> instructions_;
107 HloInstruction*
AddInstruction(std::unique_ptr<HloInstruction> instruction);
112 Status RemoveParameter(int64 param_no);
117 HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
121 Status RemoveInstruction(HloInstruction* instruction);
126 Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction);
131 void set_root_instruction(HloInstruction* new_root_instruction);
135 HloInstruction* root_instruction()
const {
return root_instruction_; }
138 int64 num_parameters()
const {
return param_instructions_.size(); }
141 HloInstruction* parameter_instruction(int64 param_no)
const {
142 CHECK_GE(param_no, 0);
143 CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
144 <<
"Computation " << name() <<
" has no parameter number " << param_no;
145 return param_instructions_[param_no];
148 const std::vector<HloInstruction*>& parameter_instructions()
const {
149 return param_instructions_;
152 const string& name()
const {
return name_; }
156 void UniquifyName(NameUniquer* name_uniquer);
162 string ToString()
const {
return ToString(HloPrintOptions()); }
163 string ToString(
const HloPrintOptions& options)
const;
166 HloComputationProto ToProto()
const;
176 static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
177 HloModule* module,
const HloComputationProto& proto,
178 const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
187 tensorflow::gtl::iterator_range<UnwrappingIterator<
188 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
189 instructions()
const {
190 return {MakeUnwrappingIterator(instructions_.begin()),
191 MakeUnwrappingIterator(instructions_.end())};
193 tensorflow::gtl::iterator_range<
194 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
196 return {MakeUnwrappingIterator(instructions_.begin()),
197 MakeUnwrappingIterator(instructions_.end())};
202 std::list<HloInstruction*> MakeInstructionPostOrder()
const;
210 std::unique_ptr<HloReachabilityMap> ComputeReachability()
const;
214 void UpdateReachabilityThroughInstruction(
215 const HloInstruction* instruction, HloReachabilityMap* reachability_map);
217 int64 instruction_count()
const {
return instructions_.size(); }
239 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
252 StatusOr<HloInstruction*> DeepCopyInstruction(
253 HloInstruction* instruction,
254 const ShapeTree<bool>* indices_to_copy =
nullptr,
255 ShapeTree<HloInstruction*>* copies_added =
nullptr);
259 ProgramShape ComputeProgramShape()
const;
262 bool operator==(
const HloComputation& other)
const;
266 Status ReplaceWithNewInstruction(
267 HloInstruction* old_instruction,
268 std::unique_ptr<HloInstruction> new_instruction);
273 Status ReplaceInstruction(HloInstruction* old_instruction,
274 HloInstruction* new_instruction);
277 void set_parent(HloModule* module) { parent_ = module; }
278 const HloModule* parent()
const {
return parent_; }
279 HloModule* parent() {
return parent_; }
287 template <
typename HloInstructionPtr>
288 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor)
const;
293 Status AcceptWithOperandOrder(
294 DfsHloVisitor* visitor,
295 const HloInstruction::CompareFunction& operand_order)
const;
299 template <
typename HloInstructionPtr>
300 Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
301 const std::vector<const HloInstruction*>& order)
const;
304 Status Accept(
const std::function<Status(HloInstruction*)>& visitor_func);
306 const std::function<Status(
const HloInstruction*)>& visitor_func)
const;
312 std::unique_ptr<HloComputation> Clone(
const string& suffix =
"clone",
313 HloModule* module =
nullptr);
320 std::unique_ptr<HloComputation> CloneWithReplacements(
321 std::unordered_map<
const HloInstruction*, std::unique_ptr<HloInstruction>>
323 HloModule* module =
nullptr,
const string& suffix =
"clone");
336 bool IsRemovable(
const HloInstruction* instruction);
340 bool HasSideEffect()
const;
343 bool IsFusionComputation()
const {
return fusion_instruction_ !=
nullptr; }
347 HloInstruction* FusionInstruction()
const {
return fusion_instruction_; }
348 void SetFusionInstruction(HloInstruction* fusion_instruction) {
349 fusion_instruction_ = fusion_instruction;
353 void SetUniqueId(int64
id) {
354 CHECK_EQ(unique_id_, -1);
359 int64 unique_id()
const {
return unique_id_; }
362 explicit HloComputation(
363 const string& name,
int parameter_count,
364 std::vector<std::unique_ptr<HloInstruction>>* instructions,
365 HloInstruction* root_instruction, HloInstruction* fusion_instruction);
369 std::unique_ptr<HloInstruction> instruction);
375 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
376 HloInstruction* fusion_instruction);
380 StatusOr<HloInstruction*> DeepCopyHelper(
381 HloInstruction* instruction,
const ShapeTree<bool>* indices_to_copy,
382 ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index);
385 std::vector<HloInstruction*> CollectUnreachableRoots()
const;
389 HloInstruction* root_instruction_;
393 HloInstruction* fusion_instruction_;
396 HloModule* parent_ =
nullptr;
401 using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
402 InstructionList instructions_;
403 std::unordered_map<const HloInstruction*, InstructionList::iterator>
404 instruction_iterators_;
406 std::vector<HloInstruction*> param_instructions_;
408 TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
413 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ Builder class for HloComputation.
Definition: hlo_computation.h:65
Definition: hlo_computation.h:60
void FuseInstructionsInto(tensorflow::gtl::ArraySlice< HloInstruction *> instructions_to_fuse, HloInstruction *fusion_instruction)
Definition: hlo_computation.cc:453
HloInstruction * AddInstructionInternal(std::unique_ptr< HloInstruction > instruction)
Definition: hlo_computation.cc:118
HloInstruction * AddInstruction(std::unique_ptr< HloInstruction > instruction)
Definition: hlo_computation.cc:104
Definition: hlo_instruction.h:165
std::list< HloComputation * > MakeEmbeddedComputationsList() const
Definition: hlo_computation.cc:361
std::unique_ptr< HloComputation > Build(HloInstruction *root_instruction=nullptr)
Build and return an HloComputation.
Definition: hlo_computation.cc:56
FusionKind
Definition: hlo_instruction.h:170
HloInstruction * AddInstruction(std::unique_ptr< HloInstruction > instruction)
Add the instruction to the member instructions which is an array used to record instructions.
Definition: hlo_computation.h:83
HloInstruction * CreateFusionInstruction(tensorflow::gtl::ArraySlice< HloInstruction *> instructions_to_fuse, HloInstruction::FusionKind fusion_kind)
Definition: hlo_computation.cc:479
namespace for xla
Definition: client_library.cc:26