1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// HLO instructions are in DAG form and represent the computations that the user
17// has built up via the XLA service interface. They are ultimately lowered
18// in a platform-aware way by traversing the HLO DAG and emitting a lowered
19// form; e.g. see DfsHloVisitor.
20
21#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
22#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
23
24#include <functional>
25#include <iosfwd>
26#include <list>
27#include <memory>
28#include <set>
29#include <string>
30#include <tuple>
31#include <unordered_map>
32#include <unordered_set>
33#include <vector>
34
35#include "tensorflow/compiler/xla/iterator_util.h"
36#include "tensorflow/compiler/xla/literal_util.h"
37#include "tensorflow/compiler/xla/map_util.h"
38#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
39#include "tensorflow/compiler/xla/service/hlo.pb.h"
40#include "tensorflow/compiler/xla/service/hlo_opcode.h"
41#include "tensorflow/compiler/xla/service/hlo_sharding.h"
42#include "tensorflow/compiler/xla/service/name_uniquer.h"
43#include "tensorflow/compiler/xla/types.h"
44#include "tensorflow/compiler/xla/xla_data.pb.h"
45#include "tensorflow/core/lib/core/status.h"
46#include "tensorflow/core/lib/core/stringpiece.h"
47#include "tensorflow/core/lib/gtl/array_slice.h"
48#include "tensorflow/core/lib/gtl/flatmap.h"
49#include "tensorflow/core/lib/gtl/inlined_vector.h"
50#include "tensorflow/core/lib/gtl/iterator_range.h"
51#include "tensorflow/core/platform/logging.h"
52#include "tensorflow/core/platform/macros.h"
53#include "tensorflow/core/platform/types.h"
54
55namespace xla {
56
57class HloComputation;
58class HloModule;
59
60// A bunch of switches that control how the hlo text should be printed.
61class HloPrintOptions {
62 public:
63 // Constructs the default print options: don't print large constants, don't
64 // compact operands, no indentation.
65 HloPrintOptions()
66 : print_large_constants_(false),
67 print_subcomputation_references_(true),
68 print_metadata_(true),
69 compact_operands_(false),
70 print_operand_shape_(true),
71 print_program_shape_(true),
72 print_percent_(true),
73 indent_amount_(0) {}
74
75 static HloPrintOptions ShortParsable() {
76 return HloPrintOptions()
77 .set_print_large_constants(true)
78 .set_print_subcomputation_references(true)
79 .set_print_metadata(false)
80 .set_print_operand_shape(false)
81 .set_print_program_shape(false)
82 .set_print_percent(false);
83 }
84
85 // If true, large constants will be printed out.
86 HloPrintOptions& set_print_large_constants(bool value) {
87 print_large_constants_ = value;
88 return *this;
89 }
90
91 // If true, the names of subcomputations (e.g. a fusion node's fused
92 // computation) won't be printed. This makes the resulting text not parsable.
93 //
94 // A CustomCall's call target is printed even if
95 // print_subcomputation_references is false, because the call target isn't an
96 // HloComputation.
97 HloPrintOptions& set_print_subcomputation_references(bool value) {
98 print_subcomputation_references_ = value;
99 return *this;
100 }
101
102 // If true, metatdata will be printed.
103 HloPrintOptions& set_print_metadata(bool value) {
104 print_metadata_ = value;
105 return *this;
106 }
107
108 // If true, operands' shapes will be printed.
109 HloPrintOptions& set_print_operand_shape(bool value) {
110 print_operand_shape_ = value;
111 return *this;
112 }
113
114 // If true, program shape of hlo computations will be printed.
115 HloPrintOptions& set_print_program_shape(bool value) {
116 print_program_shape_ = value;
117 return *this;
118 }
119
120 // If true, names will be printed with prefix '%'.
121 HloPrintOptions& set_print_percent(bool value) {
122 print_percent_ = value;
123 return *this;
124 }
125
126 // If true, only a part of operands will be printed out, and their names will
127 // be omitted (note that in this case the text will not be parsable).
128 HloPrintOptions& set_compact_operands(bool value) {
129 compact_operands_ = value;
130 return *this;
131 }
132
133 // The indent of the hlo text block.
134 HloPrintOptions& set_indent_amount(int value) {
135 indent_amount_ = value;
136 return *this;
137 }
138
139 bool print_large_constants() const { return print_large_constants_; }
140 bool print_subcomputation_references() const {
141 return print_subcomputation_references_;
142 }
143 bool print_metadata() const { return print_metadata_; }
144 bool compact_operands() const { return compact_operands_; }
145 bool print_operand_shape() const { return print_operand_shape_; }
146 bool print_program_shape() const { return print_program_shape_; }
147 bool print_percent() const { return print_percent_; }
148 int indent_amount() const { return indent_amount_; }
149
150 private:
151 bool print_large_constants_;
152 bool print_subcomputation_references_;
153 bool print_metadata_;
154 bool compact_operands_;
155 bool print_operand_shape_;
156 bool print_program_shape_;
157 bool print_percent_;
158 int indent_amount_;
159};
160
161// HLO instructions are the IR used by the high-level compiler.
162class HloInstruction {
163 public:
164 enum class FusionKind {
165 kLoop, // Fused into a loop.
166 kInput, // Op's input is fused into the op itself.
167 kOutput, // Op's output is fused into the op itself.
168 // REQUIRES: At least one operand buffer must be able
169 // to alias the output buffer.
170 kTransposeDot, // Fused into a dot with transposed operands.
171 kCustom, // Custom category for backend-specific fusions that
172 // do not match any of the more specific ones.
173 };
174
175 ~HloInstruction();
176
177 // Creates an instruction from the given proto. Arguments:
178 //
179 // module: the module which will contain the instruction. The newly created
180 // instruction is *not* added to the module or any computation, however.
181 // proto: the proto to convert from.
182 // instruction_map: a map from instruction id to HloInstruction*. This map
183 // must contain all operands of the newly constructed instruction.
184 // computation_map: a map from computation id to HloComputation*. This map
185 // must contain all computations which the newly constructed instruction
186 // calls.
187 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
188 HloModule* module, const HloInstructionProto& proto,
189 const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
190 const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
191
192 // Creates a parameter-retrieving instruction.
193 static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
194 const Shape& shape,
195 const string& name);
196
197 // Creates a literal constant instruction.
198 static std::unique_ptr<HloInstruction> CreateConstant(
199 std::unique_ptr<Literal> literal);
200
201 // Creates a get tuple element instruction.
202 static std::unique_ptr<HloInstruction> CreateGetTupleElement(
203 const Shape& shape, HloInstruction* operand, int64 index);
204
205 // Creates a trace instruction that logs the input operand in the computation.
206 static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
207 HloInstruction* operand);
208
209 // Creates a random number generation instruction that fills a shape with
210 // random numbers from a given distribution.
211 static std::unique_ptr<HloInstruction> CreateRng(
212 const Shape& shape, RandomDistribution distribution,
213 tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
214
215 // Creates a unary instruction (one operand).
216 // Precondition: opcode must be a legitimate unary operation.
217 static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
218 HloOpcode opcode,
219 HloInstruction* operand);
220
221 // Creates a binary instruction (two operands).
222 // Precondition: opcode must be a legitimate binary operation.
223 static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
224 HloOpcode opcode,
225 HloInstruction* lhs,
226 HloInstruction* rhs);
227
228 // Creates a ternary instruction (three operands).
229 // Precondition: opcode must be a legitimate ternary operation.
230 static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
231 HloOpcode opcode,
232 HloInstruction* lhs,
233 HloInstruction* rhs,
234 HloInstruction* ehs);
235
236 // Creates a variadic instruction (variable number of operands).
237 // Precondition: opcode must be a legitimate variadic operation.
238 static std::unique_ptr<HloInstruction> CreateVariadic(
239 const Shape& shape, HloOpcode opcode,
240 tensorflow::gtl::ArraySlice<HloInstruction*> operands);
241
242 // Creates a map instruction, where the computation (given by the handle) is
243 // applied element-wise to every element in operands (across the operands,
244 // at a given index) with the same `static_operands`.
245 static std::unique_ptr<HloInstruction> CreateMap(
246 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
247 HloComputation* map_computation,
248 tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
249
250 // Creates a convolution op, where rhs is the convolutional filter
251 // and window describes how the filter is applied to lhs.
252 static std::unique_ptr<HloInstruction> CreateConvolve(
253 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
254 const Window& window,
255 const ConvolutionDimensionNumbers& dimension_numbers);
256
257 // Creates an FFT op, of the type indicated by fft_type.
258 static std::unique_ptr<HloInstruction> CreateFft(
259 const Shape& shape, HloInstruction* operand, FftType fft_type,
260 tensorflow::gtl::ArraySlice<int64> fft_length);
261
262 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
263 // dimensions specified in 'dimension_numbers'.
264 static std::unique_ptr<HloInstruction> CreateDot(
265 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
266 const DotDimensionNumbers& dimension_numbers);
267
268 // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
269 // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
270 // and the RHS must be of rank 2.
271 static std::unique_ptr<HloInstruction> CreateCanonicalDot(
272 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
273
274 // Creates a reduce-precision op, where operand is the data to reduce in
275 // precision, and exponent_bits and mantissa_bits describe the precision to
276 // reduce it to.
277 static std::unique_ptr<HloInstruction> CreateReducePrecision(
278 const Shape& shape, HloInstruction* operand, const int exponent_bits,
279 const int mantissa_bits);
280
281 // Creates a cross replica sum op.
282 static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
283 const Shape& shape,
284 tensorflow::gtl::ArraySlice<HloInstruction*> operands);
285
286 // Creates a conversion instruction, where operand is the data to convert and
287 // shape is the target shape for the conversion.
288 static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
289 HloInstruction* operand);
290
291 // Creates a bitcast conversion instruction, where operand is the data to
292 // convert and shape is the target shape for the conversion.
293 static std::unique_ptr<HloInstruction> CreateBitcastConvert(
294 const Shape& shape, HloInstruction* operand);
295
296 // Creates an infeed instruction, which reads data of the given shape from the
297 // Infeed interface of the device.
298 static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
299 const string& config);
300
301 // Creates an outfeed instruction, which outputs data.
302 static std::unique_ptr<HloInstruction> CreateOutfeed(
303 const Shape& shape, HloInstruction* operand,
304 tensorflow::StringPiece outfeed_config);
305
306 // Creates an asynchronous send instruction with the given channel id, which
307 // initiates sending the operand data to a unique receive instruction in
308 // another computation that has the same channel id.
309 static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
310 int64 channel_id);
311
312 // Blocks until data transfer for the Send instruction (operand) is complete.
313 // The operand must be kSend.
314 static std::unique_ptr<HloInstruction> CreateSendDone(
315 HloInstruction* operand);
316
317 // Creates an asynchronous receive instruction with the given channel id,
318 // which allocates resources to receive data of the given shape from a unique
319 // send instruction in another computation that has the same channel id.
320 static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
321 int64 channel_id);
322
323 // Blocks until data transfer for the Recv instruction (operand) is complete
324 // and returns the receive buffer. The operand must be kRecv.
325 static std::unique_ptr<HloInstruction> CreateRecvDone(
326 HloInstruction* operand);
327
328 // Creates a slice instruction, where the operand is sliced by the given
329 // start/limit indices.
330 static std::unique_ptr<HloInstruction> CreateSlice(
331 const Shape& shape, HloInstruction* operand,
332 tensorflow::gtl::ArraySlice<int64> start_indices,
333 tensorflow::gtl::ArraySlice<int64> limit_indices,
334 tensorflow::gtl::ArraySlice<int64> strides);
335
336 // Creates a slice instruction, where the first operand is sliced by
337 // start indices specified in the second operand, and by size specified in
338 // 'slice_sizes'.
339 static std::unique_ptr<HloInstruction> CreateDynamicSlice(
340 const Shape& shape, HloInstruction* operand,
341 HloInstruction* start_indices,
342 tensorflow::gtl::ArraySlice<int64> slice_sizes);
343
344 // Creates a dynamic update slice instruction, which updates a slice
345 // of 'operand' with 'update' and 'start_indices'.
346 static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
347 const Shape& shape, HloInstruction* operand, HloInstruction* update,
348 HloInstruction* start_indices);
349
350 // Creates a concatenate instruction, where the operands are concatenated on
351 // the provided dimension.
352 static std::unique_ptr<HloInstruction> CreateConcatenate(
353 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
354 int64 dimension);
355
356 // Creates a reduce instruction, where the computation (given by the handle)
357 // is applied successively to every element in operand. That is, if f is the
358 // function to apply (which either takes 2 [accumulator, value] or 3
359 // [accumulator, index, value] arguments) and init is a reduction operator
360 // specified initial value (for example, 0 for addition), then this operation
361 // will compute:
362 // f(f(init, [index0], value0), [index1], value1), ...)
363 static std::unique_ptr<HloInstruction> CreateReduce(
364 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
365 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
366 HloComputation* reduce_computation);
367
368 // Creates a reduce-window instruction, where the computation (given
369 // by the handle) is applied window-wise at each valid window
370 // position in the operand.
371 static std::unique_ptr<HloInstruction> CreateReduceWindow(
372 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
373 const Window& window, HloComputation* reduce_computation);
374
375 // Creates a batch-norm-training instruction.
376 static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
377 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
378 HloInstruction* offset, float epsilon, int64 feature_index);
379
380 // Creates a batch-norm-inference instruction.
381 static std::unique_ptr<HloInstruction> CreateBatchNormInference(
382 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
383 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
384 float epsilon, int64 feature_index);
385
386 // Creates a batch-norm-grad instruction.
387 static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
388 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
389 HloInstruction* mean, HloInstruction* variance,
390 HloInstruction* grad_output, float epsilon, int64 feature_index);
391
392 // Creates a scatter computation that scatters the `source` array to the
393 // selected indices of each window.
394 static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
395 const Shape& shape, HloInstruction* operand, HloComputation* select,
396 const Window& window, HloInstruction* source, HloInstruction* init_value,
397 HloComputation* scatter);
398
399 // Creates a broadcast instruction.
400 static std::unique_ptr<HloInstruction> CreateBroadcast(
401 const Shape& shape, HloInstruction* operand,
402 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
403
404 // Creates a broadcast-size-one-dimensions instruction.
405 static std::unique_ptr<HloInstruction> CreateBroadcastDimOne(
406 const Shape& shape, HloInstruction* operand);
407
408 // Creates a sequence of instructions that performs an explicit broadcast of
409 // the operand to the target shape.
410 //
411 // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
412 // returned as a unique_ptr for API consistency with other factory methods in
413 // this interface.
414 //
415 // TODO(b/72173833) Ideally HloComputations would always be present, and so
416 // the adder being passed by the caller would not be necessary.
417 static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
418 const Shape& output_shape, HloInstruction* operand,
419 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
420 adder);
421
422 // Creates a pad instruction, where the operand is padded on the edges and
423 // between the elements with the given padding value.
424 static std::unique_ptr<HloInstruction> CreatePad(
425 const Shape& shape, HloInstruction* operand,
426 HloInstruction* padding_value, const PaddingConfig& padding_config);
427
428 // Creates a reshape instruction, where the operand is flattened row-major
429 // order and then reshaped to the given result shape.
430 static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
431 HloInstruction* operand);
432
433 // Creates a transpose instruction which permutes the operand dimensions.
434 static std::unique_ptr<HloInstruction> CreateTranspose(
435 const Shape& shape, HloInstruction* operand,
436 tensorflow::gtl::ArraySlice<int64> dimensions);
437
438 // Creates a while instruction, given a condition computation, a body
439 // computation, and the initial value for the input of the computations. For
440 // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
441 // corresponds to the C code below.
442 // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
443 static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
444 HloComputation* condition,
445 HloComputation* body,
446 HloInstruction* init);
447
448 static std::unique_ptr<HloInstruction> CreateConditional(
449 const Shape& shape, HloInstruction* pred,
450 HloInstruction* true_computation_arg, HloComputation* true_computation,
451 HloInstruction* false_computation_arg, HloComputation* false_computation);
452
453 static std::unique_ptr<HloInstruction> CreateGather(
454 const Shape& shape, HloInstruction* operand,
455 HloInstruction* gather_indices,
456 const GatherDimensionNumbers& gather_dim_numbers,
457 tensorflow::gtl::ArraySlice<int64> window_bounds);
458
459 // Creates a fusion instruction. A fusion instruction contains one or more
460 // fused instructions forming an expression with a single root
461 // "fused_root". Additional instructions can be added to the fusion
462 // instruction with the method FuseInstruction.
463 static std::unique_ptr<HloInstruction> CreateFusion(
464 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
465
466 static std::unique_ptr<HloInstruction> CreateFusion(
467 const Shape& shape, FusionKind fusion_kind,
468 tensorflow::gtl::ArraySlice<HloInstruction*> operands,
469 HloComputation* fusion_computation);
470
471 // Creates a call instruction that applies the given computation on the given
472 // operands. "shape" is the resultant shape.
473 static std::unique_ptr<HloInstruction> CreateCall(
474 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
475 HloComputation* computation);
476
477 // Creates a custom call instruction that applies the given custom call target
478 // to the given operands. "shape" is the resultant shape.
479 static std::unique_ptr<HloInstruction> CreateCustomCall(
480 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
481 tensorflow::StringPiece custom_call_target);
482
483 // Creates a HostCompute instruction, which records host-side control and
484 // data dependencies for use in instruction scheduling.
485 static std::unique_ptr<HloInstruction> CreateHostCompute(
486 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
487 tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
488
489 // Creates a tuple instruction with the given elements. This is a convenience
490 // wrapper around CreateVariadic.
491 static std::unique_ptr<HloInstruction> CreateTuple(
492 tensorflow::gtl::ArraySlice<HloInstruction*> elements);
493
494 // Creates a reverse instruction, which reverses the order of the elements
495 // in the specified dimensions.
496 static std::unique_ptr<HloInstruction> CreateReverse(
497 const Shape& shape, HloInstruction* operand,
498 tensorflow::gtl::ArraySlice<int64> dimensions);
499
500 // Creates an instance of GatherDimensionNumbers.
501 static GatherDimensionNumbers MakeGatherDimNumbers(
502 tensorflow::gtl::ArraySlice<int64> output_window_dims,
503 tensorflow::gtl::ArraySlice<int64> elided_window_dims,
504 tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
505 int64 index_vector_dim);
506
507 // Returns the opcode for this instruction.
508 HloOpcode opcode() const { return opcode_; }
509
510 // Returns true if this instruction has a side effect. An instruction has a
511 // side effect if it uses certain opcodes or calls a computation with a side
512 // effect.
513 bool HasSideEffect() const;
514
515 // Returns the result shape of this instruction.
516 const Shape& shape() const;
517
518 // Returns the (mutable) result shape of this instruction.
519 Shape* mutable_shape() { return &shape_; }
520
521 // Returns the ith operand to this instruction.
522 const HloInstruction* operand(int64 i) const;
523
524 // Returns the ith operand to this instruction.
525 HloInstruction* mutable_operand(int64 i);
526
527 // Returns the number of operands to this instruction.
528 int64 operand_count() const { return operands_.size(); }
529
530 // Returns the vector of operands of this instruction.
531 using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
532 const InstructionVector& operands() const { return operands_; }
533
534 // Returns the index of 'target' in the operands sequence.
535 // Precondition: target must be an operand (or a fatal error will occur).
536 int64 operand_index(const HloInstruction* target) const;
537
538 // Returns the number of users of this instruction.
539 int64 user_count() const { return users_.size(); }
540
541 // Returns the users of this instruction.
542 const std::vector<HloInstruction*>& users() const { return users_; }
543
544 // Returns true if this instruction is a user of 'instruction'.
545 bool IsUserOf(const HloInstruction* instruction) const {
546 return ContainsKey(instruction->user_set_, this);
547 }
548
549 // Adds a control dependency from this instruction to the given
550 // instruction. This instruction becomes a control predecessor of
551 // 'instruction', and 'instruction' becomes a control successor of this
552 // instruction. Returns an error status if either of the given instructions
553 // does not belong to the same computation.
554 //
555 // This is used to enforce an additional ordering requirement that is not
556 // captured by normal data dependencies, such as ordering among Send or Recv
557 // operations to avoid deadlock.
558 Status AddControlDependencyTo(HloInstruction* instruction);
559
560 // Removes a previously added control dependency from this instruction to
561 // 'instruction'.
562 Status RemoveControlDependencyTo(HloInstruction* instruction);
563
564 // Returns the set of control predecessors (successors) of this
565 // instruction. Control predecessors (successors) must execute before (after)
566 // the current instruction.
567 const std::vector<HloInstruction*>& control_predecessors() const {
568 return control_predecessors_;
569 }
570 const std::vector<HloInstruction*>& control_successors() const {
571 return control_successors_;
572 }
573
574 // Returns true if "other" performs the same computation as this instruction.
575 bool Identical(
576 const HloInstruction& other,
577 const std::function<bool(const HloInstruction*, const HloInstruction*)>&
578 eq_operands = std::equal_to<const HloInstruction*>(),
579 const std::function<bool(const HloComputation*, const HloComputation*)>&
580 eq_computations = std::equal_to<const HloComputation*>(),
581 bool layout_sensitive = true) const {
582 // An instruction is always identical to itself.
583 if (this == &other) {
584 return true;
585 }
586
587 // Identical instruction must have the same opcode, shape, and identical
588 // operands.
589 if (opcode() != other.opcode()) {
590 return false;
591 }
592 using EqShapeFuncType = bool (*)(const Shape&, const Shape&);
593 EqShapeFuncType eq_shapes =
594 layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
595 if (!eq_shapes(shape(), other.shape())) {
596 return false;
597 }
598 if (operands().size() != other.operands().size()) {
599 return false;
600 }
601
602 // Use an explicit loop rather than ContainerEquals, because copying around
603 // std::functions may be too expensive in some cases.
604 for (size_t i = 0; i < operands().size(); ++i) {
605 if (!eq_operands(operand(i), other.operand(i))) {
606 return false;
607 }
608 }
609
610 return IdenticalSlowPath(other, eq_computations, eq_shapes);
611 }
612
613 // Returns whether the instruction has a constant operand.
614 bool HasConstantOperand() const;
615
616 // Returns whether this instruction does a rank-2 transposition.
617 bool IsRank2Transpose() const;
618
619 // Replaces the use of this instruction in "user" with "new_producer". Note
620 // that there might be multiple uses of this instruction in "user"; all will
621 // be replaced.
622 Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
623
624 // Replaces the specified operand with new_operand.
625 Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
626
627 // Replaces all uses of this instruction with the new producer. If
628 // new_producer is a user of this instruction then new_producer remains a use
629 // of this instruction to avoid introducing cycles into the graph.
630 //
631 // If this instruction is the root of its computation, sets the computation's
632 // root to new_producer.
633 Status ReplaceAllUsesWith(HloInstruction* new_producer);
634
635 // Detaches an instruction from its operands. That is, remove the instruction
636 // from each operand's user set. This should only be called prior to
637 // deallocating the instruction.
638 void DetachFromOperands();
639
640 // Performs a postorder DFS visit using this node as the root. If
641 // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
642 // complete. If ignore_control_predecessors is true, instructions only
643 // reachable via control dependencies will not be visited, and the postorder
644 // will not take control dependencies into account. It is as if the control
645 // dependencies didn't exist in the graph at all.
646 template <typename HloInstructionPtr>
647 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
648 bool call_finish_visit = true,
649 bool ignore_control_predecessors = false);
650 Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
651 bool ignore_control_predecessors = false) const {
652 return const_cast<HloInstruction*>(this)->Accept(
653 visitor, call_finish_visit, ignore_control_predecessors);
654 }
655
656 // Same as Accept() above, but the order of operand and control predecessor
657 // visitation is determined by the given operand order; if compare(A, B) ==
658 // true, A is visited before B.
659 using CompareFunction =
660 std::function<bool(const HloInstruction*, const HloInstruction*)>;
661 Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
662 const CompareFunction& operand_order,
663 bool call_finish_visit = true);
664
665 // Performs a postorder DFS visit using this node as the root. Calls the given
666 // visitor function at each instruction.
667 Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
668 Status Accept(
669 const std::function<Status(const HloInstruction*)>& visitor_func) const;
670
671 // Visits all instructions rooted at this instruction using the given visitor
672 // in the given order. 'order' must contain at least the set of instructions
673 // rooted at this node (ie, those accessible from a DFS traversal from this
674 // instruction). Instructions contained in 'order' which are not in the set of
675 // instructions rooted at this node are ignored. 'order' must also be a valid
676 // topological sort of these instructions (defs appear before uses) though
677 // need not be a DFS post-order.
678 Status AcceptOrdered(DfsHloVisitor* visitor,
679 const std::vector<const HloInstruction*>& order);
680
681 // Visit this instruction and only this instruction with the given visitor.
682 template <typename HloInstructionPtr>
683 Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
684
685 // Returns the literal associated with this instruction.
686 //
687 // Note: only constant and parameter opcodes have an associated literal.
688 const Literal& literal() const;
689
690 // Returns the parameter number associated with this instruction.
691 //
692 // Note: only parameter opcodes have an associated parameter number.
693 int64 parameter_number() const {
694 CHECK_EQ(HloOpcode::kParameter, opcode_);
695 return parameter_number_;
696 }
697
698 // Returns the dimension sizes or numbers associated with this instruction.
699 //
700 // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
701 // and reverse.
702 const std::vector<int64>& dimensions() const;
703 int64 dimensions(int64 index) const;
704
705 // Accessor for the dimension in which a concatenate HLO should occur.
706 // Precondition: opcode() == HloOpcode::kConcatenate
707 int64 concatenate_dimension() const;
708
709 // Returns the tuple index associated with this instruction.
710 //
711 // Precondition: opcode() == HloOpcode::kGetTupleElement
712 int64 tuple_index() const;
713
714 // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
715 // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
716 // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
717 std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
718 const;
719
720 std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
721 auto rv =
722 const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
723 return {const_cast<HloInstruction*>(rv.first), rv.second};
724 }
725
726 // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
727 const HloInstruction* LatestNonGteAncestor() const;
728
729 HloInstruction* LatestNonGteAncestor() {
730 return const_cast<HloInstruction*>(
731 const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
732 }
733
734 // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
735 // The setter should only be called by HloModule or HloComputation methods.
736 //
737 // Precondition: The instruction has a valid to_apply_ field.
738 HloComputation* to_apply() const;
739 void set_to_apply(HloComputation* to_apply);
740
741 // Returns the custom_call_target for CustomCall.
742 // Precondition: opcode() == HloOpcode::kCustomCall
743 const string& custom_call_target() const;
744
745 // Returns the config for the Outfeed instruction.
746 // Precondition: opcode() == HloOpcode::kOutfeed
747 const string& outfeed_config() const;
748
749 // Returns the shape for the Outfeed instruction.
750 // Precondition: opcode() == HloOpcode::kOutfeed
751 const Shape& outfeed_shape() const;
752
753 // Gets/sets the while_condition or while_body HloComputation for While. The
754 // setters should only be called by HloModule or HloComputation methods.
755 //
756 // Precondition: The instruction is a While instruction.
757 HloComputation* while_condition() const;
758 HloComputation* while_body() const;
759 void set_while_condition(HloComputation* while_condition);
760 void set_while_body(HloComputation* while_body);
761
762 // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
763 // setters should only be called by HloModule or HloComputation methods.
764 //
765 // Precondition: opcode() == HloOpcode::kSelectAndScatter.
766 HloComputation* select() const;
767 HloComputation* scatter() const;
768 void set_select(HloComputation* select);
769 void set_scatter(HloComputation* scatter);
770
771 // Gets/sets the true and false HloComputation for Conditional. The setters
772 // should only be called by HloModule or HloComputation methods.
773 //
774 // Precondition: The instruction is a Conditional instruction.
775 HloComputation* true_computation() const;
776 HloComputation* false_computation() const;
777 void set_true_computation(HloComputation* true_computation);
778 void set_false_computation(HloComputation* false_computation);
779
780 // Returns a string for the signature of this instruction if considered as a
781 // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
782 string SignatureString() const;
783
784 // Returns a debugging string that represents this instruction.
785 //
786 // (We express the default options using an overload rather than a default
787 // param because gdb ignores default params, but does resolve overloads.)
788 //
789 // TODO(b/73348663): Make ToString() adaptive to the size of the string by
790 // default, backing off on providing full information for very large strings,
791 // or provide a different name for a ToString-like function that does that.
792 string ToString() const { return ToString(HloPrintOptions()); }
793 string ToString(const HloPrintOptions& options) const;
794
795 // Components of the ToString() representation:
796
797 // Returns a string representation of the operand list.
798 string OperandsToString(const HloPrintOptions& options) const;
799
800 // Returns string representation of op-specific attributes.
801 std::vector<string> ExtraAttributesToString(
802 const HloPrintOptions& options) const;
803
804 // As ToString, but returns a shorter string.
805 string ToShortString() const;
806
807 // Returns a serialized representation of this instruction.
808 HloInstructionProto ToProto() const;
809
810 // Returns a category for the HLO. This could be something like "convolution"
811 // or "elementwise".
812 string ToCategory() const;
813
814 // Returns a logging instruction, if the output of this instruction is logged.
815 //
816 // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
817 HloInstruction* tracing() const;
818 void set_tracing(HloInstruction* trace_instruction);
819
820 // Returns the channel id associated with the instruction. The id is
821 // shared between each Send/Recv pair and is globally unique to identify each
822 // channel.
823 //
824 // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
825 int64 channel_id() const { return channel_id_; }
826
827 // Returns the channel name associated with the instruction. The name is
828 // used to identify host Send/Recv operations.
829 //
830 // Precondition: opcode() == HloOpcode::kHostCompute
831 string channel_name() const { return channel_name_; }
832
833 // Returns feature_index field associated with the instruction. The index
834 // represents the index of the feature dimension.
835 //
836 // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
837 // or kBatchNormGrad.
838 int64 feature_index() const { return feature_index_; }
839
840 // Returns a epsilon value associated with the instruction. The is a small
841 // number added to the variance to avoid divide-by-zero error.
842 //
843 // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
844 // or kBatchNormGrad.
845 float epsilon() const { return epsilon_; }
846
847 // Returns the infeed configuration string. The infeed configuration includes
848 // any metadata needed for the backend compiler (e.g., infeed buffer address)
849 // and is target-dependent.
850 string infeed_config() const { return infeed_config_; }
851 void set_infeed_config(const string& config) { infeed_config_ = config; }
852
853 // Returns a tag to be used in tracing.
854 //
855 // Precondition: opcode() == HloOpcode::kTrace
856 string TracingTag() const;
857
858 // Returns whether the instruction is a constant.
859 bool IsConstant() const;
860
861 // Returns true if this instruction is fused, ie contained within a fusion
862 // instruction.
863 bool IsFused() const;
864
865 // Returns the computation for this fused instruction.
866 //
867 // Precondition: opcode() == HloOpcode::kFusion
868 HloComputation* fused_instructions_computation() const;
869
870 // Returns true if this instruction can be legally fused into a fusion
871 // instruction.
872 bool IsFusable() const;
873
874 // Returns the root instruction of the fused expression contained within this
875 // fusion instruction.
876 //
877 // Precondition: opcode() == HloOpcode::kFusion
878 HloInstruction* fused_expression_root() const;
879
880 // Returns the list of fused instructions inside this fusion instruction. The
881 // returned type is a range of HloInstruction*s.
882 //
883 // Precondition: opcode() == HloOpcode::kFusion
884 const tensorflow::gtl::iterator_range<UnwrappingIterator<
885 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
886 fused_instructions() const;
887
888 const tensorflow::gtl::iterator_range<
889 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
890 fused_instructions();
891
892 // Gets the number of instructions inside this fusion instruction.
893 //
894 // Precondition: opcode() == HloOpcode::kFusion
895 int64 fused_instruction_count() const;
896
897 // Returns the fused parameter instruction in this fusion instruction
898 // corresponding to the given parameter number.
899 //
900 // Precondition: opcode() == HloOpcode::kFusion
901 HloInstruction* fused_parameter(int64 parameter_number) const;
902
903 // Returns the vector of fused parameters inside this fusion instruction.
904 //
905 // Precondition: opcode() == HloOpcode::kFusion
906 const std::vector<HloInstruction*>& fused_parameters() const;
907
908 // Returns true if this instruction is a fusion instruction that generates
909 // multiple outputs.
910 const bool IsMultiOutputFusion() const {
911 return opcode() == HloOpcode::kFusion &&
912 fused_expression_root()->opcode() == HloOpcode::kTuple;
913 }
914
915 FusionKind fusion_kind() const {
916 CHECK_EQ(HloOpcode::kFusion, opcode_);
917 return fusion_kind_;
918 }
919
920 void set_fusion_kind(FusionKind kind) {
921 CHECK_EQ(HloOpcode::kFusion, opcode_);
922 fusion_kind_ = kind;
923 }
924
925 // Returns the sharding applied to this operator.
926 // REQUIRES: has_sharding() is true.
927 const HloSharding& sharding() const {
928 CHECK(has_sharding());
929 return *sharding_;
930 }
931 // Returns the sharding applied to this operator, or default_ if none exists.
932 const HloSharding& sharding_or_default(const HloSharding& default_) const {
933 return sharding_ ? *sharding_ : default_;
934 }
935 // Returns the sharding unique device, if any.
936 tensorflow::gtl::optional<int64> sharding_unique_device() const {
937 if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) {
938 return tensorflow::gtl::optional<int64>();
939 }
940 return sharding_->UniqueDevice().ValueOrDie();
941 }
942 // Sets the sharding of this operator. Should only be called by HloModule or
943 // HloComputation methods.
944 void set_sharding(const HloSharding& sharding) {
945 sharding_ = MakeUnique<HloSharding>(sharding);
946 }
947 // Remove any sharding from this operator.
948 void clear_sharding() { sharding_ = nullptr; }
949 // Return true if this operator has a sharding assigned.
950 bool has_sharding() const { return sharding_ != nullptr; }
951
952 // Adds a new operand the fusion instruction.
953 HloInstruction* AddFusionOperand(HloInstruction* new_operand);
954
955 // Merges the fused instructions from 'instruction_to_merge' into the
956 // fused instruction set of 'this', updating operands as necessary.
957 //
958 // Precondition: opcode() == HloOpcode::kFusion
959 // Predondition: 'instruction_to_merge' must be an operand of 'this'.
960 void MergeFusionInstruction(HloInstruction* instruction_to_merge);
961
962 // Merges the fused instructions from instruction_to_merge into the fused
963 // instruction set of 'this' and generates multioutput fusion instructions.
964 // All the users of instruction_to_merge will be redirected to 'this'
965 // instruction. instruction_to_merge will be removed from its parent
966 // computation.
967 //
968 // Precondition: opcode() == HloOpcode::kFusion
969 void MergeFusionInstructionIntoMultiOutput(
970 HloInstruction* instruction_to_merge);
971
972 // Fuses the given instruction in this fusion instruction. instruction_to_fuse
973 // is cloned and the clone is placed in the fusion
974 // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
975 // than moved to cleanly handle the case where the instruction has a use
976 // outside the fusion instruction. Moving such an instruction into a fusion
977 // instruction would violate the single-result invariant of HLO instructions
978 // and significantly complicate code generation.
979 //
980 // Precondition: this->opcode() == HloOpcode::kFusion
981 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
982 return FuseInstructionInternal(instruction_to_fuse);
983 }
984
985 // Fuses the given instruction in this fusion instruction and generate
986 // multioutput fusion instruction. A clone of the instruction_to_fuse will
987 // be part of the output of fusion instructions. The users of
988 // instruction_to_fuse will be redirected to this fusion instructions.
989 // instruction_to_fuse will be removed from its parent computation.
990 //
991 // Precondition: this->opcode() == HloOpcode::kFusion
992 HloInstruction* FuseInstructionIntoMultiOutput(
993 HloInstruction* instruction_to_fuse) {
994 return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
995 }
996
997 // Returns the start index in the given dimension for a slice node.
998 //
999 // Precondition: opcode() == HloOpcode::kSlice
1000 int64 slice_starts(int64 dimension) const {
1001 CHECK_EQ(HloOpcode::kSlice, opcode_);
1002 return slice_starts_[dimension];
1003 }
1004 const std::vector<int64>& slice_starts() const { return slice_starts_; }
1005
1006 // Returns the (exclusive) limit index in the given dimension for a slice
1007 // node.
1008 //
1009 // Precondition: opcode() == HloOpcode::kSlice
1010 int64 slice_limits(int64 dimension) const {
1011 CHECK_EQ(HloOpcode::kSlice, opcode_);
1012 return slice_limits_[dimension];
1013 }
1014 const std::vector<int64>& slice_limits() const {
1015 CHECK_EQ(HloOpcode::kSlice, opcode_);
1016 return slice_limits_;
1017 }
1018
1019 // Returns the stride in the given dimension for a slice node.
1020 //
1021 // Precondition: opcode() == HloOpcode::kSlice
1022 int64 slice_strides(int64 dimension) const {
1023 CHECK_EQ(HloOpcode::kSlice, opcode_);
1024 return slice_strides_[dimension];
1025 }
1026 const std::vector<int64>& slice_strides() const { return slice_strides_; }
1027
1028 // Returns the flag that describes whether a slice must be lowered into an
1029 // offset into the original operand.
1030 bool IsInPlaceSlice() const { return is_in_place_slice_; }
1031
1032 // Sets and returns the flag that describes whether a slice must be lowered
1033 // into an offset into the original operand.
1034 bool SetIsInPlaceSlice(bool value) {
1035 is_in_place_slice_ = value;
1036 return value;
1037 }
1038
1039 // Returns the size of the slice in the given dimension for a dynamic
1040 // slice node.
1041 //
1042 // Precondition: opcode() == HloOpcode::kDynamicSlice
1043 int64 slice_sizes(int64 dimension) const {
1044 CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1045 return dynamic_slice_sizes_[dimension];
1046 }
1047 const std::vector<int64>& dynamic_slice_sizes() const {
1048 CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1049 return dynamic_slice_sizes_;
1050 }
1051
1052 // Returns the number of exponent bits for a reduce-precision node.
1053 //
1054 // Precondition: opcode() == HloOpcode::kReducePrecision
1055 int32 exponent_bits() const {
1056 CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1057 return exponent_bits_;
1058 }
1059
1060 // Returns the number of mantissa bits for a reduce-precision node.
1061 //
1062 // Precondition: opcode() == HloOpcode::kReducePrecision
1063 int32 mantissa_bits() const {
1064 CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1065 return mantissa_bits_;
1066 }
1067
1068 // Returns data on the window in a windowed operation such as
1069 // convolution.
1070 const Window& window() const {
1071 CHECK(window_ != nullptr);
1072 return *window_;
1073 }
1074
1075 // Sets the window data in a windowed operation such as convolution.
1076 void set_window(const Window& window) {
1077 window_ = MakeUnique<Window>(window);
1078 }
1079
1080 // Returns the padding configuration for a pad node.
1081 //
1082 // Precondition: opcode() == HloOpcode::kPad
1083 const PaddingConfig& padding_config() const {
1084 CHECK(padding_config_ != nullptr);
1085 return *padding_config_;
1086 }
1087
1088 // Returns data on the dimension numbers used for a convolution operation,
1089 // which may be a kConvolution instruction or a kCustomCall that implements a
1090 // convolution.
1091 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1092 CHECK(convolution_dimension_numbers_ != nullptr);
1093 return *convolution_dimension_numbers_;
1094 }
1095
1096 // Sets the convolution dimension numbers on this instruction. In general you
1097 // shouldn't need to call this; instead, specify the convolution dimension
1098 // numbers when you create the instruction.
1099 void set_convolution_dimension_numbers(
1100 const ConvolutionDimensionNumbers& dnums) {
1101 convolution_dimension_numbers_ =
1102 MakeUnique<ConvolutionDimensionNumbers>(dnums);
1103 }
1104
1105 FftType fft_type() const {
1106 CHECK_EQ(HloOpcode::kFft, opcode_);
1107 return fft_type_;
1108 }
1109
1110 const std::vector<int64>& fft_length() const {
1111 CHECK_EQ(HloOpcode::kFft, opcode_);
1112 return fft_length_;
1113 }
1114
1115 // Returns the dump string of the convolution dimension numbers.
1116 string ConvolutionDimensionNumbersToString() const;
1117
1118 // Returns data on the dimension numbers used for a dot operation.
1119 const DotDimensionNumbers& dot_dimension_numbers() const {
1120 CHECK(dot_dimension_numbers_ != nullptr);
1121 return *dot_dimension_numbers_;
1122 }
1123
1124 // Returns the dump string of the dot dimension numbers.
1125 string DotDimensionNumbersToString() const;
1126
1127 const GatherDimensionNumbers& gather_dimension_numbers() const {
1128 CHECK(gather_dimension_numbers_ != nullptr);
1129 return *gather_dimension_numbers_;
1130 }
1131
1132 tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
1133 CHECK_EQ(opcode(), HloOpcode::kGather);
1134 return gather_window_bounds_;
1135 }
1136
1137 // Returns the dump string of the gather dimension numbers.
1138 string GatherDimensionNumbersToString() const;
1139
1140 // Returns the random distribution for this rng node.
1141 //
1142 // Precondition: opcode() == HloOpcode::kRng
1143 RandomDistribution random_distribution() const;
1144
1145 // Clones the HLO instruction. The clone will have the same opcode, shape, and
1146 // operands. After creation the clone has no uses. "this" (the instruction
1147 // cloned from) is not changed. Suffix is the string to append to the name of
1148 // the instruction to form the name of the cloned instruction.
1149 // If the module pointer is not nullptr, it will be the module where
1150 // the cloned computations will be added to (in order to support deep
1151 // cloning).
1152 std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
1153 HloModule* module = nullptr) const;
1154
1155 // Clones the HLO instruction as above but with new shape and operands.
1156 // If the module pointer is not nullptr, it will be the module where
1157 // the cloned computations will be added to (in order to support deep
1158 // cloning).
1159 std::unique_ptr<HloInstruction> CloneWithNewOperands(
1160 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1161 HloModule* module = nullptr) const;
1162
1163 // Returns the computations this instruction directly calls (if any).
1164 const std::vector<HloComputation*>& called_computations() const {
1165 return called_computations_;
1166 }
1167
1168 // Replaces all called computations based on a map function. This is needed
1169 // when we clone hlo_computations and want to let the instructions to point
1170 // to the newly cloned nodes.
1171 void ReplaceCalledComputations(
1172 std::function<HloComputation*(HloComputation*)> map_function) {
1173 for (int64 i = 0; i < called_computations_.size(); ++i) {
1174 called_computations_[i] = map_function(called_computations_[i]);
1175 }
1176 }
1177
1178 // Clears out the called computations.
1179 //
1180 // This is, in particular, necessary when inlining function bodies into their
1181 // caller. If there were side-effecting operations in the called computations,
1182 // the call itself is considered side-effecting and thus cannot be removed. By
1183 // clearing out the computations, we reflect the fact that all side-effecting
1184 // properties have been reflected in the caller, and make the call HLO
1185 // removable.
1186 void ClearCalledComputations() { called_computations_.clear(); }
1187
1188 // Returns true if this instruction performs an elementwise operation on
1189 // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1190 // after performing necessary implicit broadcast
1191 // (cs/IrArray::EmitArrayElementAddress), to compute the output at index
1192 // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is
1193 // the element at {i_0,i_1,...,i_n}.
1194 //
1195 // Note on performance: when this instruction is kFusion, this method, in the
1196 // worst case, scans all fused instructions. We could speed this up by
1197 // caching.
1198 bool IsElementwiseOnOperand(int64 operand_idx) const;
1199
1200 // Returns true if this instruction is elementwise on all its operands.
1201 bool IsElementwise() const;
1202
1203 // Returns true if this elementwise instruction implicitly broadcasts operand
1204 // `operand_idx`.
1205 //
1206 // Precondition: this instruction should be an elementwise operation.
1207 bool ImplicitlyBroadcastsOperand(int64 operand_idx) const;
1208
1209 // Returns true if this instruction is binary and elementwise.
1210 bool IsElementwiseBinary() const;
1211
1212 // Returns whether this instruction may reuse elements of its `i`th operand.
1213 bool ReusesOperandElements(int64 i) const {
1214 return OperandElementUse(i) == UseKind::kReuse;
1215 }
1216
1217 // Returns the indices that the given operand appear in the operand list of
1218 // this instruction. Note that an instruction can use the same operand
1219 // multiple times.
1220 std::vector<int64> OperandIndices(const HloInstruction* operand) const;
1221
1222 // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1223 // this reshape merely inserts or deletes 1-sized dimensions, return the input
1224 // indices of the deleted dimensions and the output indices of the inserted
1225 // dimensions.
1226 //
1227 // Precondition: this op must be a reshape.
1228 std::tuple<bool, std::vector<int64>, std::vector<int64>>
1229 ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1230
1231 // Gets/sets the string identifier for this instruction.
1232 const string& name() const { return name_; }
1233 void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); }
1234
1235 // Use the given NameUniquer to select a unique name for the instruction based
1236 // on the instruction's existing name.
1237 void UniquifyName(NameUniquer* name_uniquer);
1238
1239 // Set the unique id for this instruction to "id"
1240 void SetUniqueId(int id) {
1241 CHECK_EQ(unique_id_, -1); // Should not be assigned already
1242 CHECK_GE(id, 0);
1243 unique_id_ = id;
1244 }
1245
1246 // Return the unique ID assigned to this node via SetUniqueId (or -1
1247 // if no id has been assigned yet).
1248 int unique_id() const { return unique_id_; }
1249
1250 // Sets the debug metadata for this instruction.
1251 void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
1252 const OpMetadata& metadata() const { return metadata_; }
1253
1254 // Set/get the computation containing this instruction. set_parent should only
1255 // be called by HloComputation methods which add/remove instructions to
1256 // computations.
1257 void set_parent(HloComputation* computation) { parent_ = computation; }
1258 const HloComputation* parent() const { return parent_; }
1259 HloComputation* parent() { return parent_; }
1260
1261 // Returns the module for this instruction.
1262 HloModule* GetModule() const;
1263
1264 // Returns whether we could assign input and output layouts to this
1265 // instruction to make it a bitcast.
1266 bool CouldBeBitcast() const;
1267
1268 // Get/Set the number of partitions per outer dimension (in order, starting
1269 // with outer-most dimension first). Currently used by the parallel cpu
1270 // backend to partition HLOs into parallel tasks.
1271 // TODO(b/62783254) Replace these methods with a more general way to
1272 // annotate HLOs with backend-specific information.
1273 const std::vector<int64>& outer_dimension_partitions() const {
1274 return outer_dimension_partitions_;
1275 }
1276 void set_outer_dimension_partitions(
1277 const std::vector<int64>& outer_dimension_partitions);
1278
1279 // Change the layout for an Constant Hlo instruction to match new_layout. For
1280 // tuple shaped constants shape_index is the path to the internal array
1281 // subshape whose layout needs to be changed.
1282 void RelayoutConstant(const Layout& new_layout,
1283 const ShapeIndex& shape_index = {});
1284
1285 private:
1286 enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
1287
1288 // Helper class for computing OperandElementUse for kFusion.
1289 class FusionReusesParamElements;
1290
1291 // See comments on Identical().
1292 // eq_shapes() is used to check shapes for equality, and would normally be
1293 // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
1294 // whether we want a layout-sensitive check or not.
1295 bool IdenticalSlowPath(
1296 const HloInstruction& other,
1297 const std::function<bool(const HloComputation*, const HloComputation*)>&
1298 eq_computations,
1299 const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
1300
1301 // Creates an n-ary elementwise operation.
1302 static std::unique_ptr<HloInstruction> CreateNary(
1303 const Shape& shape, HloOpcode opcode,
1304 tensorflow::gtl::ArraySlice<HloInstruction*> operands);
1305
1306 // Appends operand to the list of operands and adds this instruction as a user
1307 // of the operand.
1308 void AppendOperand(HloInstruction* operand);
1309
1310 // Adds a user for this instruction.
1311 void AddUser(HloInstruction* user);
1312
1313 // Removes a user for this instruction.
1314 void RemoveUser(HloInstruction* user);
1315
1316 // Internal constructor for a given opcode/shape, other fields must be filled
1317 // by factory methods.
1318 HloInstruction(HloOpcode opcode, const Shape& shape);
1319
1320 // Fuses the given instruction into this fusion instruction. When add_output
1321 // is false (which is the default), instruction_to_fuse is cloned and the
1322 // clone is placed in the fusion instruction. instruction_to_fuse is
1323 // unchanged.
1324 //
1325 // When add_output is true, a clone of the instruction_to_fuse will be part
1326 // of the output of fusion instructions. The users of instruction_to_fuse
1327 // will be redirected to this fusion instructions. instruction_to_fuse will
1328 // be removed from its parent computation.
1329 //
1330 // Precondition: this->opcode() == HloOpcode::kFusion
1331 HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
1332 bool add_output = false);
1333
1334 // Clones the given instruction_to_fuse and insert the clone into this fusion
1335 // instruction. If add_output is true, a clone of instruction_to_fuse will
1336 // be in the output of the this fusion instruction (part of the tuple of the
1337 // fusion root).
1338 //
1339 // Precondition: opcode() == HloOpcode::kFusion
1340 HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
1341 bool add_output = false);
1342
1343 // Clones a fusion instruction with a new shape and operands.
1344 std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
1345 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1346 HloModule* module = nullptr) const;
1347
1348 // Returns true if this instruction can legally have the dimensions field
1349 // set. Used for checking precondition of dimensions field accessors.
1350 bool CanHaveDimensionsField() const;
1351
1352 // Returns how this instruction uses elements of its `i`th operand.
1353 UseKind OperandElementUse(int64 i) const;
1354
1355 int unique_id_; // Unique to this HloInstruction within a HloModule
1356
1357 // Opcode for this instruction.
1358 HloOpcode opcode_;
1359
1360 // Instruction operands.
1361 InstructionVector operands_;
1362
1363 // The set of control predecessors of this instruction.
1364 std::vector<HloInstruction*> control_predecessors_;
1365
1366 // The users of this instruction. Users are HLOs where this instruction is an
1367 // operand. The vector users_ and the set user_set_ contain identical
1368 // members. The set enables fast membership testing and the vector enables
1369 // fast, stable iteration.
1370 std::vector<HloInstruction*> users_;
1371 std::unordered_set<const HloInstruction*> user_set_;
1372
1373 // The set of control successors of this instruction.
1374 std::vector<HloInstruction*> control_successors_;
1375
1376 // The computation in which this instruction is contained.
1377 HloComputation* parent_ = nullptr;
1378
1379 // Shape of outfeed request.
1380 Shape outfeed_shape_;
1381
1382 // Result shape of this instruction.
1383 Shape shape_;
1384
1385 // Literal, only present for kConstant.
1386 std::unique_ptr<Literal> literal_;
1387
1388 // Constant index, only present for kGetTupleElement.
1389 int64 tuple_index_ = -1;
1390
1391 // Dimensions present for some operations that require reshaping or
1392 // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
1393 std::vector<int64> dimensions_;
1394
1395 // Describes the window in a windowed operation such as convolution.
1396 std::unique_ptr<Window> window_;
1397
1398 // Describes the dimension numbers used for a convolution.
1399 std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1400
1401 // Describes the dimension numbers used for a dot.
1402 std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
1403
1404 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1405 std::vector<int64> gather_window_bounds_;
1406
1407 // Describes FFT type for an FFT instruction.
1408 FftType fft_type_ = FftType::FFT;
1409
1410 // Indicates the FFT length for an FFT instruction.
1411 std::vector<int64> fft_length_;
1412
1413 // Describes the [begin, end) index range for a slice.
1414 std::vector<int64> slice_starts_;
1415 std::vector<int64> slice_limits_;
1416 std::vector<int64> slice_strides_;
1417
1418 // Describes whether the slice can be lowered to an offset into the operand.
1419 bool is_in_place_slice_ = false;
1420
1421 // The bit sizes for a reduce-precision operation.
1422 int32 exponent_bits_ = 0;
1423 int32 mantissa_bits_ = 0;
1424
1425 // Describes the [start, start + size) range size for a dynamic slice
1426 // ('start' is specified dynamically in the second operand of the operation).
1427 std::vector<int64> dynamic_slice_sizes_;
1428
1429 // The padding configuration that describes the edge padding and interior
1430 // padding of this pad instruction. Only set for pad instructions.
1431 std::unique_ptr<PaddingConfig> padding_config_;
1432
1433 // The type of the fusion. Used by kFusion only.
1434 FusionKind fusion_kind_;
1435
1436 // The sharding, if one exists.
1437 std::unique_ptr<HloSharding> sharding_;
1438
1439 // For parameter instructions this field holds the parameter number.
1440 int64 parameter_number_ = 0;
1441
1442 // Name of a global symbol to call, only present for kCustomCall.
1443 string custom_call_target_;
1444
1445 // Name to use for host send/recv channels, only present for kHostCompute.
1446 string channel_name_;
1447
1448 // Estimate of the duration of a host computation in nanoseconds.
1449 int64 cost_estimate_ns_;
1450
1451 // Computations called by this instruction.
1452 std::vector<HloComputation*> called_computations_;
1453
1454 // Indices of computations in called_computations_ for instructions which call
1455 // multiple computations.
1456 enum {
1457 // kWhile computations.
1458 kBodyComputationIndex = 0,
1459 kConditionComputationIndex = 1,
1460
1461 // kSelectAndScatter computations.
1462 kSelectComputationIndex = 0,
1463 kScatterComputationIndex = 1,
1464
1465 // kConditional computations.
1466 kTrueComputationIndex = 0,
1467 kFalseComputationIndex = 1,
1468 };
1469
1470 // Outfeed configuration information, only present for kOutfeed.
1471 string outfeed_config_;
1472
1473 // A trace instruction that consumes this instruction.
1474 //
1475 // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
1476 // an operand.
1477 HloInstruction* trace_instruction_ = nullptr;
1478
1479 // The distribution requested for random number generation.
1480 // Only present for kRng.
1481 RandomDistribution distribution_;
1482
1483 // A small float number added to the variance to avoid divide-by-zero error.
1484 // Only present for kBatchNormTraining.
1485 float epsilon_ = 0.0f;
1486
1487 // An integer value representing the index of the feature dimension.
1488 // Only present for kBatchNormTraining.
1489 int64 feature_index_ = -1;
1490
1491 // Represents a unique identifier for each Send/Recv instruction pair.
1492 // Only present for kSend or kRecv.
1493 int64 channel_id_ = -1;
1494
1495 // The string representation of the infeed configuration.
1496 string infeed_config_;
1497
1498 // String identifier for instruction.
1499 string name_;
1500
1501 // Metadata for debugging.
1502 OpMetadata metadata_;
1503
1504 // The number of partitions per outer dimension (listed in order from
1505 // outer-most dimension first).
1506 std::vector<int64> outer_dimension_partitions_;
1507
1508 TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
1509};
1510
1511string ToString(HloInstruction::FusionKind kind);
1512StatusOr<HloInstruction::FusionKind> StringToFusionKind(
1513 const string& kind_name);
1514
1515// Custom (de)stringification functions for protos that live inside
1516// HloInstruction.
1517string PaddingConfigToString(const PaddingConfig& padding);
1518string OpMetadataToString(const OpMetadata& metadata);
1519string RandomDistributionToString(const RandomDistribution& distribution);
1520StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
1521
1522std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
1523
1524// Map classes that guarantee a deterministic iteration order when the key is
1525// an HloInstruction* or a const HloInstruction*.
1526// To make the iteration order over the map deterministic, the comparator
1527// should not be using the pointer values, but rather an intrinsic property of
1528// the hlo.
1529//
1530// Note that this cannot be used for HLO instructions across multiple modules
1531// since the id of HLO instructions are only unique within each HLO module.
1532struct HloPtrComparator {
1533 bool operator()(const HloInstruction* const& lhs,
1534 const HloInstruction* const& rhs) const {
1535 return lhs->unique_id() < rhs->unique_id();
1536 }
1537};
1538
1539template <typename ValueT>
1540using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
1541
1542template <typename ValueT>
1543using ConstHloInstructionMap =
1544 std::map<const HloInstruction*, ValueT, HloPtrComparator>;
1545
1546using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
1547using ConstHloInstructionSet =
1548 std::set<const HloInstruction*, HloPtrComparator>;
1549
1550} // namespace xla
1551
1552#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
1553