tf_1.8_xla_doc
hlo_instruction.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 // HLO instructions are in DAG form and represent the computations that the user
19 // has built up via the XLA service interface. They are ultimately lowered
20 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
21 // form; e.g. see DfsHloVisitor.
22 
23 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
24 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
25 
26 #include <functional>
27 #include <iosfwd>
28 #include <list>
29 #include <memory>
30 #include <set>
31 #include <string>
32 #include <tuple>
33 #include <unordered_map>
34 #include <unordered_set>
35 #include <vector>
36 
37 #include "tensorflow/compiler/xla/iterator_util.h"
38 #include "tensorflow/compiler/xla/literal_util.h"
39 #include "tensorflow/compiler/xla/map_util.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
41 #include "tensorflow/compiler/xla/service/hlo.pb.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
44 #include "tensorflow/compiler/xla/service/name_uniquer.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/core/stringpiece.h"
49 #include "tensorflow/core/lib/gtl/array_slice.h"
50 #include "tensorflow/core/lib/gtl/flatmap.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/lib/gtl/iterator_range.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/types.h"
56 
57 namespace xla {
58 
59 class HloComputation;
60 class HloModule;
61 
62 // A bunch of switches that control how the hlo text should be printed.
63 class HloPrintOptions {
64  public:
65  // Constructs the default print options: don't print large constants, don't
66  // compact operands, no indentation.
67  HloPrintOptions()
68  : print_large_constants_(false),
69  print_subcomputation_references_(true),
70  print_metadata_(true),
71  compact_operands_(false),
72  print_operand_shape_(true),
73  print_program_shape_(true),
74  print_percent_(true),
75  indent_amount_(0) {}
76 
77  static HloPrintOptions ShortParsable() {
78  return HloPrintOptions()
79  .set_print_large_constants(true)
80  .set_print_subcomputation_references(true)
81  .set_print_metadata(false)
82  .set_print_operand_shape(false)
83  .set_print_program_shape(false)
84  .set_print_percent(false);
85  }
86 
87  // If true, large constants will be printed out.
88  HloPrintOptions& set_print_large_constants(bool value) {
89  print_large_constants_ = value;
90  return *this;
91  }
92 
93  // If true, the names of subcomputations (e.g. a fusion node's fused
94  // computation) won't be printed. This makes the resulting text not parsable.
95  //
96  // A CustomCall's call target is printed even if
97  // print_subcomputation_references is false, because the call target isn't an
98  // HloComputation.
99  HloPrintOptions& set_print_subcomputation_references(bool value) {
100  print_subcomputation_references_ = value;
101  return *this;
102  }
103 
104  // If true, metatdata will be printed.
105  HloPrintOptions& set_print_metadata(bool value) {
106  print_metadata_ = value;
107  return *this;
108  }
109 
110  // If true, operands' shapes will be printed.
111  HloPrintOptions& set_print_operand_shape(bool value) {
112  print_operand_shape_ = value;
113  return *this;
114  }
115 
116  // If true, program shape of hlo computations will be printed.
117  HloPrintOptions& set_print_program_shape(bool value) {
118  print_program_shape_ = value;
119  return *this;
120  }
121 
122  // If true, names will be printed with prefix '%'.
123  HloPrintOptions& set_print_percent(bool value) {
124  print_percent_ = value;
125  return *this;
126  }
127 
128  // If true, only a part of operands will be printed out, and their names will
129  // be omitted (note that in this case the text will not be parsable).
130  HloPrintOptions& set_compact_operands(bool value) {
131  compact_operands_ = value;
132  return *this;
133  }
134 
135  // The indent of the hlo text block.
136  HloPrintOptions& set_indent_amount(int value) {
137  indent_amount_ = value;
138  return *this;
139  }
140 
141  bool print_large_constants() const { return print_large_constants_; }
142  bool print_subcomputation_references() const {
143  return print_subcomputation_references_;
144  }
145  bool print_metadata() const { return print_metadata_; }
146  bool compact_operands() const { return compact_operands_; }
147  bool print_operand_shape() const { return print_operand_shape_; }
148  bool print_program_shape() const { return print_program_shape_; }
149  bool print_percent() const { return print_percent_; }
150  int indent_amount() const { return indent_amount_; }
151 
152  private:
153  bool print_large_constants_;
154  bool print_subcomputation_references_;
155  bool print_metadata_;
156  bool compact_operands_;
157  bool print_operand_shape_;
158  bool print_program_shape_;
159  bool print_percent_;
160  int indent_amount_;
161 };
166  public:
170  enum class FusionKind {
171  kLoop,
172  kInput,
173  kOutput,
177  kTransposeDot,
178  kCustom,
180  };
181 
182  ~HloInstruction();
183 
184  // Creates an instruction from the given proto. Arguments:
185  //
186  // module: the module which will contain the instruction. The newly created
187  // instruction is *not* added to the module or any computation, however.
188  // proto: the proto to convert from.
189  // instruction_map: a map from instruction id to HloInstruction*. This map
190  // must contain all operands of the newly constructed instruction.
191  // computation_map: a map from computation id to HloComputation*. This map
192  // must contain all computations which the newly constructed instruction
193  // calls.
194  static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
195  HloModule* module, const HloInstructionProto& proto,
196  const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
197  const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
198 
199  // Creates a parameter-retrieving instruction.
200  static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
201  const Shape& shape,
202  const string& name);
203 
204  // Creates a literal constant instruction.
205  static std::unique_ptr<HloInstruction> CreateConstant(
206  std::unique_ptr<Literal> literal);
207 
208  // Creates a get tuple element instruction.
209  static std::unique_ptr<HloInstruction> CreateGetTupleElement(
210  const Shape& shape, HloInstruction* operand, int64 index);
211 
212  // Creates a trace instruction that logs the input operand in the computation.
213  static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
214  HloInstruction* operand);
215 
216  // Creates a random number generation instruction that fills a shape with
217  // random numbers from a given distribution.
218  static std::unique_ptr<HloInstruction> CreateRng(
219  const Shape& shape, RandomDistribution distribution,
220  tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
221 
222  // Creates a unary instruction (one operand).
223  // Precondition: opcode must be a legitimate unary operation.
224  static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
225  HloOpcode opcode,
226  HloInstruction* operand);
227 
228  // Creates a binary instruction (two operands).
229  // Precondition: opcode must be a legitimate binary operation.
230  static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
231  HloOpcode opcode,
232  HloInstruction* lhs,
233  HloInstruction* rhs);
234 
235  // Creates a ternary instruction (three operands).
236  // Precondition: opcode must be a legitimate ternary operation.
237  static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
238  HloOpcode opcode,
239  HloInstruction* lhs,
240  HloInstruction* rhs,
241  HloInstruction* ehs);
242 
243  // Creates a variadic instruction (variable number of operands).
244  // Precondition: opcode must be a legitimate variadic operation.
245  static std::unique_ptr<HloInstruction> CreateVariadic(
246  const Shape& shape, HloOpcode opcode,
247  tensorflow::gtl::ArraySlice<HloInstruction*> operands);
248 
249  // Creates a map instruction, where the computation (given by the handle) is
250  // applied element-wise to every element in operands (across the operands,
251  // at a given index) with the same `static_operands`.
252  static std::unique_ptr<HloInstruction> CreateMap(
253  const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
254  HloComputation* map_computation,
255  tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
256 
257  // Creates a convolution op, where rhs is the convolutional filter
258  // and window describes how the filter is applied to lhs.
259  static std::unique_ptr<HloInstruction> CreateConvolve(
260  const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
261  const Window& window,
262  const ConvolutionDimensionNumbers& dimension_numbers);
263 
264  // Creates an FFT op, of the type indicated by fft_type.
265  static std::unique_ptr<HloInstruction> CreateFft(
266  const Shape& shape, HloInstruction* operand, FftType fft_type,
267  tensorflow::gtl::ArraySlice<int64> fft_length);
268 
269  // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
270  // dimensions specified in 'dimension_numbers'.
271  static std::unique_ptr<HloInstruction> CreateDot(
272  const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
273  const DotDimensionNumbers& dimension_numbers);
274 
275  // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
276  // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
277  // and the RHS must be of rank 2.
278  static std::unique_ptr<HloInstruction> CreateCanonicalDot(
279  const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
280 
281  // Creates a reduce-precision op, where operand is the data to reduce in
282  // precision, and exponent_bits and mantissa_bits describe the precision to
283  // reduce it to.
284  static std::unique_ptr<HloInstruction> CreateReducePrecision(
285  const Shape& shape, HloInstruction* operand, const int exponent_bits,
286  const int mantissa_bits);
287 
288  // Creates a cross replica sum op.
289  static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
290  const Shape& shape,
291  tensorflow::gtl::ArraySlice<HloInstruction*> operands);
292 
293  // Creates a conversion instruction, where operand is the data to convert and
294  // shape is the target shape for the conversion.
295  static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
296  HloInstruction* operand);
297 
298  // Creates a bitcast conversion instruction, where operand is the data to
299  // convert and shape is the target shape for the conversion.
300  static std::unique_ptr<HloInstruction> CreateBitcastConvert(
301  const Shape& shape, HloInstruction* operand);
302 
303  // Creates an infeed instruction, which reads data of the given shape from the
304  // Infeed interface of the device.
305  static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
306  const string& config);
307 
308  // Creates an outfeed instruction, which outputs data.
309  static std::unique_ptr<HloInstruction> CreateOutfeed(
310  const Shape& shape, HloInstruction* operand,
311  tensorflow::StringPiece outfeed_config);
312 
313  // Creates an asynchronous send instruction with the given channel id, which
314  // initiates sending the operand data to a unique receive instruction in
315  // another computation that has the same channel id.
316  static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
317  int64 channel_id);
318 
319  // Blocks until data transfer for the Send instruction (operand) is complete.
320  // The operand must be kSend.
321  static std::unique_ptr<HloInstruction> CreateSendDone(
322  HloInstruction* operand);
323 
324  // Creates an asynchronous receive instruction with the given channel id,
325  // which allocates resources to receive data of the given shape from a unique
326  // send instruction in another computation that has the same channel id.
327  static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
328  int64 channel_id);
329 
330  // Blocks until data transfer for the Recv instruction (operand) is complete
331  // and returns the receive buffer. The operand must be kRecv.
332  static std::unique_ptr<HloInstruction> CreateRecvDone(
333  HloInstruction* operand);
334 
335  // Creates a slice instruction, where the operand is sliced by the given
336  // start/limit indices.
337  static std::unique_ptr<HloInstruction> CreateSlice(
338  const Shape& shape, HloInstruction* operand,
339  tensorflow::gtl::ArraySlice<int64> start_indices,
340  tensorflow::gtl::ArraySlice<int64> limit_indices,
341  tensorflow::gtl::ArraySlice<int64> strides);
342 
343  // Creates a slice instruction, where the first operand is sliced by
344  // start indices specified in the second operand, and by size specified in
345  // 'slice_sizes'.
346  static std::unique_ptr<HloInstruction> CreateDynamicSlice(
347  const Shape& shape, HloInstruction* operand,
348  HloInstruction* start_indices,
349  tensorflow::gtl::ArraySlice<int64> slice_sizes);
350 
351  // Creates a dynamic update slice instruction, which updates a slice
352  // of 'operand' with 'update' and 'start_indices'.
353  static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
354  const Shape& shape, HloInstruction* operand, HloInstruction* update,
355  HloInstruction* start_indices);
356 
357  // Creates a concatenate instruction, where the operands are concatenated on
358  // the provided dimension.
359  static std::unique_ptr<HloInstruction> CreateConcatenate(
360  const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
361  int64 dimension);
362 
363  // Creates a reduce instruction, where the computation (given by the handle)
364  // is applied successively to every element in operand. That is, if f is the
365  // function to apply (which either takes 2 [accumulator, value] or 3
366  // [accumulator, index, value] arguments) and init is a reduction operator
367  // specified initial value (for example, 0 for addition), then this operation
368  // will compute:
369  // f(f(init, [index0], value0), [index1], value1), ...)
370  static std::unique_ptr<HloInstruction> CreateReduce(
371  const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
372  tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
373  HloComputation* reduce_computation);
374 
375  // Creates a reduce-window instruction, where the computation (given
376  // by the handle) is applied window-wise at each valid window
377  // position in the operand.
378  static std::unique_ptr<HloInstruction> CreateReduceWindow(
379  const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
380  const Window& window, HloComputation* reduce_computation);
381 
382  // Creates a batch-norm-training instruction.
383  static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
384  const Shape& shape, HloInstruction* operand, HloInstruction* scale,
385  HloInstruction* offset, float epsilon, int64 feature_index);
386 
387  // Creates a batch-norm-inference instruction.
388  static std::unique_ptr<HloInstruction> CreateBatchNormInference(
389  const Shape& shape, HloInstruction* operand, HloInstruction* scale,
390  HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
391  float epsilon, int64 feature_index);
392 
393  // Creates a batch-norm-grad instruction.
394  static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
395  const Shape& shape, HloInstruction* operand, HloInstruction* scale,
396  HloInstruction* mean, HloInstruction* variance,
397  HloInstruction* grad_output, float epsilon, int64 feature_index);
398 
399  // Creates a scatter computation that scatters the `source` array to the
400  // selected indices of each window.
401  static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
402  const Shape& shape, HloInstruction* operand, HloComputation* select,
403  const Window& window, HloInstruction* source, HloInstruction* init_value,
404  HloComputation* scatter);
405 
406  // Creates a broadcast instruction.
407  static std::unique_ptr<HloInstruction> CreateBroadcast(
408  const Shape& shape, HloInstruction* operand,
409  tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
410 
411  // Creates a broadcast-size-one-dimensions instruction.
412  static std::unique_ptr<HloInstruction> CreateBroadcastDimOne(
413  const Shape& shape, HloInstruction* operand);
414 
415  // Creates a sequence of instructions that performs an explicit broadcast of
416  // the operand to the target shape.
417  //
418  // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
419  // returned as a unique_ptr for API consistency with other factory methods in
420  // this interface.
421  //
422  // TODO(b/72173833) Ideally HloComputations would always be present, and so
423  // the adder being passed by the caller would not be necessary.
424  static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
425  const Shape& output_shape, HloInstruction* operand,
426  const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
427  adder);
428 
429  // Creates a pad instruction, where the operand is padded on the edges and
430  // between the elements with the given padding value.
431  static std::unique_ptr<HloInstruction> CreatePad(
432  const Shape& shape, HloInstruction* operand,
433  HloInstruction* padding_value, const PaddingConfig& padding_config);
434 
435  // Creates a reshape instruction, where the operand is flattened row-major
436  // order and then reshaped to the given result shape.
437  static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
438  HloInstruction* operand);
439 
440  // Creates a transpose instruction which permutes the operand dimensions.
441  static std::unique_ptr<HloInstruction> CreateTranspose(
442  const Shape& shape, HloInstruction* operand,
443  tensorflow::gtl::ArraySlice<int64> dimensions);
444 
445  // Creates a while instruction, given a condition computation, a body
446  // computation, and the initial value for the input of the computations. For
447  // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
448  // corresponds to the C code below.
449  // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
450  static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
451  HloComputation* condition,
452  HloComputation* body,
453  HloInstruction* init);
454 
455  static std::unique_ptr<HloInstruction> CreateConditional(
456  const Shape& shape, HloInstruction* pred,
457  HloInstruction* true_computation_arg, HloComputation* true_computation,
458  HloInstruction* false_computation_arg, HloComputation* false_computation);
459 
460  static std::unique_ptr<HloInstruction> CreateGather(
461  const Shape& shape, HloInstruction* operand,
462  HloInstruction* gather_indices,
463  const GatherDimensionNumbers& gather_dim_numbers,
464  tensorflow::gtl::ArraySlice<int64> window_bounds);
472  static std::unique_ptr<HloInstruction> CreateFusion(
473  const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
474 
475  static std::unique_ptr<HloInstruction> CreateFusion(
476  const Shape& shape, FusionKind fusion_kind,
477  tensorflow::gtl::ArraySlice<HloInstruction*> operands,
478  HloComputation* fusion_computation);
479 
480  // Creates a call instruction that applies the given computation on the given
481  // operands. "shape" is the resultant shape.
482  static std::unique_ptr<HloInstruction> CreateCall(
483  const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
484  HloComputation* computation);
485 
486  // Creates a custom call instruction that applies the given custom call target
487  // to the given operands. "shape" is the resultant shape.
488  static std::unique_ptr<HloInstruction> CreateCustomCall(
489  const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
490  tensorflow::StringPiece custom_call_target);
491 
492  // Creates a HostCompute instruction, which records host-side control and
493  // data dependencies for use in instruction scheduling.
494  static std::unique_ptr<HloInstruction> CreateHostCompute(
495  const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
496  tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
497 
498  // Creates a tuple instruction with the given elements. This is a convenience
499  // wrapper around CreateVariadic.
500  static std::unique_ptr<HloInstruction> CreateTuple(
501  tensorflow::gtl::ArraySlice<HloInstruction*> elements);
502 
503  // Creates a reverse instruction, which reverses the order of the elements
504  // in the specified dimensions.
505  static std::unique_ptr<HloInstruction> CreateReverse(
506  const Shape& shape, HloInstruction* operand,
507  tensorflow::gtl::ArraySlice<int64> dimensions);
508 
509  // Creates an instance of GatherDimensionNumbers.
510  static GatherDimensionNumbers MakeGatherDimNumbers(
511  tensorflow::gtl::ArraySlice<int64> output_window_dims,
512  tensorflow::gtl::ArraySlice<int64> elided_window_dims,
513  tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
514  int64 index_vector_dim);
515 
516  // Returns the opcode for this instruction.
517  HloOpcode opcode() const { return opcode_; }
518 
519  // Returns true if this instruction has a side effect. An instruction has a
520  // side effect if it uses certain opcodes or calls a computation with a side
521  // effect.
522  bool HasSideEffect() const;
523 
524  // Returns the result shape of this instruction.
525  const Shape& shape() const;
526 
527  // Returns the (mutable) result shape of this instruction.
528  Shape* mutable_shape() { return &shape_; }
529 
530  // Returns the ith operand to this instruction.
531  const HloInstruction* operand(int64 i) const;
532 
533  // Returns the ith operand to this instruction.
534  HloInstruction* mutable_operand(int64 i);
535 
536  // Returns the number of operands to this instruction.
537  int64 operand_count() const { return operands_.size(); }
538 
539  // Returns the vector of operands of this instruction.
540  using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
541  const InstructionVector& operands() const { return operands_; }
542 
543  // Returns the index of 'target' in the operands sequence.
544  // Precondition: target must be an operand (or a fatal error will occur).
545  int64 operand_index(const HloInstruction* target) const;
546 
547  // Returns the number of users of this instruction.
548  int64 user_count() const { return users_.size(); }
549 
550  // Returns the users of this instruction.
551  const std::vector<HloInstruction*>& users() const { return users_; }
552 
553  // Returns true if this instruction is a user of 'instruction'.
554  bool IsUserOf(const HloInstruction* instruction) const {
555  return ContainsKey(instruction->user_set_, this);
556  }
557 
558  // Adds a control dependency from this instruction to the given
559  // instruction. This instruction becomes a control predecessor of
560  // 'instruction', and 'instruction' becomes a control successor of this
561  // instruction. Returns an error status if either of the given instructions
562  // does not belong to the same computation.
563  //
564  // This is used to enforce an additional ordering requirement that is not
565  // captured by normal data dependencies, such as ordering among Send or Recv
566  // operations to avoid deadlock.
567  Status AddControlDependencyTo(HloInstruction* instruction);
568 
569  // Removes a previously added control dependency from this instruction to
570  // 'instruction'.
571  Status RemoveControlDependencyTo(HloInstruction* instruction);
572 
573  // Returns the set of control predecessors (successors) of this
574  // instruction. Control predecessors (successors) must execute before (after)
575  // the current instruction.
576  const std::vector<HloInstruction*>& control_predecessors() const {
577  return control_predecessors_;
578  }
579  const std::vector<HloInstruction*>& control_successors() const {
580  return control_successors_;
581  }
582 
583  // Returns true if "other" performs the same computation as this instruction.
584  bool Identical(
585  const HloInstruction& other,
586  const std::function<bool(const HloInstruction*, const HloInstruction*)>&
587  eq_operands = std::equal_to<const HloInstruction*>(),
588  const std::function<bool(const HloComputation*, const HloComputation*)>&
589  eq_computations = std::equal_to<const HloComputation*>(),
590  bool layout_sensitive = true) const {
591  // An instruction is always identical to itself.
592  if (this == &other) {
593  return true;
594  }
595 
596  // Identical instruction must have the same opcode, shape, and identical
597  // operands.
598  if (opcode() != other.opcode()) {
599  return false;
600  }
601  using EqShapeFuncType = bool (*)(const Shape&, const Shape&);
602  EqShapeFuncType eq_shapes =
603  layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
604  if (!eq_shapes(shape(), other.shape())) {
605  return false;
606  }
607  if (operands().size() != other.operands().size()) {
608  return false;
609  }
610 
611  // Use an explicit loop rather than ContainerEquals, because copying around
612  // std::functions may be too expensive in some cases.
613  for (size_t i = 0; i < operands().size(); ++i) {
614  if (!eq_operands(operand(i), other.operand(i))) {
615  return false;
616  }
617  }
618 
619  return IdenticalSlowPath(other, eq_computations, eq_shapes);
620  }
621 
622  // Returns whether the instruction has a constant operand.
623  bool HasConstantOperand() const;
624 
625  // Returns whether this instruction does a rank-2 transposition.
626  bool IsRank2Transpose() const;
627 
628  // Replaces the use of this instruction in "user" with "new_producer". Note
629  // that there might be multiple uses of this instruction in "user"; all will
630  // be replaced.
631  Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
632 
633  // Replaces the specified operand with new_operand.
634  Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
635 
636  // Replaces all uses of this instruction with the new producer. If
637  // new_producer is a user of this instruction then new_producer remains a use
638  // of this instruction to avoid introducing cycles into the graph.
639  //
640  // If this instruction is the root of its computation, sets the computation's
641  // root to new_producer.
642  Status ReplaceAllUsesWith(HloInstruction* new_producer);
643 
644  // Detaches an instruction from its operands. That is, remove the instruction
645  // from each operand's user set. This should only be called prior to
646  // deallocating the instruction.
647  void DetachFromOperands();
648 
649  // Performs a postorder DFS visit using this node as the root. If
650  // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
651  // complete. If ignore_control_predecessors is true, instructions only
652  // reachable via control dependencies will not be visited, and the postorder
653  // will not take control dependencies into account. It is as if the control
654  // dependencies didn't exist in the graph at all.
655  template <typename HloInstructionPtr>
656  Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
657  bool call_finish_visit = true,
658  bool ignore_control_predecessors = false);
659  Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
660  bool ignore_control_predecessors = false) const {
661  return const_cast<HloInstruction*>(this)->Accept(
662  visitor, call_finish_visit, ignore_control_predecessors);
663  }
664 
665  // Same as Accept() above, but the order of operand and control predecessor
666  // visitation is determined by the given operand order; if compare(A, B) ==
667  // true, A is visited before B.
668  using CompareFunction =
669  std::function<bool(const HloInstruction*, const HloInstruction*)>;
670  Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
671  const CompareFunction& operand_order,
672  bool call_finish_visit = true);
673 
674  // Performs a postorder DFS visit using this node as the root. Calls the given
675  // visitor function at each instruction.
676  Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
677  Status Accept(
678  const std::function<Status(const HloInstruction*)>& visitor_func) const;
679 
680  // Visits all instructions rooted at this instruction using the given visitor
681  // in the given order. 'order' must contain at least the set of instructions
682  // rooted at this node (ie, those accessible from a DFS traversal from this
683  // instruction). Instructions contained in 'order' which are not in the set of
684  // instructions rooted at this node are ignored. 'order' must also be a valid
685  // topological sort of these instructions (defs appear before uses) though
686  // need not be a DFS post-order.
687  Status AcceptOrdered(DfsHloVisitor* visitor,
688  const std::vector<const HloInstruction*>& order);
689 
690  // Visit this instruction and only this instruction with the given visitor.
691  template <typename HloInstructionPtr>
692  Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
693 
694  // Returns the literal associated with this instruction.
695  //
696  // Note: only constant and parameter opcodes have an associated literal.
697  const Literal& literal() const;
698 
699  // Returns the parameter number associated with this instruction.
700  //
701  // Note: only parameter opcodes have an associated parameter number.
702  int64 parameter_number() const {
703  CHECK_EQ(HloOpcode::kParameter, opcode_);
704  return parameter_number_;
705  }
706 
707  // Returns the dimension sizes or numbers associated with this instruction.
708  //
709  // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
710  // and reverse.
711  const std::vector<int64>& dimensions() const;
712  int64 dimensions(int64 index) const;
713 
714  // Accessor for the dimension in which a concatenate HLO should occur.
715  // Precondition: opcode() == HloOpcode::kConcatenate
716  int64 concatenate_dimension() const;
717 
718  // Returns the tuple index associated with this instruction.
719  //
720  // Precondition: opcode() == HloOpcode::kGetTupleElement
721  int64 tuple_index() const;
722 
723  // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
724  // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
725  // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
726  std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
727  const;
728 
729  std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
730  auto rv =
731  const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
732  return {const_cast<HloInstruction*>(rv.first), rv.second};
733  }
734 
735  // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
736  const HloInstruction* LatestNonGteAncestor() const;
737 
738  HloInstruction* LatestNonGteAncestor() {
739  return const_cast<HloInstruction*>(
740  const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
741  }
742 
743  // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
744  // The setter should only be called by HloModule or HloComputation methods.
745  //
746  // Precondition: The instruction has a valid to_apply_ field.
747  HloComputation* to_apply() const;
748  void set_to_apply(HloComputation* to_apply);
749 
750  // Returns the custom_call_target for CustomCall.
751  // Precondition: opcode() == HloOpcode::kCustomCall
752  const string& custom_call_target() const;
753 
754  // Returns the config for the Outfeed instruction.
755  // Precondition: opcode() == HloOpcode::kOutfeed
756  const string& outfeed_config() const;
757 
758  // Returns the shape for the Outfeed instruction.
759  // Precondition: opcode() == HloOpcode::kOutfeed
760  const Shape& outfeed_shape() const;
761 
762  // Gets/sets the while_condition or while_body HloComputation for While. The
763  // setters should only be called by HloModule or HloComputation methods.
764  //
765  // Precondition: The instruction is a While instruction.
766  HloComputation* while_condition() const;
767  HloComputation* while_body() const;
768  void set_while_condition(HloComputation* while_condition);
769  void set_while_body(HloComputation* while_body);
770 
771  // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
772  // setters should only be called by HloModule or HloComputation methods.
773  //
774  // Precondition: opcode() == HloOpcode::kSelectAndScatter.
775  HloComputation* select() const;
776  HloComputation* scatter() const;
777  void set_select(HloComputation* select);
778  void set_scatter(HloComputation* scatter);
779 
780  // Gets/sets the true and false HloComputation for Conditional. The setters
781  // should only be called by HloModule or HloComputation methods.
782  //
783  // Precondition: The instruction is a Conditional instruction.
784  HloComputation* true_computation() const;
785  HloComputation* false_computation() const;
786  void set_true_computation(HloComputation* true_computation);
787  void set_false_computation(HloComputation* false_computation);
788 
789  // Returns a string for the signature of this instruction if considered as a
790  // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
791  string SignatureString() const;
792 
793  // Returns a debugging string that represents this instruction.
794  //
795  // (We express the default options using an overload rather than a default
796  // param because gdb ignores default params, but does resolve overloads.)
797  //
798  // TODO(b/73348663): Make ToString() adaptive to the size of the string by
799  // default, backing off on providing full information for very large strings,
800  // or provide a different name for a ToString-like function that does that.
801  string ToString() const { return ToString(HloPrintOptions()); }
802  string ToString(const HloPrintOptions& options) const;
803 
804  // Components of the ToString() representation:
805 
806  // Returns a string representation of the operand list.
807  string OperandsToString(const HloPrintOptions& options) const;
808 
809  // Returns string representation of op-specific attributes.
810  std::vector<string> ExtraAttributesToString(
811  const HloPrintOptions& options) const;
812 
813  // As ToString, but returns a shorter string.
814  string ToShortString() const;
815 
816  // Returns a serialized representation of this instruction.
817  HloInstructionProto ToProto() const;
818 
819  // Returns a category for the HLO. This could be something like "convolution"
820  // or "elementwise".
821  string ToCategory() const;
822 
823  // Returns a logging instruction, if the output of this instruction is logged.
824  //
825  // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
826  HloInstruction* tracing() const;
827  void set_tracing(HloInstruction* trace_instruction);
828 
829  // Returns the channel id associated with the instruction. The id is
830  // shared between each Send/Recv pair and is globally unique to identify each
831  // channel.
832  //
833  // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
834  int64 channel_id() const { return channel_id_; }
835 
836  // Returns the channel name associated with the instruction. The name is
837  // used to identify host Send/Recv operations.
838  //
839  // Precondition: opcode() == HloOpcode::kHostCompute
840  string channel_name() const { return channel_name_; }
841 
842  // Returns feature_index field associated with the instruction. The index
843  // represents the index of the feature dimension.
844  //
845  // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
846  // or kBatchNormGrad.
847  int64 feature_index() const { return feature_index_; }
848 
849  // Returns a epsilon value associated with the instruction. The is a small
850  // number added to the variance to avoid divide-by-zero error.
851  //
852  // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
853  // or kBatchNormGrad.
854  float epsilon() const { return epsilon_; }
855 
856  // Returns the infeed configuration string. The infeed configuration includes
857  // any metadata needed for the backend compiler (e.g., infeed buffer address)
858  // and is target-dependent.
859  string infeed_config() const { return infeed_config_; }
860  void set_infeed_config(const string& config) { infeed_config_ = config; }
861 
862  // Returns a tag to be used in tracing.
863  //
864  // Precondition: opcode() == HloOpcode::kTrace
865  string TracingTag() const;
866 
867  // Returns whether the instruction is a constant.
868  bool IsConstant() const;
869 
870  // Returns true if this instruction is fused, ie contained within a fusion
871  // instruction.
872  bool IsFused() const;
873 
874  // Returns the computation for this fused instruction.
875  //
876  // Precondition: opcode() == HloOpcode::kFusion
877  HloComputation* fused_instructions_computation() const;
878 
879  // Returns true if this instruction can be legally fused into a fusion
880  // instruction.
881  bool IsFusable() const;
882 
883  // Returns the root instruction of the fused expression contained within this
884  // fusion instruction.
885  //
886  // Precondition: opcode() == HloOpcode::kFusion
887  HloInstruction* fused_expression_root() const;
888 
889  // Returns the list of fused instructions inside this fusion instruction. The
890  // returned type is a range of HloInstruction*s.
891  //
892  // Precondition: opcode() == HloOpcode::kFusion
893  const tensorflow::gtl::iterator_range<UnwrappingIterator<
894  std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
895  fused_instructions() const;
896 
897  const tensorflow::gtl::iterator_range<
898  UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
899  fused_instructions();
900 
901  // Gets the number of instructions inside this fusion instruction.
902  //
903  // Precondition: opcode() == HloOpcode::kFusion
904  int64 fused_instruction_count() const;
905 
906  // Returns the fused parameter instruction in this fusion instruction
907  // corresponding to the given parameter number.
908  //
909  // Precondition: opcode() == HloOpcode::kFusion
910  HloInstruction* fused_parameter(int64 parameter_number) const;
911 
912  // Returns the vector of fused parameters inside this fusion instruction.
913  //
914  // Precondition: opcode() == HloOpcode::kFusion
915  const std::vector<HloInstruction*>& fused_parameters() const;
916 
917  // Returns true if this instruction is a fusion instruction that generates
918  // multiple outputs.
919  const bool IsMultiOutputFusion() const {
920  return opcode() == HloOpcode::kFusion &&
921  fused_expression_root()->opcode() == HloOpcode::kTuple;
922  }
923 
924  FusionKind fusion_kind() const {
925  CHECK_EQ(HloOpcode::kFusion, opcode_);
926  return fusion_kind_;
927  }
928 
929  void set_fusion_kind(FusionKind kind) {
930  CHECK_EQ(HloOpcode::kFusion, opcode_);
931  fusion_kind_ = kind;
932  }
933 
934  // Returns the sharding applied to this operator.
935  // REQUIRES: has_sharding() is true.
936  const HloSharding& sharding() const {
937  CHECK(has_sharding());
938  return *sharding_;
939  }
940  // Returns the sharding applied to this operator, or default_ if none exists.
941  const HloSharding& sharding_or_default(const HloSharding& default_) const {
942  return sharding_ ? *sharding_ : default_;
943  }
944  // Returns the sharding unique device, if any.
945  tensorflow::gtl::optional<int64> sharding_unique_device() const {
946  if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) {
947  return tensorflow::gtl::optional<int64>();
948  }
949  return sharding_->UniqueDevice().ValueOrDie();
950  }
951  // Sets the sharding of this operator. Should only be called by HloModule or
952  // HloComputation methods.
953  void set_sharding(const HloSharding& sharding) {
954  sharding_ = MakeUnique<HloSharding>(sharding);
955  }
956  // Remove any sharding from this operator.
957  void clear_sharding() { sharding_ = nullptr; }
958  // Return true if this operator has a sharding assigned.
959  bool has_sharding() const { return sharding_ != nullptr; }
960 
961  // Adds a new operand the fusion instruction.
962  HloInstruction* AddFusionOperand(HloInstruction* new_operand);
963 
964  // Merges the fused instructions from 'instruction_to_merge' into the
965  // fused instruction set of 'this', updating operands as necessary.
966  //
967  // Precondition: opcode() == HloOpcode::kFusion
968  // Predondition: 'instruction_to_merge' must be an operand of 'this'.
969  void MergeFusionInstruction(HloInstruction* instruction_to_merge);
970 
971  // Merges the fused instructions from instruction_to_merge into the fused
972  // instruction set of 'this' and generates multioutput fusion instructions.
973  // All the users of instruction_to_merge will be redirected to 'this'
974  // instruction. instruction_to_merge will be removed from its parent
975  // computation.
976  //
977  // Precondition: opcode() == HloOpcode::kFusion
978  void MergeFusionInstructionIntoMultiOutput(
979  HloInstruction* instruction_to_merge);
980 
981  // Fuses the given instruction in this fusion instruction. instruction_to_fuse
982  // is cloned and the clone is placed in the fusion
983  // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
984  // than moved to cleanly handle the case where the instruction has a use
985  // outside the fusion instruction. Moving such an instruction into a fusion
986  // instruction would violate the single-result invariant of HLO instructions
987  // and significantly complicate code generation.
988  //
989  // Precondition: this->opcode() == HloOpcode::kFusion
990  HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
991  return FuseInstructionInternal(instruction_to_fuse);
992  }
993 
994  // Fuses the given instruction in this fusion instruction and generate
995  // multioutput fusion instruction. A clone of the instruction_to_fuse will
996  // be part of the output of fusion instructions. The users of
997  // instruction_to_fuse will be redirected to this fusion instructions.
998  // instruction_to_fuse will be removed from its parent computation.
999  //
1000  // Precondition: this->opcode() == HloOpcode::kFusion
1001  HloInstruction* FuseInstructionIntoMultiOutput(
1002  HloInstruction* instruction_to_fuse) {
1003  return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
1004  }
1005 
1006  // Returns the start index in the given dimension for a slice node.
1007  //
1008  // Precondition: opcode() == HloOpcode::kSlice
1009  int64 slice_starts(int64 dimension) const {
1010  CHECK_EQ(HloOpcode::kSlice, opcode_);
1011  return slice_starts_[dimension];
1012  }
1013  const std::vector<int64>& slice_starts() const { return slice_starts_; }
1014 
1015  // Returns the (exclusive) limit index in the given dimension for a slice
1016  // node.
1017  //
1018  // Precondition: opcode() == HloOpcode::kSlice
1019  int64 slice_limits(int64 dimension) const {
1020  CHECK_EQ(HloOpcode::kSlice, opcode_);
1021  return slice_limits_[dimension];
1022  }
1023  const std::vector<int64>& slice_limits() const {
1024  CHECK_EQ(HloOpcode::kSlice, opcode_);
1025  return slice_limits_;
1026  }
1027 
1028  // Returns the stride in the given dimension for a slice node.
1029  //
1030  // Precondition: opcode() == HloOpcode::kSlice
1031  int64 slice_strides(int64 dimension) const {
1032  CHECK_EQ(HloOpcode::kSlice, opcode_);
1033  return slice_strides_[dimension];
1034  }
1035  const std::vector<int64>& slice_strides() const { return slice_strides_; }
1036 
1037  // Returns the flag that describes whether a slice must be lowered into an
1038  // offset into the original operand.
1039  bool IsInPlaceSlice() const { return is_in_place_slice_; }
1040 
1041  // Sets and returns the flag that describes whether a slice must be lowered
1042  // into an offset into the original operand.
1043  bool SetIsInPlaceSlice(bool value) {
1044  is_in_place_slice_ = value;
1045  return value;
1046  }
1047 
1048  // Returns the size of the slice in the given dimension for a dynamic
1049  // slice node.
1050  //
1051  // Precondition: opcode() == HloOpcode::kDynamicSlice
1052  int64 slice_sizes(int64 dimension) const {
1053  CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1054  return dynamic_slice_sizes_[dimension];
1055  }
1056  const std::vector<int64>& dynamic_slice_sizes() const {
1057  CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1058  return dynamic_slice_sizes_;
1059  }
1060 
1061  // Returns the number of exponent bits for a reduce-precision node.
1062  //
1063  // Precondition: opcode() == HloOpcode::kReducePrecision
1064  int32 exponent_bits() const {
1065  CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1066  return exponent_bits_;
1067  }
1068 
1069  // Returns the number of mantissa bits for a reduce-precision node.
1070  //
1071  // Precondition: opcode() == HloOpcode::kReducePrecision
1072  int32 mantissa_bits() const {
1073  CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1074  return mantissa_bits_;
1075  }
1076 
1077  // Returns data on the window in a windowed operation such as
1078  // convolution.
1079  const Window& window() const {
1080  CHECK(window_ != nullptr);
1081  return *window_;
1082  }
1083 
1084  // Sets the window data in a windowed operation such as convolution.
1085  void set_window(const Window& window) {
1086  window_ = MakeUnique<Window>(window);
1087  }
1088 
1089  // Returns the padding configuration for a pad node.
1090  //
1091  // Precondition: opcode() == HloOpcode::kPad
1092  const PaddingConfig& padding_config() const {
1093  CHECK(padding_config_ != nullptr);
1094  return *padding_config_;
1095  }
1096 
1097  // Returns data on the dimension numbers used for a convolution operation,
1098  // which may be a kConvolution instruction or a kCustomCall that implements a
1099  // convolution.
1100  const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1101  CHECK(convolution_dimension_numbers_ != nullptr);
1102  return *convolution_dimension_numbers_;
1103  }
1104 
1105  // Sets the convolution dimension numbers on this instruction. In general you
1106  // shouldn't need to call this; instead, specify the convolution dimension
1107  // numbers when you create the instruction.
1108  void set_convolution_dimension_numbers(
1109  const ConvolutionDimensionNumbers& dnums) {
1110  convolution_dimension_numbers_ =
1111  MakeUnique<ConvolutionDimensionNumbers>(dnums);
1112  }
1113 
1114  FftType fft_type() const {
1115  CHECK_EQ(HloOpcode::kFft, opcode_);
1116  return fft_type_;
1117  }
1118 
1119  const std::vector<int64>& fft_length() const {
1120  CHECK_EQ(HloOpcode::kFft, opcode_);
1121  return fft_length_;
1122  }
1123 
1124  // Returns the dump string of the convolution dimension numbers.
1125  string ConvolutionDimensionNumbersToString() const;
1126 
1127  // Returns data on the dimension numbers used for a dot operation.
1128  const DotDimensionNumbers& dot_dimension_numbers() const {
1129  CHECK(dot_dimension_numbers_ != nullptr);
1130  return *dot_dimension_numbers_;
1131  }
1132 
1133  // Returns the dump string of the dot dimension numbers.
1134  string DotDimensionNumbersToString() const;
1135 
1136  const GatherDimensionNumbers& gather_dimension_numbers() const {
1137  CHECK(gather_dimension_numbers_ != nullptr);
1138  return *gather_dimension_numbers_;
1139  }
1140 
1141  tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
1142  CHECK_EQ(opcode(), HloOpcode::kGather);
1143  return gather_window_bounds_;
1144  }
1145 
1146  // Returns the dump string of the gather dimension numbers.
1147  string GatherDimensionNumbersToString() const;
1148 
1149  // Returns the random distribution for this rng node.
1150  //
1151  // Precondition: opcode() == HloOpcode::kRng
1152  RandomDistribution random_distribution() const;
1153 
1154  // Clones the HLO instruction. The clone will have the same opcode, shape, and
1155  // operands. After creation the clone has no uses. "this" (the instruction
1156  // cloned from) is not changed. Suffix is the string to append to the name of
1157  // the instruction to form the name of the cloned instruction.
1158  // If the module pointer is not nullptr, it will be the module where
1159  // the cloned computations will be added to (in order to support deep
1160  // cloning).
1161  std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
1162  HloModule* module = nullptr) const;
1163 
1164  // Clones the HLO instruction as above but with new shape and operands.
1165  // If the module pointer is not nullptr, it will be the module where
1166  // the cloned computations will be added to (in order to support deep
1167  // cloning).
1168  std::unique_ptr<HloInstruction> CloneWithNewOperands(
1169  const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1170  HloModule* module = nullptr) const;
1171 
1172  // Returns the computations this instruction directly calls (if any).
1173  const std::vector<HloComputation*>& called_computations() const {
1174  return called_computations_;
1175  }
1176 
1177  // Replaces all called computations based on a map function. This is needed
1178  // when we clone hlo_computations and want to let the instructions to point
1179  // to the newly cloned nodes.
1180  void ReplaceCalledComputations(
1181  std::function<HloComputation*(HloComputation*)> map_function) {
1182  for (int64 i = 0; i < called_computations_.size(); ++i) {
1183  called_computations_[i] = map_function(called_computations_[i]);
1184  }
1185  }
1186 
1187  // Clears out the called computations.
1188  //
1189  // This is, in particular, necessary when inlining function bodies into their
1190  // caller. If there were side-effecting operations in the called computations,
1191  // the call itself is considered side-effecting and thus cannot be removed. By
1192  // clearing out the computations, we reflect the fact that all side-effecting
1193  // properties have been reflected in the caller, and make the call HLO
1194  // removable.
1195  void ClearCalledComputations() { called_computations_.clear(); }
1196 
1197  // Returns true if this instruction performs an elementwise operation on
1198  // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1199  // after performing necessary implicit broadcast
1200  // (cs/IrArray::EmitArrayElementAddress), to compute the output at index
1201  // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is
1202  // the element at {i_0,i_1,...,i_n}.
1203  //
1204  // Note on performance: when this instruction is kFusion, this method, in the
1205  // worst case, scans all fused instructions. We could speed this up by
1206  // caching.
1207  bool IsElementwiseOnOperand(int64 operand_idx) const;
1208 
1209  // Returns true if this instruction is elementwise on all its operands.
1210  bool IsElementwise() const;
1211 
1212  // Returns true if this elementwise instruction implicitly broadcasts operand
1213  // `operand_idx`.
1214  //
1215  // Precondition: this instruction should be an elementwise operation.
1216  bool ImplicitlyBroadcastsOperand(int64 operand_idx) const;
1217 
1218  // Returns true if this instruction is binary and elementwise.
1219  bool IsElementwiseBinary() const;
1220 
1221  // Returns whether this instruction may reuse elements of its `i`th operand.
1222  bool ReusesOperandElements(int64 i) const {
1223  return OperandElementUse(i) == UseKind::kReuse;
1224  }
1225 
1226  // Returns the indices that the given operand appear in the operand list of
1227  // this instruction. Note that an instruction can use the same operand
1228  // multiple times.
1229  std::vector<int64> OperandIndices(const HloInstruction* operand) const;
1230 
1231  // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1232  // this reshape merely inserts or deletes 1-sized dimensions, return the input
1233  // indices of the deleted dimensions and the output indices of the inserted
1234  // dimensions.
1235  //
1236  // Precondition: this op must be a reshape.
1237  std::tuple<bool, std::vector<int64>, std::vector<int64>>
1238  ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1239 
1240  // Gets/sets the string identifier for this instruction.
1241  const string& name() const { return name_; }
1242  void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); }
1243 
1244  // Use the given NameUniquer to select a unique name for the instruction based
1245  // on the instruction's existing name.
1246  void UniquifyName(NameUniquer* name_uniquer);
1247 
1248  // Set the unique id for this instruction to "id"
1249  void SetUniqueId(int id) {
1250  CHECK_EQ(unique_id_, -1); // Should not be assigned already
1251  CHECK_GE(id, 0);
1252  unique_id_ = id;
1253  }
1254 
1255  // Return the unique ID assigned to this node via SetUniqueId (or -1
1256  // if no id has been assigned yet).
1257  int unique_id() const { return unique_id_; }
1258 
1259  // Sets the debug metadata for this instruction.
1260  void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
1261  const OpMetadata& metadata() const { return metadata_; }
1262 
1263  // Set/get the computation containing this instruction. set_parent should only
1264  // be called by HloComputation methods which add/remove instructions to
1265  // computations.
1266  void set_parent(HloComputation* computation) { parent_ = computation; }
1267  const HloComputation* parent() const { return parent_; }
1268  HloComputation* parent() { return parent_; }
1269 
1270  // Returns the module for this instruction.
1271  HloModule* GetModule() const;
1272 
1273  // Returns whether we could assign input and output layouts to this
1274  // instruction to make it a bitcast.
1275  bool CouldBeBitcast() const;
1276 
1277  // Get/Set the number of partitions per outer dimension (in order, starting
1278  // with outer-most dimension first). Currently used by the parallel cpu
1279  // backend to partition HLOs into parallel tasks.
1280  // TODO(b/62783254) Replace these methods with a more general way to
1281  // annotate HLOs with backend-specific information.
1282  const std::vector<int64>& outer_dimension_partitions() const {
1283  return outer_dimension_partitions_;
1284  }
1285  void set_outer_dimension_partitions(
1286  const std::vector<int64>& outer_dimension_partitions);
1287 
1288  // Change the layout for an Constant Hlo instruction to match new_layout. For
1289  // tuple shaped constants shape_index is the path to the internal array
1290  // subshape whose layout needs to be changed.
1291  void RelayoutConstant(const Layout& new_layout,
1292  const ShapeIndex& shape_index = {});
1293 
1294  private:
1295  enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
1296 
1297  // Helper class for computing OperandElementUse for kFusion.
1298  class FusionReusesParamElements;
1299 
1300  // See comments on Identical().
1301  // eq_shapes() is used to check shapes for equality, and would normally be
1302  // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
1303  // whether we want a layout-sensitive check or not.
1304  bool IdenticalSlowPath(
1305  const HloInstruction& other,
1306  const std::function<bool(const HloComputation*, const HloComputation*)>&
1307  eq_computations,
1308  const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
1309 
1310  // Creates an n-ary elementwise operation.
1311  static std::unique_ptr<HloInstruction> CreateNary(
1312  const Shape& shape, HloOpcode opcode,
1313  tensorflow::gtl::ArraySlice<HloInstruction*> operands);
1314 
1315  // Appends operand to the list of operands and adds this instruction as a user
1316  // of the operand.
1317  void AppendOperand(HloInstruction* operand);
1318 
1319  // Adds a user for this instruction.
1320  void AddUser(HloInstruction* user);
1321 
1322  // Removes a user for this instruction.
1323  void RemoveUser(HloInstruction* user);
1324 
1325  // Internal constructor for a given opcode/shape, other fields must be filled
1326  // by factory methods.
1327  HloInstruction(HloOpcode opcode, const Shape& shape);
1328 
1329  // Fuses the given instruction into this fusion instruction. When add_output
1330  // is false (which is the default), instruction_to_fuse is cloned and the
1331  // clone is placed in the fusion instruction. instruction_to_fuse is
1332  // unchanged.
1333  //
1334  // When add_output is true, a clone of the instruction_to_fuse will be part
1335  // of the output of fusion instructions. The users of instruction_to_fuse
1336  // will be redirected to this fusion instructions. instruction_to_fuse will
1337  // be removed from its parent computation.
1338  //
1339  // Precondition: this->opcode() == HloOpcode::kFusion
1340  HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
1341  bool add_output = false);
1342 
1343  // Clones the given instruction_to_fuse and insert the clone into this fusion
1344  // instruction. If add_output is true, a clone of instruction_to_fuse will
1345  // be in the output of the this fusion instruction (part of the tuple of the
1346  // fusion root).
1347  //
1348  // Precondition: opcode() == HloOpcode::kFusion
1349  HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
1350  bool add_output = false);
1351 
1352  // Clones a fusion instruction with a new shape and operands.
1353  std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
1354  const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1355  HloModule* module = nullptr) const;
1356 
1357  // Returns true if this instruction can legally have the dimensions field
1358  // set. Used for checking precondition of dimensions field accessors.
1359  bool CanHaveDimensionsField() const;
1360 
1361  // Returns how this instruction uses elements of its `i`th operand.
1362  UseKind OperandElementUse(int64 i) const;
1363 
1364  int unique_id_; // Unique to this HloInstruction within a HloModule
1365 
1366  // Opcode for this instruction.
1367  HloOpcode opcode_;
1368 
1369  // Instruction operands.
1370  InstructionVector operands_;
1371 
1372  // The set of control predecessors of this instruction.
1373  std::vector<HloInstruction*> control_predecessors_;
1374 
1375  // The users of this instruction. Users are HLOs where this instruction is an
1376  // operand. The vector users_ and the set user_set_ contain identical
1377  // members. The set enables fast membership testing and the vector enables
1378  // fast, stable iteration.
1379  std::vector<HloInstruction*> users_;
1380  std::unordered_set<const HloInstruction*> user_set_;
1381 
1382  // The set of control successors of this instruction.
1383  std::vector<HloInstruction*> control_successors_;
1384 
1385  // The computation in which this instruction is contained.
1386  HloComputation* parent_ = nullptr;
1387 
1388  // Shape of outfeed request.
1389  Shape outfeed_shape_;
1390 
1391  // Result shape of this instruction.
1392  Shape shape_;
1393 
1394  // Literal, only present for kConstant.
1395  std::unique_ptr<Literal> literal_;
1396 
1397  // Constant index, only present for kGetTupleElement.
1398  int64 tuple_index_ = -1;
1399 
1400  // Dimensions present for some operations that require reshaping or
1401  // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
1402  std::vector<int64> dimensions_;
1403 
1404  // Describes the window in a windowed operation such as convolution.
1405  std::unique_ptr<Window> window_;
1406 
1407  // Describes the dimension numbers used for a convolution.
1408  std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1409 
1410  // Describes the dimension numbers used for a dot.
1411  std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
1412 
1413  std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1414  std::vector<int64> gather_window_bounds_;
1415 
1416  // Describes FFT type for an FFT instruction.
1417  FftType fft_type_ = FftType::FFT;
1418 
1419  // Indicates the FFT length for an FFT instruction.
1420  std::vector<int64> fft_length_;
1421 
1422  // Describes the [begin, end) index range for a slice.
1423  std::vector<int64> slice_starts_;
1424  std::vector<int64> slice_limits_;
1425  std::vector<int64> slice_strides_;
1426 
1427  // Describes whether the slice can be lowered to an offset into the operand.
1428  bool is_in_place_slice_ = false;
1429 
1430  // The bit sizes for a reduce-precision operation.
1431  int32 exponent_bits_ = 0;
1432  int32 mantissa_bits_ = 0;
1433 
1434  // Describes the [start, start + size) range size for a dynamic slice
1435  // ('start' is specified dynamically in the second operand of the operation).
1436  std::vector<int64> dynamic_slice_sizes_;
1437 
1438  // The padding configuration that describes the edge padding and interior
1439  // padding of this pad instruction. Only set for pad instructions.
1440  std::unique_ptr<PaddingConfig> padding_config_;
1441 
1442  // The type of the fusion. Used by kFusion only.
1443  FusionKind fusion_kind_;
1444 
1445  // The sharding, if one exists.
1446  std::unique_ptr<HloSharding> sharding_;
1447 
1448  // For parameter instructions this field holds the parameter number.
1449  int64 parameter_number_ = 0;
1450 
1451  // Name of a global symbol to call, only present for kCustomCall.
1452  string custom_call_target_;
1453 
1454  // Name to use for host send/recv channels, only present for kHostCompute.
1455  string channel_name_;
1456 
1457  // Estimate of the duration of a host computation in nanoseconds.
1458  int64 cost_estimate_ns_;
1459 
1460  // Computations called by this instruction.
1461  std::vector<HloComputation*> called_computations_;
1462 
1463  // Indices of computations in called_computations_ for instructions which call
1464  // multiple computations.
1465  enum {
1466  // kWhile computations.
1467  kBodyComputationIndex = 0,
1468  kConditionComputationIndex = 1,
1469 
1470  // kSelectAndScatter computations.
1471  kSelectComputationIndex = 0,
1472  kScatterComputationIndex = 1,
1473 
1474  // kConditional computations.
1475  kTrueComputationIndex = 0,
1476  kFalseComputationIndex = 1,
1477  };
1478 
1479  // Outfeed configuration information, only present for kOutfeed.
1480  string outfeed_config_;
1481 
1482  // A trace instruction that consumes this instruction.
1483  //
1484  // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
1485  // an operand.
1486  HloInstruction* trace_instruction_ = nullptr;
1487 
1488  // The distribution requested for random number generation.
1489  // Only present for kRng.
1490  RandomDistribution distribution_;
1491 
1492  // A small float number added to the variance to avoid divide-by-zero error.
1493  // Only present for kBatchNormTraining.
1494  float epsilon_ = 0.0f;
1495 
1496  // An integer value representing the index of the feature dimension.
1497  // Only present for kBatchNormTraining.
1498  int64 feature_index_ = -1;
1499 
1500  // Represents a unique identifier for each Send/Recv instruction pair.
1501  // Only present for kSend or kRecv.
1502  int64 channel_id_ = -1;
1503 
1504  // The string representation of the infeed configuration.
1505  string infeed_config_;
1506 
1507  // String identifier for instruction.
1508  string name_;
1509 
1510  // Metadata for debugging.
1511  OpMetadata metadata_;
1512 
1513  // The number of partitions per outer dimension (listed in order from
1514  // outer-most dimension first).
1515  std::vector<int64> outer_dimension_partitions_;
1516 
1517  TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
1518 };
1519 
1520 string ToString(HloInstruction::FusionKind kind);
1521 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
1522  const string& kind_name);
1523 
1524 // Custom (de)stringification functions for protos that live inside
1525 // HloInstruction.
1526 string PaddingConfigToString(const PaddingConfig& padding);
1527 string OpMetadataToString(const OpMetadata& metadata);
1528 string RandomDistributionToString(const RandomDistribution& distribution);
1529 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
1530 
1531 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
1532 
1533 // Map classes that guarantee a deterministic iteration order when the key is
1534 // an HloInstruction* or a const HloInstruction*.
1535 // To make the iteration order over the map deterministic, the comparator
1536 // should not be using the pointer values, but rather an intrinsic property of
1537 // the hlo.
1538 //
1539 // Note that this cannot be used for HLO instructions across multiple modules
1540 // since the id of HLO instructions are only unique within each HLO module.
1541 struct HloPtrComparator {
1542  bool operator()(const HloInstruction* const& lhs,
1543  const HloInstruction* const& rhs) const {
1544  return lhs->unique_id() < rhs->unique_id();
1545  }
1546 };
1547 
1548 template <typename ValueT>
1549 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
1550 
1551 template <typename ValueT>
1552 using ConstHloInstructionMap =
1553  std::map<const HloInstruction*, ValueT, HloPtrComparator>;
1554 
1555 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
1556 using ConstHloInstructionSet =
1557  std::set<const HloInstruction*, HloPtrComparator>;
1558 
1559 } // namespace xla
1560 
1561 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
Definition: hlo_computation.h:60
static std::unique_ptr< HloInstruction > CreateFusion(const Shape &shape, FusionKind fusion_kind, HloInstruction *fused_root)
Definition: hlo_instruction.cc:800
Definition: hlo_instruction.h:165
FusionKind
Definition: hlo_instruction.h:170
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52