tf_1.8_xla_doc
hlo_computation.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
20 
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28 
29 #include "tensorflow/compiler/xla/iterator_util.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_tree.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/gtl/array_slice.h"
43 #include "tensorflow/core/lib/gtl/flatmap.h"
44 #include "tensorflow/core/lib/gtl/flatset.h"
45 #include "tensorflow/core/platform/macros.h"
46 #include "tensorflow/core/platform/types.h"
47 
48 namespace xla {
49 
50 class HloModule;
51 
61  public:
65  class Builder {
66  public:
67  explicit Builder(const string& name,
68  HloInstruction* fusion_instruction = nullptr)
69  : name_(name),
70  last_added_instruction_(nullptr),
71  fusion_instruction_(fusion_instruction) {}
72 
73  // Build and return an HloComputation. The parameter root_instruction
74  // specifies the already-added instruction to use as the root. If
75  // root_instruction is nullptr then use the last added instruction as the
76  // root.
77  std::unique_ptr<HloComputation> Build(
78  HloInstruction* root_instruction = nullptr);
79 
84  std::unique_ptr<HloInstruction> instruction) {
85  instructions_.push_back(std::move(instruction));
86  last_added_instruction_ = instructions_.back().get();
87  return last_added_instruction_;
88  }
89 
90  Status ForEachInstruction(
91  const std::function<Status(const HloInstruction*)>& func) const {
92  for (const auto& instruction : instructions_) {
93  TF_RETURN_IF_ERROR(func(instruction.get()));
94  }
95  return Status::OK();
96  }
97 
98  private:
99  const string name_;
100  HloInstruction* last_added_instruction_;
101  HloInstruction* fusion_instruction_;
102  std::vector<std::unique_ptr<HloInstruction>> instructions_;
103  };
104 
105  // Add an instruction to the computation. The computation takes ownership of
106  // the instruction.
107  HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
108 
109  // Remove the param_no'th parameter from the computation.
110  // Note this is only applicatable to the computation for the fusion
111  // instruction.
112  Status RemoveParameter(int64 param_no);
113 
114  // Add new parameter instruction to the computation.
115  // This should be a new parameter. Instruction will be appended to parameters
116  // and inserted to the instruction list.
117  HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
118 
119  // Remove an instruction from the computation. The instruction must have no
120  // users. Instruction is deallocated with this call.
121  Status RemoveInstruction(HloInstruction* instruction);
122 
123  // Remove an instruction from the computation and also transitively any
124  // operand that has no users post removing an instruction. The instruction
125  // must have no users. Instruction is deallocated with this call.
126  Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction);
127 
128  // Set the root of the computation to the given instruction. The instruction
129  // must have already been added to the computation and have the same shape as
130  // the result of the computation for non fusion computations.
131  void set_root_instruction(HloInstruction* new_root_instruction);
132 
133  // Return the root instruction of the computation. The root instruction is the
134  // instruction which produces the output of the computation.
135  HloInstruction* root_instruction() const { return root_instruction_; }
136 
137  // Returns the number of parameters for this computation.
138  int64 num_parameters() const { return param_instructions_.size(); }
139 
140  // Returns the parameter instruction for the given parameter number.
141  HloInstruction* parameter_instruction(int64 param_no) const {
142  CHECK_GE(param_no, 0);
143  CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
144  << "Computation " << name() << " has no parameter number " << param_no;
145  return param_instructions_[param_no];
146  }
147 
148  const std::vector<HloInstruction*>& parameter_instructions() const {
149  return param_instructions_;
150  }
151 
152  const string& name() const { return name_; }
153 
154  // Use the given NameUniquer to select a unique name for the computation based
155  // on the computation's existing name.
156  void UniquifyName(NameUniquer* name_uniquer);
157 
158  // Return a string representation of the computation.
159  //
160  // (We express the default options using an overload rather than a default
161  // param because gdb ignores default params, but does resolve overloads.)
162  string ToString() const { return ToString(HloPrintOptions()); }
163  string ToString(const HloPrintOptions& options) const;
164 
165  // Returns a serialized representation of this computation.
166  HloComputationProto ToProto() const;
167 
168  // Creates a computation from the given proto. Arguments:
169  //
170  // module: the module which will contain the computation. The newly created
171  // computation is *not* added to the module, however.
172  // proto: the proto to convert from.
173  // computation_map: a map from computation id to HloComputation*. This map
174  // must contain all computations which the newly constructed computation
175  // calls.
176  static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
177  HloModule* module, const HloComputationProto& proto,
178  const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
179 
180  // Gets the instructions in this computation.
181  //
182  // The returned type is a range of HloInstruction*s, so you can iterate over
183  // it using a range-based for loop in the natural way:
184  //
185  // for (HloInstruction* instr : computation->instructions()) { ... }
186  //
187  tensorflow::gtl::iterator_range<UnwrappingIterator<
188  std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
189  instructions() const {
190  return {MakeUnwrappingIterator(instructions_.begin()),
191  MakeUnwrappingIterator(instructions_.end())};
192  }
193  tensorflow::gtl::iterator_range<
194  UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
195  instructions() {
196  return {MakeUnwrappingIterator(instructions_.begin()),
197  MakeUnwrappingIterator(instructions_.end())};
198  }
199 
200  // Compute and return a post-order of the instructions in the computation. In
201  // this order, definitions of values always appear before their uses.
202  std::list<HloInstruction*> MakeInstructionPostOrder() const;
203 
204  // Computes and returns the reachability between HLO instructions in the
205  // computation. The returned HloReachabilityMap is constructed such that
206  // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a
207  // directed path (from producer to consumer) from 'a' to 'b'. Both data
208  // dependencies (operands) and control dependencies are considered for
209  // reachability. Trivially an instruction is reachable from itself.
210  std::unique_ptr<HloReachabilityMap> ComputeReachability() const;
211 
212  // Updates the given reachability map after the immediate predecessor set
213  // (operands and control predecessors) of 'instruction' has changed.
214  void UpdateReachabilityThroughInstruction(
215  const HloInstruction* instruction, HloReachabilityMap* reachability_map);
216 
217  int64 instruction_count() const { return instructions_.size(); }
218 
227  std::list<HloComputation*> MakeEmbeddedComputationsList() const;
228 
238  HloInstruction* CreateFusionInstruction(
239  tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
240  HloInstruction::FusionKind fusion_kind);
241 
242  // Create a deep copy of the given instruction and return the instruction
243  // producing the copied result. All instructions performing the copy are added
244  // to the computation. For array-shaped values, this method trivially returns
245  // a kCopy instruction. For tuple-shaped instructions, the copy is performed
246  // with a series of kGetTupleElement and kTuple instructions. If
247  // indices_to_copy is non-null then this ShapeTree indicates which elements
248  // (arrays) of the shape to copy. Non-copied elements are passed through
249  // transparently. If copies_added is non-null, then the added kCopy
250  // instructions will be inserted in the respective index in the given
251  // ShapeTree.
252  StatusOr<HloInstruction*> DeepCopyInstruction(
253  HloInstruction* instruction,
254  const ShapeTree<bool>* indices_to_copy = nullptr,
255  ShapeTree<HloInstruction*>* copies_added = nullptr);
256 
257  // Computes and returns the ProgramShape of this computation (shape of
258  // parameters and result with layout).
259  ProgramShape ComputeProgramShape() const;
260 
261  // Return whether `*this` and `other` are functionally equivalent.
262  bool operator==(const HloComputation& other) const;
263 
264  // Replaces old instruction with newly created instruction. Removes old
265  // instruction from computation. Updates uses and root instruction.
266  Status ReplaceWithNewInstruction(
267  HloInstruction* old_instruction,
268  std::unique_ptr<HloInstruction> new_instruction);
269 
270  // Replace old instruction with new instruction. Updates uses and root
271  // instruction. Removes old instruction from computation. Precondition:
272  // old_instruction and new_instruction must have the compatible shapes.
273  Status ReplaceInstruction(HloInstruction* old_instruction,
274  HloInstruction* new_instruction);
275 
276  // Set/get the module containing this computation.
277  void set_parent(HloModule* module) { parent_ = module; }
278  const HloModule* parent() const { return parent_; }
279  HloModule* parent() { return parent_; }
280 
281  // Visit every node in the computation in DFS post-order with the given
282  // visitor. This is similar to calling HloInstruction::Accept on the root of
283  // the computation except this method also visits instructions not reachable
284  // via the root. The root instruction of the computation is visited last, and
285  // the visitor's FinishVisit method is called once upon completion (with the
286  // root instruction as the argument).
287  template <typename HloInstructionPtr>
288  Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const;
289 
290  // Same as Accept() above, but the order of operand and control predecessor
291  // visitation is determined by the given operand order; if compare(A, B) ==
292  // true, A is visited before B.
293  Status AcceptWithOperandOrder(
294  DfsHloVisitor* visitor,
295  const HloInstruction::CompareFunction& operand_order) const;
296 
297  // Visit every node in the computation in the given order. 'order' must
298  // be a topological sort of all instructions in the computation.
299  template <typename HloInstructionPtr>
300  Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
301  const std::vector<const HloInstruction*>& order) const;
302 
303  // Same as Accept() above, but the visitor is given as a function.
304  Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
305  Status Accept(
306  const std::function<Status(const HloInstruction*)>& visitor_func) const;
307 
308  // Returns a deep copy of this computation including all instructions.
309  // If the module pointer is not nullptr, it will be the module where
310  // the cloned computations will be added to (in order to support deep
311  // cloning).
312  std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
313  HloModule* module = nullptr);
314 
315  // Like Clone(), but if an instruction is present in replacement_map, we use
316  // the map's value to replace that instruction in the cloned computation.
317  //
318  // If replacements maps a key to nullptr, we remove that instruction from the
319  // new computation.
320  std::unique_ptr<HloComputation> CloneWithReplacements(
321  std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
322  replacements,
323  HloModule* module = nullptr, const string& suffix = "clone");
324 
325  // Returns true if the given instruction can be removed from the computation.
326  // Parameter instructions cannot be removed without violating invariants of
327  // the HLO computation with the exception of fusion computation. A parameter
328  // instruction is removable for a fusion computation.
329  //
330  // Note that IsRemovable() is a necessariy condition to remove an instruction
331  // rather than a sufficient condition. For example, instructions with
332  // side-effect (e.g., Send, Infeed) may be removed from a computation, but the
333  // transformation must guarantee the invariants relevant to the instructions
334  // still hold (e.g., Send and Recv must be removed together to make each
335  // channel complete).
336  bool IsRemovable(const HloInstruction* instruction);
337 
338  // Returns true if this computation has a side effect. A computation has a
339  // side effect if it contains one or more instructions with a side effect.
340  bool HasSideEffect() const;
341 
342  // Returns if this computation is a fusion computation.
343  bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
344 
345  // Returns the owning fusion instruction, or nullptr if this is not a fusion
346  // computation.
347  HloInstruction* FusionInstruction() const { return fusion_instruction_; }
348  void SetFusionInstruction(HloInstruction* fusion_instruction) {
349  fusion_instruction_ = fusion_instruction;
350  }
351 
352  // The id of this computation should be unique within the module.
353  void SetUniqueId(int64 id) {
354  CHECK_EQ(unique_id_, -1);
355  CHECK_GE(id, 0);
356  unique_id_ = id;
357  }
358 
359  int64 unique_id() const { return unique_id_; }
360 
361  private:
362  explicit HloComputation(
363  const string& name, int parameter_count,
364  std::vector<std::unique_ptr<HloInstruction>>* instructions,
365  HloInstruction* root_instruction, HloInstruction* fusion_instruction);
366 
367  // Internal helper for adding instructions.
368  HloInstruction* AddInstructionInternal(
369  std::unique_ptr<HloInstruction> instruction);
370 
371  // Fuses HLOs in instructions_to_fuse into fusion_instruction.
372  //
373  // Pre-condition: fusion_instruction's opcode is kFusion.
375  tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
376  HloInstruction* fusion_instruction);
377 
378  // Internal helper for recursive copying of an instruction. Creates and
379  // returns a deep copy of the given instruction.
380  StatusOr<HloInstruction*> DeepCopyHelper(
381  HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
382  ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index);
383 
384  // Internal helper to collect unreachable roots.
385  std::vector<HloInstruction*> CollectUnreachableRoots() const;
386 
387  string name_;
388  int64 unique_id_;
389  HloInstruction* root_instruction_;
390 
391  // If this computation is a fusion computation, this field points to the
392  // corresponding fusion instruction. Otherwise, this is null.
393  HloInstruction* fusion_instruction_;
394 
395  // Module containing this computation.
396  HloModule* parent_ = nullptr;
397 
398  // Store instructions in std::list as they can be added and removed
399  // arbitrarily and we want a stable iteration order. Keep a map from
400  // instruction pointer to location in the list for fast lookup.
401  using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
402  InstructionList instructions_;
403  std::unordered_map<const HloInstruction*, InstructionList::iterator>
404  instruction_iterators_;
405 
406  std::vector<HloInstruction*> param_instructions_;
407 
408  TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
409 };
410 
411 } // namespace xla
412 
413 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
Builder class for HloComputation.
Definition: hlo_computation.h:65
Definition: hlo_computation.h:60
void FuseInstructionsInto(tensorflow::gtl::ArraySlice< HloInstruction *> instructions_to_fuse, HloInstruction *fusion_instruction)
Definition: hlo_computation.cc:453
HloInstruction * AddInstructionInternal(std::unique_ptr< HloInstruction > instruction)
Definition: hlo_computation.cc:118
HloInstruction * AddInstruction(std::unique_ptr< HloInstruction > instruction)
Definition: hlo_computation.cc:104
Definition: hlo_instruction.h:165
std::list< HloComputation * > MakeEmbeddedComputationsList() const
Definition: hlo_computation.cc:361
std::unique_ptr< HloComputation > Build(HloInstruction *root_instruction=nullptr)
Build and return an HloComputation.
Definition: hlo_computation.cc:56
FusionKind
Definition: hlo_instruction.h:170
HloInstruction * AddInstruction(std::unique_ptr< HloInstruction > instruction)
Add the instruction to the member instructions which is an array used to record instructions.
Definition: hlo_computation.h:83
HloInstruction * CreateFusionInstruction(tensorflow::gtl::ArraySlice< HloInstruction *> instructions_to_fuse, HloInstruction::FusionKind fusion_kind)
Definition: hlo_computation.cc:479
namespace for xla
Definition: client_library.cc:26