| 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | |
| 16 | // HLO instructions are in DAG form and represent the computations that the user |
| 17 | // has built up via the XLA service interface. They are ultimately lowered |
| 18 | // in a platform-aware way by traversing the HLO DAG and emitting a lowered |
| 19 | // form; e.g. see DfsHloVisitor. |
| 20 | |
| 21 | #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ |
| 22 | #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ |
| 23 | |
| 24 | #include <functional> |
| 25 | #include <iosfwd> |
| 26 | #include <list> |
| 27 | #include <memory> |
| 28 | #include <set> |
| 29 | #include <string> |
| 30 | #include <tuple> |
| 31 | #include <unordered_map> |
| 32 | #include <unordered_set> |
| 33 | #include <vector> |
| 34 | |
| 35 | #include "tensorflow/compiler/xla/iterator_util.h" |
| 36 | #include "tensorflow/compiler/xla/literal_util.h" |
| 37 | #include "tensorflow/compiler/xla/map_util.h" |
| 38 | #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" |
| 39 | #include "tensorflow/compiler/xla/service/hlo.pb.h" |
| 40 | #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| 41 | #include "tensorflow/compiler/xla/service/hlo_sharding.h" |
| 42 | #include "tensorflow/compiler/xla/service/name_uniquer.h" |
| 43 | #include "tensorflow/compiler/xla/types.h" |
| 44 | #include "tensorflow/compiler/xla/xla_data.pb.h" |
| 45 | #include "tensorflow/core/lib/core/status.h" |
| 46 | #include "tensorflow/core/lib/core/stringpiece.h" |
| 47 | #include "tensorflow/core/lib/gtl/array_slice.h" |
| 48 | #include "tensorflow/core/lib/gtl/flatmap.h" |
| 49 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
| 50 | #include "tensorflow/core/lib/gtl/iterator_range.h" |
| 51 | #include "tensorflow/core/platform/logging.h" |
| 52 | #include "tensorflow/core/platform/macros.h" |
| 53 | #include "tensorflow/core/platform/types.h" |
| 54 | |
| 55 | namespace xla { |
| 56 | |
| 57 | class HloComputation; |
| 58 | class HloModule; |
| 59 | |
| 60 | // A bunch of switches that control how the hlo text should be printed. |
| 61 | class HloPrintOptions { |
| 62 | public: |
| 63 | // Constructs the default print options: don't print large constants, don't |
| 64 | // compact operands, no indentation. |
| 65 | HloPrintOptions() |
| 66 | : print_large_constants_(false), |
| 67 | print_subcomputation_references_(true), |
| 68 | print_metadata_(true), |
| 69 | compact_operands_(false), |
| 70 | print_operand_shape_(true), |
| 71 | print_program_shape_(true), |
| 72 | print_percent_(true), |
| 73 | indent_amount_(0) {} |
| 74 | |
| 75 | static HloPrintOptions ShortParsable() { |
| 76 | return HloPrintOptions() |
| 77 | .set_print_large_constants(true) |
| 78 | .set_print_subcomputation_references(true) |
| 79 | .set_print_metadata(false) |
| 80 | .set_print_operand_shape(false) |
| 81 | .set_print_program_shape(false) |
| 82 | .set_print_percent(false); |
| 83 | } |
| 84 | |
| 85 | // If true, large constants will be printed out. |
| 86 | HloPrintOptions& set_print_large_constants(bool value) { |
| 87 | print_large_constants_ = value; |
| 88 | return *this; |
| 89 | } |
| 90 | |
| 91 | // If true, the names of subcomputations (e.g. a fusion node's fused |
| 92 | // computation) won't be printed. This makes the resulting text not parsable. |
| 93 | // |
| 94 | // A CustomCall's call target is printed even if |
| 95 | // print_subcomputation_references is false, because the call target isn't an |
| 96 | // HloComputation. |
| 97 | HloPrintOptions& set_print_subcomputation_references(bool value) { |
| 98 | print_subcomputation_references_ = value; |
| 99 | return *this; |
| 100 | } |
| 101 | |
| 102 | // If true, metatdata will be printed. |
| 103 | HloPrintOptions& set_print_metadata(bool value) { |
| 104 | print_metadata_ = value; |
| 105 | return *this; |
| 106 | } |
| 107 | |
| 108 | // If true, operands' shapes will be printed. |
| 109 | HloPrintOptions& set_print_operand_shape(bool value) { |
| 110 | print_operand_shape_ = value; |
| 111 | return *this; |
| 112 | } |
| 113 | |
| 114 | // If true, program shape of hlo computations will be printed. |
| 115 | HloPrintOptions& set_print_program_shape(bool value) { |
| 116 | print_program_shape_ = value; |
| 117 | return *this; |
| 118 | } |
| 119 | |
| 120 | // If true, names will be printed with prefix '%'. |
| 121 | HloPrintOptions& set_print_percent(bool value) { |
| 122 | print_percent_ = value; |
| 123 | return *this; |
| 124 | } |
| 125 | |
| 126 | // If true, only a part of operands will be printed out, and their names will |
| 127 | // be omitted (note that in this case the text will not be parsable). |
| 128 | HloPrintOptions& set_compact_operands(bool value) { |
| 129 | compact_operands_ = value; |
| 130 | return *this; |
| 131 | } |
| 132 | |
| 133 | // The indent of the hlo text block. |
| 134 | HloPrintOptions& set_indent_amount(int value) { |
| 135 | indent_amount_ = value; |
| 136 | return *this; |
| 137 | } |
| 138 | |
| 139 | bool print_large_constants() const { return print_large_constants_; } |
| 140 | bool print_subcomputation_references() const { |
| 141 | return print_subcomputation_references_; |
| 142 | } |
| 143 | bool print_metadata() const { return print_metadata_; } |
| 144 | bool compact_operands() const { return compact_operands_; } |
| 145 | bool print_operand_shape() const { return print_operand_shape_; } |
| 146 | bool print_program_shape() const { return print_program_shape_; } |
| 147 | bool print_percent() const { return print_percent_; } |
| 148 | int indent_amount() const { return indent_amount_; } |
| 149 | |
| 150 | private: |
| 151 | bool print_large_constants_; |
| 152 | bool print_subcomputation_references_; |
| 153 | bool print_metadata_; |
| 154 | bool compact_operands_; |
| 155 | bool print_operand_shape_; |
| 156 | bool print_program_shape_; |
| 157 | bool print_percent_; |
| 158 | int indent_amount_; |
| 159 | }; |
| 160 | |
| 161 | // HLO instructions are the IR used by the high-level compiler. |
| 162 | class HloInstruction { |
| 163 | public: |
| 164 | enum class FusionKind { |
| 165 | kLoop, // Fused into a loop. |
| 166 | kInput, // Op's input is fused into the op itself. |
| 167 | kOutput, // Op's output is fused into the op itself. |
| 168 | // REQUIRES: At least one operand buffer must be able |
| 169 | // to alias the output buffer. |
| 170 | kTransposeDot, // Fused into a dot with transposed operands. |
| 171 | kCustom, // Custom category for backend-specific fusions that |
| 172 | // do not match any of the more specific ones. |
| 173 | }; |
| 174 | |
| 175 | ~HloInstruction(); |
| 176 | |
| 177 | // Creates an instruction from the given proto. Arguments: |
| 178 | // |
| 179 | // module: the module which will contain the instruction. The newly created |
| 180 | // instruction is *not* added to the module or any computation, however. |
| 181 | // proto: the proto to convert from. |
| 182 | // instruction_map: a map from instruction id to HloInstruction*. This map |
| 183 | // must contain all operands of the newly constructed instruction. |
| 184 | // computation_map: a map from computation id to HloComputation*. This map |
| 185 | // must contain all computations which the newly constructed instruction |
| 186 | // calls. |
| 187 | static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( |
| 188 | HloModule* module, const HloInstructionProto& proto, |
| 189 | const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map, |
| 190 | const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map); |
| 191 | |
| 192 | // Creates a parameter-retrieving instruction. |
| 193 | static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, |
| 194 | const Shape& shape, |
| 195 | const string& name); |
| 196 | |
| 197 | // Creates a literal constant instruction. |
| 198 | static std::unique_ptr<HloInstruction> CreateConstant( |
| 199 | std::unique_ptr<Literal> literal); |
| 200 | |
| 201 | // Creates a get tuple element instruction. |
| 202 | static std::unique_ptr<HloInstruction> CreateGetTupleElement( |
| 203 | const Shape& shape, HloInstruction* operand, int64 index); |
| 204 | |
| 205 | // Creates a trace instruction that logs the input operand in the computation. |
| 206 | static std::unique_ptr<HloInstruction> CreateTrace(const string& tag, |
| 207 | HloInstruction* operand); |
| 208 | |
| 209 | // Creates a random number generation instruction that fills a shape with |
| 210 | // random numbers from a given distribution. |
| 211 | static std::unique_ptr<HloInstruction> CreateRng( |
| 212 | const Shape& shape, RandomDistribution distribution, |
| 213 | tensorflow::gtl::ArraySlice<HloInstruction*> parameters); |
| 214 | |
| 215 | // Creates a unary instruction (one operand). |
| 216 | // Precondition: opcode must be a legitimate unary operation. |
| 217 | static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape, |
| 218 | HloOpcode opcode, |
| 219 | HloInstruction* operand); |
| 220 | |
| 221 | // Creates a binary instruction (two operands). |
| 222 | // Precondition: opcode must be a legitimate binary operation. |
| 223 | static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape, |
| 224 | HloOpcode opcode, |
| 225 | HloInstruction* lhs, |
| 226 | HloInstruction* rhs); |
| 227 | |
| 228 | // Creates a ternary instruction (three operands). |
| 229 | // Precondition: opcode must be a legitimate ternary operation. |
| 230 | static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape, |
| 231 | HloOpcode opcode, |
| 232 | HloInstruction* lhs, |
| 233 | HloInstruction* rhs, |
| 234 | HloInstruction* ehs); |
| 235 | |
| 236 | // Creates a variadic instruction (variable number of operands). |
| 237 | // Precondition: opcode must be a legitimate variadic operation. |
| 238 | static std::unique_ptr<HloInstruction> CreateVariadic( |
| 239 | const Shape& shape, HloOpcode opcode, |
| 240 | tensorflow::gtl::ArraySlice<HloInstruction*> operands); |
| 241 | |
| 242 | // Creates a map instruction, where the computation (given by the handle) is |
| 243 | // applied element-wise to every element in operands (across the operands, |
| 244 | // at a given index) with the same `static_operands`. |
| 245 | static std::unique_ptr<HloInstruction> CreateMap( |
| 246 | const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, |
| 247 | HloComputation* map_computation, |
| 248 | tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {}); |
| 249 | |
| 250 | // Creates a convolution op, where rhs is the convolutional filter |
| 251 | // and window describes how the filter is applied to lhs. |
| 252 | static std::unique_ptr<HloInstruction> CreateConvolve( |
| 253 | const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, |
| 254 | const Window& window, |
| 255 | const ConvolutionDimensionNumbers& dimension_numbers); |
| 256 | |
| 257 | // Creates an FFT op, of the type indicated by fft_type. |
| 258 | static std::unique_ptr<HloInstruction> CreateFft( |
| 259 | const Shape& shape, HloInstruction* operand, FftType fft_type, |
| 260 | tensorflow::gtl::ArraySlice<int64> fft_length); |
| 261 | |
| 262 | // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch |
| 263 | // dimensions specified in 'dimension_numbers'. |
| 264 | static std::unique_ptr<HloInstruction> CreateDot( |
| 265 | const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, |
| 266 | const DotDimensionNumbers& dimension_numbers); |
| 267 | |
| 268 | // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 |
| 269 | // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS |
| 270 | // and the RHS must be of rank 2. |
| 271 | static std::unique_ptr<HloInstruction> CreateCanonicalDot( |
| 272 | const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); |
| 273 | |
| 274 | // Creates a reduce-precision op, where operand is the data to reduce in |
| 275 | // precision, and exponent_bits and mantissa_bits describe the precision to |
| 276 | // reduce it to. |
| 277 | static std::unique_ptr<HloInstruction> CreateReducePrecision( |
| 278 | const Shape& shape, HloInstruction* operand, const int exponent_bits, |
| 279 | const int mantissa_bits); |
| 280 | |
| 281 | // Creates a cross replica sum op. |
| 282 | static std::unique_ptr<HloInstruction> CreateCrossReplicaSum( |
| 283 | const Shape& shape, |
| 284 | tensorflow::gtl::ArraySlice<HloInstruction*> operands); |
| 285 | |
| 286 | // Creates a conversion instruction, where operand is the data to convert and |
| 287 | // shape is the target shape for the conversion. |
| 288 | static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, |
| 289 | HloInstruction* operand); |
| 290 | |
| 291 | // Creates a bitcast conversion instruction, where operand is the data to |
| 292 | // convert and shape is the target shape for the conversion. |
| 293 | static std::unique_ptr<HloInstruction> CreateBitcastConvert( |
| 294 | const Shape& shape, HloInstruction* operand); |
| 295 | |
| 296 | // Creates an infeed instruction, which reads data of the given shape from the |
| 297 | // Infeed interface of the device. |
| 298 | static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape, |
| 299 | const string& config); |
| 300 | |
| 301 | // Creates an outfeed instruction, which outputs data. |
| 302 | static std::unique_ptr<HloInstruction> CreateOutfeed( |
| 303 | const Shape& shape, HloInstruction* operand, |
| 304 | tensorflow::StringPiece outfeed_config); |
| 305 | |
| 306 | // Creates an asynchronous send instruction with the given channel id, which |
| 307 | // initiates sending the operand data to a unique receive instruction in |
| 308 | // another computation that has the same channel id. |
| 309 | static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand, |
| 310 | int64 channel_id); |
| 311 | |
| 312 | // Blocks until data transfer for the Send instruction (operand) is complete. |
| 313 | // The operand must be kSend. |
| 314 | static std::unique_ptr<HloInstruction> CreateSendDone( |
| 315 | HloInstruction* operand); |
| 316 | |
| 317 | // Creates an asynchronous receive instruction with the given channel id, |
| 318 | // which allocates resources to receive data of the given shape from a unique |
| 319 | // send instruction in another computation that has the same channel id. |
| 320 | static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape, |
| 321 | int64 channel_id); |
| 322 | |
| 323 | // Blocks until data transfer for the Recv instruction (operand) is complete |
| 324 | // and returns the receive buffer. The operand must be kRecv. |
| 325 | static std::unique_ptr<HloInstruction> CreateRecvDone( |
| 326 | HloInstruction* operand); |
| 327 | |
| 328 | // Creates a slice instruction, where the operand is sliced by the given |
| 329 | // start/limit indices. |
| 330 | static std::unique_ptr<HloInstruction> CreateSlice( |
| 331 | const Shape& shape, HloInstruction* operand, |
| 332 | tensorflow::gtl::ArraySlice<int64> start_indices, |
| 333 | tensorflow::gtl::ArraySlice<int64> limit_indices, |
| 334 | tensorflow::gtl::ArraySlice<int64> strides); |
| 335 | |
| 336 | // Creates a slice instruction, where the first operand is sliced by |
| 337 | // start indices specified in the second operand, and by size specified in |
| 338 | // 'slice_sizes'. |
| 339 | static std::unique_ptr<HloInstruction> CreateDynamicSlice( |
| 340 | const Shape& shape, HloInstruction* operand, |
| 341 | HloInstruction* start_indices, |
| 342 | tensorflow::gtl::ArraySlice<int64> slice_sizes); |
| 343 | |
| 344 | // Creates a dynamic update slice instruction, which updates a slice |
| 345 | // of 'operand' with 'update' and 'start_indices'. |
| 346 | static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice( |
| 347 | const Shape& shape, HloInstruction* operand, HloInstruction* update, |
| 348 | HloInstruction* start_indices); |
| 349 | |
| 350 | // Creates a concatenate instruction, where the operands are concatenated on |
| 351 | // the provided dimension. |
| 352 | static std::unique_ptr<HloInstruction> CreateConcatenate( |
| 353 | const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, |
| 354 | int64 dimension); |
| 355 | |
| 356 | // Creates a reduce instruction, where the computation (given by the handle) |
| 357 | // is applied successively to every element in operand. That is, if f is the |
| 358 | // function to apply (which either takes 2 [accumulator, value] or 3 |
| 359 | // [accumulator, index, value] arguments) and init is a reduction operator |
| 360 | // specified initial value (for example, 0 for addition), then this operation |
| 361 | // will compute: |
| 362 | // f(f(init, [index0], value0), [index1], value1), ...) |
| 363 | static std::unique_ptr<HloInstruction> CreateReduce( |
| 364 | const Shape& shape, HloInstruction* operand, HloInstruction* init_value, |
| 365 | tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, |
| 366 | HloComputation* reduce_computation); |
| 367 | |
| 368 | // Creates a reduce-window instruction, where the computation (given |
| 369 | // by the handle) is applied window-wise at each valid window |
| 370 | // position in the operand. |
| 371 | static std::unique_ptr<HloInstruction> CreateReduceWindow( |
| 372 | const Shape& shape, HloInstruction* operand, HloInstruction* init_value, |
| 373 | const Window& window, HloComputation* reduce_computation); |
| 374 | |
| 375 | // Creates a batch-norm-training instruction. |
| 376 | static std::unique_ptr<HloInstruction> CreateBatchNormTraining( |
| 377 | const Shape& shape, HloInstruction* operand, HloInstruction* scale, |
| 378 | HloInstruction* offset, float epsilon, int64 feature_index); |
| 379 | |
| 380 | // Creates a batch-norm-inference instruction. |
| 381 | static std::unique_ptr<HloInstruction> CreateBatchNormInference( |
| 382 | const Shape& shape, HloInstruction* operand, HloInstruction* scale, |
| 383 | HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, |
| 384 | float epsilon, int64 feature_index); |
| 385 | |
| 386 | // Creates a batch-norm-grad instruction. |
| 387 | static std::unique_ptr<HloInstruction> CreateBatchNormGrad( |
| 388 | const Shape& shape, HloInstruction* operand, HloInstruction* scale, |
| 389 | HloInstruction* mean, HloInstruction* variance, |
| 390 | HloInstruction* grad_output, float epsilon, int64 feature_index); |
| 391 | |
| 392 | // Creates a scatter computation that scatters the `source` array to the |
| 393 | // selected indices of each window. |
| 394 | static std::unique_ptr<HloInstruction> CreateSelectAndScatter( |
| 395 | const Shape& shape, HloInstruction* operand, HloComputation* select, |
| 396 | const Window& window, HloInstruction* source, HloInstruction* init_value, |
| 397 | HloComputation* scatter); |
| 398 | |
| 399 | // Creates a broadcast instruction. |
| 400 | static std::unique_ptr<HloInstruction> CreateBroadcast( |
| 401 | const Shape& shape, HloInstruction* operand, |
| 402 | tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); |
| 403 | |
| 404 | // Creates a broadcast-size-one-dimensions instruction. |
| 405 | static std::unique_ptr<HloInstruction> CreateBroadcastDimOne( |
| 406 | const Shape& shape, HloInstruction* operand); |
| 407 | |
| 408 | // Creates a sequence of instructions that performs an explicit broadcast of |
| 409 | // the operand to the target shape. |
| 410 | // |
| 411 | // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is |
| 412 | // returned as a unique_ptr for API consistency with other factory methods in |
| 413 | // this interface. |
| 414 | // |
| 415 | // TODO(b/72173833) Ideally HloComputations would always be present, and so |
| 416 | // the adder being passed by the caller would not be necessary. |
| 417 | static std::unique_ptr<HloInstruction> CreateBroadcastSequence( |
| 418 | const Shape& output_shape, HloInstruction* operand, |
| 419 | const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& |
| 420 | adder); |
| 421 | |
| 422 | // Creates a pad instruction, where the operand is padded on the edges and |
| 423 | // between the elements with the given padding value. |
| 424 | static std::unique_ptr<HloInstruction> CreatePad( |
| 425 | const Shape& shape, HloInstruction* operand, |
| 426 | HloInstruction* padding_value, const PaddingConfig& padding_config); |
| 427 | |
| 428 | // Creates a reshape instruction, where the operand is flattened row-major |
| 429 | // order and then reshaped to the given result shape. |
| 430 | static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape, |
| 431 | HloInstruction* operand); |
| 432 | |
| 433 | // Creates a transpose instruction which permutes the operand dimensions. |
| 434 | static std::unique_ptr<HloInstruction> CreateTranspose( |
| 435 | const Shape& shape, HloInstruction* operand, |
| 436 | tensorflow::gtl::ArraySlice<int64> dimensions); |
| 437 | |
| 438 | // Creates a while instruction, given a condition computation, a body |
| 439 | // computation, and the initial value for the input of the computations. For |
| 440 | // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 |
| 441 | // corresponds to the C code below. |
| 442 | // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } |
| 443 | static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape, |
| 444 | HloComputation* condition, |
| 445 | HloComputation* body, |
| 446 | HloInstruction* init); |
| 447 | |
| 448 | static std::unique_ptr<HloInstruction> CreateConditional( |
| 449 | const Shape& shape, HloInstruction* pred, |
| 450 | HloInstruction* true_computation_arg, HloComputation* true_computation, |
| 451 | HloInstruction* false_computation_arg, HloComputation* false_computation); |
| 452 | |
| 453 | static std::unique_ptr<HloInstruction> CreateGather( |
| 454 | const Shape& shape, HloInstruction* operand, |
| 455 | HloInstruction* gather_indices, |
| 456 | const GatherDimensionNumbers& gather_dim_numbers, |
| 457 | tensorflow::gtl::ArraySlice<int64> window_bounds); |
| 458 | |
| 459 | // Creates a fusion instruction. A fusion instruction contains one or more |
| 460 | // fused instructions forming an expression with a single root |
| 461 | // "fused_root". Additional instructions can be added to the fusion |
| 462 | // instruction with the method FuseInstruction. |
| 463 | static std::unique_ptr<HloInstruction> CreateFusion( |
| 464 | const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); |
| 465 | |
| 466 | static std::unique_ptr<HloInstruction> CreateFusion( |
| 467 | const Shape& shape, FusionKind fusion_kind, |
| 468 | tensorflow::gtl::ArraySlice<HloInstruction*> operands, |
| 469 | HloComputation* fusion_computation); |
| 470 | |
| 471 | // Creates a call instruction that applies the given computation on the given |
| 472 | // operands. "shape" is the resultant shape. |
| 473 | static std::unique_ptr<HloInstruction> CreateCall( |
| 474 | const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, |
| 475 | HloComputation* computation); |
| 476 | |
| 477 | // Creates a custom call instruction that applies the given custom call target |
| 478 | // to the given operands. "shape" is the resultant shape. |
| 479 | static std::unique_ptr<HloInstruction> CreateCustomCall( |
| 480 | const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, |
| 481 | tensorflow::StringPiece custom_call_target); |
| 482 | |
| 483 | // Creates a HostCompute instruction, which records host-side control and |
| 484 | // data dependencies for use in instruction scheduling. |
| 485 | static std::unique_ptr<HloInstruction> CreateHostCompute( |
| 486 | const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, |
| 487 | tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); |
| 488 | |
| 489 | // Creates a tuple instruction with the given elements. This is a convenience |
| 490 | // wrapper around CreateVariadic. |
| 491 | static std::unique_ptr<HloInstruction> CreateTuple( |
| 492 | tensorflow::gtl::ArraySlice<HloInstruction*> elements); |
| 493 | |
| 494 | // Creates a reverse instruction, which reverses the order of the elements |
| 495 | // in the specified dimensions. |
| 496 | static std::unique_ptr<HloInstruction> CreateReverse( |
| 497 | const Shape& shape, HloInstruction* operand, |
| 498 | tensorflow::gtl::ArraySlice<int64> dimensions); |
| 499 | |
| 500 | // Creates an instance of GatherDimensionNumbers. |
| 501 | static GatherDimensionNumbers MakeGatherDimNumbers( |
| 502 | tensorflow::gtl::ArraySlice<int64> output_window_dims, |
| 503 | tensorflow::gtl::ArraySlice<int64> elided_window_dims, |
| 504 | tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, |
| 505 | int64 index_vector_dim); |
| 506 | |
| 507 | // Returns the opcode for this instruction. |
| 508 | HloOpcode opcode() const { return opcode_; } |
| 509 | |
| 510 | // Returns true if this instruction has a side effect. An instruction has a |
| 511 | // side effect if it uses certain opcodes or calls a computation with a side |
| 512 | // effect. |
| 513 | bool HasSideEffect() const; |
| 514 | |
| 515 | // Returns the result shape of this instruction. |
| 516 | const Shape& shape() const; |
| 517 | |
| 518 | // Returns the (mutable) result shape of this instruction. |
| 519 | Shape* mutable_shape() { return &shape_; } |
| 520 | |
| 521 | // Returns the ith operand to this instruction. |
| 522 | const HloInstruction* operand(int64 i) const; |
| 523 | |
| 524 | // Returns the ith operand to this instruction. |
| 525 | HloInstruction* mutable_operand(int64 i); |
| 526 | |
| 527 | // Returns the number of operands to this instruction. |
| 528 | int64 operand_count() const { return operands_.size(); } |
| 529 | |
| 530 | // Returns the vector of operands of this instruction. |
| 531 | using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>; |
| 532 | const InstructionVector& operands() const { return operands_; } |
| 533 | |
| 534 | // Returns the index of 'target' in the operands sequence. |
| 535 | // Precondition: target must be an operand (or a fatal error will occur). |
| 536 | int64 operand_index(const HloInstruction* target) const; |
| 537 | |
| 538 | // Returns the number of users of this instruction. |
| 539 | int64 user_count() const { return users_.size(); } |
| 540 | |
| 541 | // Returns the users of this instruction. |
| 542 | const std::vector<HloInstruction*>& users() const { return users_; } |
| 543 | |
| 544 | // Returns true if this instruction is a user of 'instruction'. |
| 545 | bool IsUserOf(const HloInstruction* instruction) const { |
| 546 | return ContainsKey(instruction->user_set_, this); |
| 547 | } |
| 548 | |
| 549 | // Adds a control dependency from this instruction to the given |
| 550 | // instruction. This instruction becomes a control predecessor of |
| 551 | // 'instruction', and 'instruction' becomes a control successor of this |
| 552 | // instruction. Returns an error status if either of the given instructions |
| 553 | // does not belong to the same computation. |
| 554 | // |
| 555 | // This is used to enforce an additional ordering requirement that is not |
| 556 | // captured by normal data dependencies, such as ordering among Send or Recv |
| 557 | // operations to avoid deadlock. |
| 558 | Status AddControlDependencyTo(HloInstruction* instruction); |
| 559 | |
| 560 | // Removes a previously added control dependency from this instruction to |
| 561 | // 'instruction'. |
| 562 | Status RemoveControlDependencyTo(HloInstruction* instruction); |
| 563 | |
| 564 | // Returns the set of control predecessors (successors) of this |
| 565 | // instruction. Control predecessors (successors) must execute before (after) |
| 566 | // the current instruction. |
| 567 | const std::vector<HloInstruction*>& control_predecessors() const { |
| 568 | return control_predecessors_; |
| 569 | } |
| 570 | const std::vector<HloInstruction*>& control_successors() const { |
| 571 | return control_successors_; |
| 572 | } |
| 573 | |
| 574 | // Returns true if "other" performs the same computation as this instruction. |
| 575 | bool Identical( |
| 576 | const HloInstruction& other, |
| 577 | const std::function<bool(const HloInstruction*, const HloInstruction*)>& |
| 578 | eq_operands = std::equal_to<const HloInstruction*>(), |
| 579 | const std::function<bool(const HloComputation*, const HloComputation*)>& |
| 580 | eq_computations = std::equal_to<const HloComputation*>(), |
| 581 | bool layout_sensitive = true) const { |
| 582 | // An instruction is always identical to itself. |
| 583 | if (this == &other) { |
| 584 | return true; |
| 585 | } |
| 586 | |
| 587 | // Identical instruction must have the same opcode, shape, and identical |
| 588 | // operands. |
| 589 | if (opcode() != other.opcode()) { |
| 590 | return false; |
| 591 | } |
| 592 | using EqShapeFuncType = bool (*)(const Shape&, const Shape&); |
| 593 | EqShapeFuncType eq_shapes = |
| 594 | layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; |
| 595 | if (!eq_shapes(shape(), other.shape())) { |
| 596 | return false; |
| 597 | } |
| 598 | if (operands().size() != other.operands().size()) { |
| 599 | return false; |
| 600 | } |
| 601 | |
| 602 | // Use an explicit loop rather than ContainerEquals, because copying around |
| 603 | // std::functions may be too expensive in some cases. |
| 604 | for (size_t i = 0; i < operands().size(); ++i) { |
| 605 | if (!eq_operands(operand(i), other.operand(i))) { |
| 606 | return false; |
| 607 | } |
| 608 | } |
| 609 | |
| 610 | return IdenticalSlowPath(other, eq_computations, eq_shapes); |
| 611 | } |
| 612 | |
| 613 | // Returns whether the instruction has a constant operand. |
| 614 | bool HasConstantOperand() const; |
| 615 | |
| 616 | // Returns whether this instruction does a rank-2 transposition. |
| 617 | bool IsRank2Transpose() const; |
| 618 | |
| 619 | // Replaces the use of this instruction in "user" with "new_producer". Note |
| 620 | // that there might be multiple uses of this instruction in "user"; all will |
| 621 | // be replaced. |
| 622 | Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); |
| 623 | |
| 624 | // Replaces the specified operand with new_operand. |
| 625 | Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); |
| 626 | |
| 627 | // Replaces all uses of this instruction with the new producer. If |
| 628 | // new_producer is a user of this instruction then new_producer remains a use |
| 629 | // of this instruction to avoid introducing cycles into the graph. |
| 630 | // |
| 631 | // If this instruction is the root of its computation, sets the computation's |
| 632 | // root to new_producer. |
| 633 | Status ReplaceAllUsesWith(HloInstruction* new_producer); |
| 634 | |
| 635 | // Detaches an instruction from its operands. That is, remove the instruction |
| 636 | // from each operand's user set. This should only be called prior to |
| 637 | // deallocating the instruction. |
| 638 | void DetachFromOperands(); |
| 639 | |
| 640 | // Performs a postorder DFS visit using this node as the root. If |
| 641 | // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when |
| 642 | // complete. If ignore_control_predecessors is true, instructions only |
| 643 | // reachable via control dependencies will not be visited, and the postorder |
| 644 | // will not take control dependencies into account. It is as if the control |
| 645 | // dependencies didn't exist in the graph at all. |
| 646 | template <typename HloInstructionPtr> |
| 647 | Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor, |
| 648 | bool call_finish_visit = true, |
| 649 | bool ignore_control_predecessors = false); |
| 650 | Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, |
| 651 | bool ignore_control_predecessors = false) const { |
| 652 | return const_cast<HloInstruction*>(this)->Accept( |
| 653 | visitor, call_finish_visit, ignore_control_predecessors); |
| 654 | } |
| 655 | |
| 656 | // Same as Accept() above, but the order of operand and control predecessor |
| 657 | // visitation is determined by the given operand order; if compare(A, B) == |
| 658 | // true, A is visited before B. |
| 659 | using CompareFunction = |
| 660 | std::function<bool(const HloInstruction*, const HloInstruction*)>; |
| 661 | Status AcceptWithOperandOrder(DfsHloVisitor* visitor, |
| 662 | const CompareFunction& operand_order, |
| 663 | bool call_finish_visit = true); |
| 664 | |
| 665 | // Performs a postorder DFS visit using this node as the root. Calls the given |
| 666 | // visitor function at each instruction. |
| 667 | Status Accept(const std::function<Status(HloInstruction*)>& visitor_func); |
| 668 | Status Accept( |
| 669 | const std::function<Status(const HloInstruction*)>& visitor_func) const; |
| 670 | |
| 671 | // Visits all instructions rooted at this instruction using the given visitor |
| 672 | // in the given order. 'order' must contain at least the set of instructions |
| 673 | // rooted at this node (ie, those accessible from a DFS traversal from this |
| 674 | // instruction). Instructions contained in 'order' which are not in the set of |
| 675 | // instructions rooted at this node are ignored. 'order' must also be a valid |
| 676 | // topological sort of these instructions (defs appear before uses) though |
| 677 | // need not be a DFS post-order. |
| 678 | Status AcceptOrdered(DfsHloVisitor* visitor, |
| 679 | const std::vector<const HloInstruction*>& order); |
| 680 | |
| 681 | // Visit this instruction and only this instruction with the given visitor. |
| 682 | template <typename HloInstructionPtr> |
| 683 | Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor); |
| 684 | |
| 685 | // Returns the literal associated with this instruction. |
| 686 | // |
| 687 | // Note: only constant and parameter opcodes have an associated literal. |
| 688 | const Literal& literal() const; |
| 689 | |
| 690 | // Returns the parameter number associated with this instruction. |
| 691 | // |
| 692 | // Note: only parameter opcodes have an associated parameter number. |
| 693 | int64 parameter_number() const { |
| 694 | CHECK_EQ(HloOpcode::kParameter, opcode_); |
| 695 | return parameter_number_; |
| 696 | } |
| 697 | |
| 698 | // Returns the dimension sizes or numbers associated with this instruction. |
| 699 | // |
| 700 | // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, |
| 701 | // and reverse. |
| 702 | const std::vector<int64>& dimensions() const; |
| 703 | int64 dimensions(int64 index) const; |
| 704 | |
| 705 | // Accessor for the dimension in which a concatenate HLO should occur. |
| 706 | // Precondition: opcode() == HloOpcode::kConcatenate |
| 707 | int64 concatenate_dimension() const; |
| 708 | |
| 709 | // Returns the tuple index associated with this instruction. |
| 710 | // |
| 711 | // Precondition: opcode() == HloOpcode::kGetTupleElement |
| 712 | int64 tuple_index() const; |
| 713 | |
| 714 | // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. |
| 715 | // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the |
| 716 | // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. |
| 717 | std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() |
| 718 | const; |
| 719 | |
| 720 | std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() { |
| 721 | auto rv = |
| 722 | const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex(); |
| 723 | return {const_cast<HloInstruction*>(rv.first), rv.second}; |
| 724 | } |
| 725 | |
| 726 | // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. |
| 727 | const HloInstruction* LatestNonGteAncestor() const; |
| 728 | |
| 729 | HloInstruction* LatestNonGteAncestor() { |
| 730 | return const_cast<HloInstruction*>( |
| 731 | const_cast<const HloInstruction*>(this)->LatestNonGteAncestor()); |
| 732 | } |
| 733 | |
| 734 | // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. |
| 735 | // The setter should only be called by HloModule or HloComputation methods. |
| 736 | // |
| 737 | // Precondition: The instruction has a valid to_apply_ field. |
| 738 | HloComputation* to_apply() const; |
| 739 | void set_to_apply(HloComputation* to_apply); |
| 740 | |
| 741 | // Returns the custom_call_target for CustomCall. |
| 742 | // Precondition: opcode() == HloOpcode::kCustomCall |
| 743 | const string& custom_call_target() const; |
| 744 | |
| 745 | // Returns the config for the Outfeed instruction. |
| 746 | // Precondition: opcode() == HloOpcode::kOutfeed |
| 747 | const string& outfeed_config() const; |
| 748 | |
| 749 | // Returns the shape for the Outfeed instruction. |
| 750 | // Precondition: opcode() == HloOpcode::kOutfeed |
| 751 | const Shape& outfeed_shape() const; |
| 752 | |
| 753 | // Gets/sets the while_condition or while_body HloComputation for While. The |
| 754 | // setters should only be called by HloModule or HloComputation methods. |
| 755 | // |
| 756 | // Precondition: The instruction is a While instruction. |
| 757 | HloComputation* while_condition() const; |
| 758 | HloComputation* while_body() const; |
| 759 | void set_while_condition(HloComputation* while_condition); |
| 760 | void set_while_body(HloComputation* while_body); |
| 761 | |
| 762 | // Gets/sets the select or scatter HloComputation for SelectAndScatter. The |
| 763 | // setters should only be called by HloModule or HloComputation methods. |
| 764 | // |
| 765 | // Precondition: opcode() == HloOpcode::kSelectAndScatter. |
| 766 | HloComputation* select() const; |
| 767 | HloComputation* scatter() const; |
| 768 | void set_select(HloComputation* select); |
| 769 | void set_scatter(HloComputation* scatter); |
| 770 | |
| 771 | // Gets/sets the true and false HloComputation for Conditional. The setters |
| 772 | // should only be called by HloModule or HloComputation methods. |
| 773 | // |
| 774 | // Precondition: The instruction is a Conditional instruction. |
| 775 | HloComputation* true_computation() const; |
| 776 | HloComputation* false_computation() const; |
| 777 | void set_true_computation(HloComputation* true_computation); |
| 778 | void set_false_computation(HloComputation* false_computation); |
| 779 | |
| 780 | // Returns a string for the signature of this instruction if considered as a |
| 781 | // function, e.g. the signature of an F32 add is (F32, F32) -> F32. |
| 782 | string SignatureString() const; |
| 783 | |
| 784 | // Returns a debugging string that represents this instruction. |
| 785 | // |
| 786 | // (We express the default options using an overload rather than a default |
| 787 | // param because gdb ignores default params, but does resolve overloads.) |
| 788 | // |
| 789 | // TODO(b/73348663): Make ToString() adaptive to the size of the string by |
| 790 | // default, backing off on providing full information for very large strings, |
| 791 | // or provide a different name for a ToString-like function that does that. |
| 792 | string ToString() const { return ToString(HloPrintOptions()); } |
| 793 | string ToString(const HloPrintOptions& options) const; |
| 794 | |
| 795 | // Components of the ToString() representation: |
| 796 | |
| 797 | // Returns a string representation of the operand list. |
| 798 | string OperandsToString(const HloPrintOptions& options) const; |
| 799 | |
| 800 | // Returns string representation of op-specific attributes. |
| 801 | std::vector<string> ( |
| 802 | const HloPrintOptions& options) const; |
| 803 | |
| 804 | // As ToString, but returns a shorter string. |
| 805 | string ToShortString() const; |
| 806 | |
| 807 | // Returns a serialized representation of this instruction. |
| 808 | HloInstructionProto ToProto() const; |
| 809 | |
| 810 | // Returns a category for the HLO. This could be something like "convolution" |
| 811 | // or "elementwise". |
| 812 | string ToCategory() const; |
| 813 | |
| 814 | // Returns a logging instruction, if the output of this instruction is logged. |
| 815 | // |
| 816 | // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace |
| 817 | HloInstruction* tracing() const; |
| 818 | void set_tracing(HloInstruction* trace_instruction); |
| 819 | |
| 820 | // Returns the channel id associated with the instruction. The id is |
| 821 | // shared between each Send/Recv pair and is globally unique to identify each |
| 822 | // channel. |
| 823 | // |
| 824 | // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv |
| 825 | int64 channel_id() const { return channel_id_; } |
| 826 | |
| 827 | // Returns the channel name associated with the instruction. The name is |
| 828 | // used to identify host Send/Recv operations. |
| 829 | // |
| 830 | // Precondition: opcode() == HloOpcode::kHostCompute |
| 831 | string channel_name() const { return channel_name_; } |
| 832 | |
| 833 | // Returns feature_index field associated with the instruction. The index |
| 834 | // represents the index of the feature dimension. |
| 835 | // |
| 836 | // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, |
| 837 | // or kBatchNormGrad. |
| 838 | int64 feature_index() const { return feature_index_; } |
| 839 | |
| 840 | // Returns a epsilon value associated with the instruction. The is a small |
| 841 | // number added to the variance to avoid divide-by-zero error. |
| 842 | // |
| 843 | // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, |
| 844 | // or kBatchNormGrad. |
| 845 | float epsilon() const { return epsilon_; } |
| 846 | |
| 847 | // Returns the infeed configuration string. The infeed configuration includes |
| 848 | // any metadata needed for the backend compiler (e.g., infeed buffer address) |
| 849 | // and is target-dependent. |
| 850 | string infeed_config() const { return infeed_config_; } |
| 851 | void set_infeed_config(const string& config) { infeed_config_ = config; } |
| 852 | |
| 853 | // Returns a tag to be used in tracing. |
| 854 | // |
| 855 | // Precondition: opcode() == HloOpcode::kTrace |
| 856 | string TracingTag() const; |
| 857 | |
| 858 | // Returns whether the instruction is a constant. |
| 859 | bool IsConstant() const; |
| 860 | |
| 861 | // Returns true if this instruction is fused, ie contained within a fusion |
| 862 | // instruction. |
| 863 | bool IsFused() const; |
| 864 | |
| 865 | // Returns the computation for this fused instruction. |
| 866 | // |
| 867 | // Precondition: opcode() == HloOpcode::kFusion |
| 868 | HloComputation* fused_instructions_computation() const; |
| 869 | |
| 870 | // Returns true if this instruction can be legally fused into a fusion |
| 871 | // instruction. |
| 872 | bool IsFusable() const; |
| 873 | |
| 874 | // Returns the root instruction of the fused expression contained within this |
| 875 | // fusion instruction. |
| 876 | // |
| 877 | // Precondition: opcode() == HloOpcode::kFusion |
| 878 | HloInstruction* fused_expression_root() const; |
| 879 | |
| 880 | // Returns the list of fused instructions inside this fusion instruction. The |
| 881 | // returned type is a range of HloInstruction*s. |
| 882 | // |
| 883 | // Precondition: opcode() == HloOpcode::kFusion |
| 884 | const tensorflow::gtl::iterator_range<UnwrappingIterator< |
| 885 | std::list<std::unique_ptr<HloInstruction>>::const_iterator>> |
| 886 | fused_instructions() const; |
| 887 | |
| 888 | const tensorflow::gtl::iterator_range< |
| 889 | UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> |
| 890 | fused_instructions(); |
| 891 | |
| 892 | // Gets the number of instructions inside this fusion instruction. |
| 893 | // |
| 894 | // Precondition: opcode() == HloOpcode::kFusion |
| 895 | int64 fused_instruction_count() const; |
| 896 | |
| 897 | // Returns the fused parameter instruction in this fusion instruction |
| 898 | // corresponding to the given parameter number. |
| 899 | // |
| 900 | // Precondition: opcode() == HloOpcode::kFusion |
| 901 | HloInstruction* fused_parameter(int64 parameter_number) const; |
| 902 | |
| 903 | // Returns the vector of fused parameters inside this fusion instruction. |
| 904 | // |
| 905 | // Precondition: opcode() == HloOpcode::kFusion |
| 906 | const std::vector<HloInstruction*>& fused_parameters() const; |
| 907 | |
| 908 | // Returns true if this instruction is a fusion instruction that generates |
| 909 | // multiple outputs. |
| 910 | const bool IsMultiOutputFusion() const { |
| 911 | return opcode() == HloOpcode::kFusion && |
| 912 | fused_expression_root()->opcode() == HloOpcode::kTuple; |
| 913 | } |
| 914 | |
| 915 | FusionKind fusion_kind() const { |
| 916 | CHECK_EQ(HloOpcode::kFusion, opcode_); |
| 917 | return fusion_kind_; |
| 918 | } |
| 919 | |
| 920 | void set_fusion_kind(FusionKind kind) { |
| 921 | CHECK_EQ(HloOpcode::kFusion, opcode_); |
| 922 | fusion_kind_ = kind; |
| 923 | } |
| 924 | |
| 925 | // Returns the sharding applied to this operator. |
| 926 | // REQUIRES: has_sharding() is true. |
| 927 | const HloSharding& sharding() const { |
| 928 | CHECK(has_sharding()); |
| 929 | return *sharding_; |
| 930 | } |
| 931 | // Returns the sharding applied to this operator, or default_ if none exists. |
| 932 | const HloSharding& sharding_or_default(const HloSharding& default_) const { |
| 933 | return sharding_ ? *sharding_ : default_; |
| 934 | } |
| 935 | // Returns the sharding unique device, if any. |
| 936 | tensorflow::gtl::optional<int64> sharding_unique_device() const { |
| 937 | if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { |
| 938 | return tensorflow::gtl::optional<int64>(); |
| 939 | } |
| 940 | return sharding_->UniqueDevice().ValueOrDie(); |
| 941 | } |
| 942 | // Sets the sharding of this operator. Should only be called by HloModule or |
| 943 | // HloComputation methods. |
| 944 | void set_sharding(const HloSharding& sharding) { |
| 945 | sharding_ = MakeUnique<HloSharding>(sharding); |
| 946 | } |
| 947 | // Remove any sharding from this operator. |
| 948 | void clear_sharding() { sharding_ = nullptr; } |
| 949 | // Return true if this operator has a sharding assigned. |
| 950 | bool has_sharding() const { return sharding_ != nullptr; } |
| 951 | |
| 952 | // Adds a new operand the fusion instruction. |
| 953 | HloInstruction* AddFusionOperand(HloInstruction* new_operand); |
| 954 | |
| 955 | // Merges the fused instructions from 'instruction_to_merge' into the |
| 956 | // fused instruction set of 'this', updating operands as necessary. |
| 957 | // |
| 958 | // Precondition: opcode() == HloOpcode::kFusion |
| 959 | // Predondition: 'instruction_to_merge' must be an operand of 'this'. |
| 960 | void MergeFusionInstruction(HloInstruction* instruction_to_merge); |
| 961 | |
| 962 | // Merges the fused instructions from instruction_to_merge into the fused |
| 963 | // instruction set of 'this' and generates multioutput fusion instructions. |
| 964 | // All the users of instruction_to_merge will be redirected to 'this' |
| 965 | // instruction. instruction_to_merge will be removed from its parent |
| 966 | // computation. |
| 967 | // |
| 968 | // Precondition: opcode() == HloOpcode::kFusion |
| 969 | void MergeFusionInstructionIntoMultiOutput( |
| 970 | HloInstruction* instruction_to_merge); |
| 971 | |
| 972 | // Fuses the given instruction in this fusion instruction. instruction_to_fuse |
| 973 | // is cloned and the clone is placed in the fusion |
| 974 | // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather |
| 975 | // than moved to cleanly handle the case where the instruction has a use |
| 976 | // outside the fusion instruction. Moving such an instruction into a fusion |
| 977 | // instruction would violate the single-result invariant of HLO instructions |
| 978 | // and significantly complicate code generation. |
| 979 | // |
| 980 | // Precondition: this->opcode() == HloOpcode::kFusion |
| 981 | HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { |
| 982 | return FuseInstructionInternal(instruction_to_fuse); |
| 983 | } |
| 984 | |
| 985 | // Fuses the given instruction in this fusion instruction and generate |
| 986 | // multioutput fusion instruction. A clone of the instruction_to_fuse will |
| 987 | // be part of the output of fusion instructions. The users of |
| 988 | // instruction_to_fuse will be redirected to this fusion instructions. |
| 989 | // instruction_to_fuse will be removed from its parent computation. |
| 990 | // |
| 991 | // Precondition: this->opcode() == HloOpcode::kFusion |
| 992 | HloInstruction* FuseInstructionIntoMultiOutput( |
| 993 | HloInstruction* instruction_to_fuse) { |
| 994 | return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); |
| 995 | } |
| 996 | |
| 997 | // Returns the start index in the given dimension for a slice node. |
| 998 | // |
| 999 | // Precondition: opcode() == HloOpcode::kSlice |
| 1000 | int64 slice_starts(int64 dimension) const { |
| 1001 | CHECK_EQ(HloOpcode::kSlice, opcode_); |
| 1002 | return slice_starts_[dimension]; |
| 1003 | } |
| 1004 | const std::vector<int64>& slice_starts() const { return slice_starts_; } |
| 1005 | |
| 1006 | // Returns the (exclusive) limit index in the given dimension for a slice |
| 1007 | // node. |
| 1008 | // |
| 1009 | // Precondition: opcode() == HloOpcode::kSlice |
| 1010 | int64 slice_limits(int64 dimension) const { |
| 1011 | CHECK_EQ(HloOpcode::kSlice, opcode_); |
| 1012 | return slice_limits_[dimension]; |
| 1013 | } |
| 1014 | const std::vector<int64>& slice_limits() const { |
| 1015 | CHECK_EQ(HloOpcode::kSlice, opcode_); |
| 1016 | return slice_limits_; |
| 1017 | } |
| 1018 | |
| 1019 | // Returns the stride in the given dimension for a slice node. |
| 1020 | // |
| 1021 | // Precondition: opcode() == HloOpcode::kSlice |
| 1022 | int64 slice_strides(int64 dimension) const { |
| 1023 | CHECK_EQ(HloOpcode::kSlice, opcode_); |
| 1024 | return slice_strides_[dimension]; |
| 1025 | } |
| 1026 | const std::vector<int64>& slice_strides() const { return slice_strides_; } |
| 1027 | |
| 1028 | // Returns the flag that describes whether a slice must be lowered into an |
| 1029 | // offset into the original operand. |
| 1030 | bool IsInPlaceSlice() const { return is_in_place_slice_; } |
| 1031 | |
| 1032 | // Sets and returns the flag that describes whether a slice must be lowered |
| 1033 | // into an offset into the original operand. |
| 1034 | bool SetIsInPlaceSlice(bool value) { |
| 1035 | is_in_place_slice_ = value; |
| 1036 | return value; |
| 1037 | } |
| 1038 | |
| 1039 | // Returns the size of the slice in the given dimension for a dynamic |
| 1040 | // slice node. |
| 1041 | // |
| 1042 | // Precondition: opcode() == HloOpcode::kDynamicSlice |
| 1043 | int64 slice_sizes(int64 dimension) const { |
| 1044 | CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); |
| 1045 | return dynamic_slice_sizes_[dimension]; |
| 1046 | } |
| 1047 | const std::vector<int64>& dynamic_slice_sizes() const { |
| 1048 | CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); |
| 1049 | return dynamic_slice_sizes_; |
| 1050 | } |
| 1051 | |
| 1052 | // Returns the number of exponent bits for a reduce-precision node. |
| 1053 | // |
| 1054 | // Precondition: opcode() == HloOpcode::kReducePrecision |
| 1055 | int32 exponent_bits() const { |
| 1056 | CHECK_EQ(HloOpcode::kReducePrecision, opcode_); |
| 1057 | return exponent_bits_; |
| 1058 | } |
| 1059 | |
| 1060 | // Returns the number of mantissa bits for a reduce-precision node. |
| 1061 | // |
| 1062 | // Precondition: opcode() == HloOpcode::kReducePrecision |
| 1063 | int32 mantissa_bits() const { |
| 1064 | CHECK_EQ(HloOpcode::kReducePrecision, opcode_); |
| 1065 | return mantissa_bits_; |
| 1066 | } |
| 1067 | |
| 1068 | // Returns data on the window in a windowed operation such as |
| 1069 | // convolution. |
| 1070 | const Window& window() const { |
| 1071 | CHECK(window_ != nullptr); |
| 1072 | return *window_; |
| 1073 | } |
| 1074 | |
| 1075 | // Sets the window data in a windowed operation such as convolution. |
| 1076 | void set_window(const Window& window) { |
| 1077 | window_ = MakeUnique<Window>(window); |
| 1078 | } |
| 1079 | |
| 1080 | // Returns the padding configuration for a pad node. |
| 1081 | // |
| 1082 | // Precondition: opcode() == HloOpcode::kPad |
| 1083 | const PaddingConfig& padding_config() const { |
| 1084 | CHECK(padding_config_ != nullptr); |
| 1085 | return *padding_config_; |
| 1086 | } |
| 1087 | |
| 1088 | // Returns data on the dimension numbers used for a convolution operation, |
| 1089 | // which may be a kConvolution instruction or a kCustomCall that implements a |
| 1090 | // convolution. |
| 1091 | const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { |
| 1092 | CHECK(convolution_dimension_numbers_ != nullptr); |
| 1093 | return *convolution_dimension_numbers_; |
| 1094 | } |
| 1095 | |
| 1096 | // Sets the convolution dimension numbers on this instruction. In general you |
| 1097 | // shouldn't need to call this; instead, specify the convolution dimension |
| 1098 | // numbers when you create the instruction. |
| 1099 | void set_convolution_dimension_numbers( |
| 1100 | const ConvolutionDimensionNumbers& dnums) { |
| 1101 | convolution_dimension_numbers_ = |
| 1102 | MakeUnique<ConvolutionDimensionNumbers>(dnums); |
| 1103 | } |
| 1104 | |
| 1105 | FftType fft_type() const { |
| 1106 | CHECK_EQ(HloOpcode::kFft, opcode_); |
| 1107 | return fft_type_; |
| 1108 | } |
| 1109 | |
| 1110 | const std::vector<int64>& fft_length() const { |
| 1111 | CHECK_EQ(HloOpcode::kFft, opcode_); |
| 1112 | return fft_length_; |
| 1113 | } |
| 1114 | |
| 1115 | // Returns the dump string of the convolution dimension numbers. |
| 1116 | string ConvolutionDimensionNumbersToString() const; |
| 1117 | |
| 1118 | // Returns data on the dimension numbers used for a dot operation. |
| 1119 | const DotDimensionNumbers& dot_dimension_numbers() const { |
| 1120 | CHECK(dot_dimension_numbers_ != nullptr); |
| 1121 | return *dot_dimension_numbers_; |
| 1122 | } |
| 1123 | |
| 1124 | // Returns the dump string of the dot dimension numbers. |
| 1125 | string DotDimensionNumbersToString() const; |
| 1126 | |
| 1127 | const GatherDimensionNumbers& gather_dimension_numbers() const { |
| 1128 | CHECK(gather_dimension_numbers_ != nullptr); |
| 1129 | return *gather_dimension_numbers_; |
| 1130 | } |
| 1131 | |
| 1132 | tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { |
| 1133 | CHECK_EQ(opcode(), HloOpcode::kGather); |
| 1134 | return gather_window_bounds_; |
| 1135 | } |
| 1136 | |
| 1137 | // Returns the dump string of the gather dimension numbers. |
| 1138 | string GatherDimensionNumbersToString() const; |
| 1139 | |
| 1140 | // Returns the random distribution for this rng node. |
| 1141 | // |
| 1142 | // Precondition: opcode() == HloOpcode::kRng |
| 1143 | RandomDistribution random_distribution() const; |
| 1144 | |
| 1145 | // Clones the HLO instruction. The clone will have the same opcode, shape, and |
| 1146 | // operands. After creation the clone has no uses. "this" (the instruction |
| 1147 | // cloned from) is not changed. Suffix is the string to append to the name of |
| 1148 | // the instruction to form the name of the cloned instruction. |
| 1149 | // If the module pointer is not nullptr, it will be the module where |
| 1150 | // the cloned computations will be added to (in order to support deep |
| 1151 | // cloning). |
| 1152 | std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone" , |
| 1153 | HloModule* module = nullptr) const; |
| 1154 | |
| 1155 | // Clones the HLO instruction as above but with new shape and operands. |
| 1156 | // If the module pointer is not nullptr, it will be the module where |
| 1157 | // the cloned computations will be added to (in order to support deep |
| 1158 | // cloning). |
| 1159 | std::unique_ptr<HloInstruction> CloneWithNewOperands( |
| 1160 | const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, |
| 1161 | HloModule* module = nullptr) const; |
| 1162 | |
| 1163 | // Returns the computations this instruction directly calls (if any). |
| 1164 | const std::vector<HloComputation*>& called_computations() const { |
| 1165 | return called_computations_; |
| 1166 | } |
| 1167 | |
| 1168 | // Replaces all called computations based on a map function. This is needed |
| 1169 | // when we clone hlo_computations and want to let the instructions to point |
| 1170 | // to the newly cloned nodes. |
| 1171 | void ReplaceCalledComputations( |
| 1172 | std::function<HloComputation*(HloComputation*)> map_function) { |
| 1173 | for (int64 i = 0; i < called_computations_.size(); ++i) { |
| 1174 | called_computations_[i] = map_function(called_computations_[i]); |
| 1175 | } |
| 1176 | } |
| 1177 | |
| 1178 | // Clears out the called computations. |
| 1179 | // |
| 1180 | // This is, in particular, necessary when inlining function bodies into their |
| 1181 | // caller. If there were side-effecting operations in the called computations, |
| 1182 | // the call itself is considered side-effecting and thus cannot be removed. By |
| 1183 | // clearing out the computations, we reflect the fact that all side-effecting |
| 1184 | // properties have been reflected in the caller, and make the call HLO |
| 1185 | // removable. |
| 1186 | void ClearCalledComputations() { called_computations_.clear(); } |
| 1187 | |
| 1188 | // Returns true if this instruction performs an elementwise operation on |
| 1189 | // `operand_idx`-th operand. An instruction is elementwise on an operand iff, |
| 1190 | // after performing necessary implicit broadcast |
| 1191 | // (cs/IrArray::EmitArrayElementAddress), to compute the output at index |
| 1192 | // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is |
| 1193 | // the element at {i_0,i_1,...,i_n}. |
| 1194 | // |
| 1195 | // Note on performance: when this instruction is kFusion, this method, in the |
| 1196 | // worst case, scans all fused instructions. We could speed this up by |
| 1197 | // caching. |
| 1198 | bool IsElementwiseOnOperand(int64 operand_idx) const; |
| 1199 | |
| 1200 | // Returns true if this instruction is elementwise on all its operands. |
| 1201 | bool IsElementwise() const; |
| 1202 | |
| 1203 | // Returns true if this elementwise instruction implicitly broadcasts operand |
| 1204 | // `operand_idx`. |
| 1205 | // |
| 1206 | // Precondition: this instruction should be an elementwise operation. |
| 1207 | bool ImplicitlyBroadcastsOperand(int64 operand_idx) const; |
| 1208 | |
| 1209 | // Returns true if this instruction is binary and elementwise. |
| 1210 | bool IsElementwiseBinary() const; |
| 1211 | |
| 1212 | // Returns whether this instruction may reuse elements of its `i`th operand. |
| 1213 | bool ReusesOperandElements(int64 i) const { |
| 1214 | return OperandElementUse(i) == UseKind::kReuse; |
| 1215 | } |
| 1216 | |
| 1217 | // Returns the indices that the given operand appear in the operand list of |
| 1218 | // this instruction. Note that an instruction can use the same operand |
| 1219 | // multiple times. |
| 1220 | std::vector<int64> OperandIndices(const HloInstruction* operand) const; |
| 1221 | |
| 1222 | // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If |
| 1223 | // this reshape merely inserts or deletes 1-sized dimensions, return the input |
| 1224 | // indices of the deleted dimensions and the output indices of the inserted |
| 1225 | // dimensions. |
| 1226 | // |
| 1227 | // Precondition: this op must be a reshape. |
| 1228 | std::tuple<bool, std::vector<int64>, std::vector<int64>> |
| 1229 | ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; |
| 1230 | |
| 1231 | // Gets/sets the string identifier for this instruction. |
| 1232 | const string& name() const { return name_; } |
| 1233 | void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } |
| 1234 | |
| 1235 | // Use the given NameUniquer to select a unique name for the instruction based |
| 1236 | // on the instruction's existing name. |
| 1237 | void UniquifyName(NameUniquer* name_uniquer); |
| 1238 | |
| 1239 | // Set the unique id for this instruction to "id" |
| 1240 | void SetUniqueId(int id) { |
| 1241 | CHECK_EQ(unique_id_, -1); // Should not be assigned already |
| 1242 | CHECK_GE(id, 0); |
| 1243 | unique_id_ = id; |
| 1244 | } |
| 1245 | |
| 1246 | // Return the unique ID assigned to this node via SetUniqueId (or -1 |
| 1247 | // if no id has been assigned yet). |
| 1248 | int unique_id() const { return unique_id_; } |
| 1249 | |
| 1250 | // Sets the debug metadata for this instruction. |
| 1251 | void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } |
| 1252 | const OpMetadata& metadata() const { return metadata_; } |
| 1253 | |
| 1254 | // Set/get the computation containing this instruction. set_parent should only |
| 1255 | // be called by HloComputation methods which add/remove instructions to |
| 1256 | // computations. |
| 1257 | void set_parent(HloComputation* computation) { parent_ = computation; } |
| 1258 | const HloComputation* parent() const { return parent_; } |
| 1259 | HloComputation* parent() { return parent_; } |
| 1260 | |
| 1261 | // Returns the module for this instruction. |
| 1262 | HloModule* GetModule() const; |
| 1263 | |
| 1264 | // Returns whether we could assign input and output layouts to this |
| 1265 | // instruction to make it a bitcast. |
| 1266 | bool CouldBeBitcast() const; |
| 1267 | |
| 1268 | // Get/Set the number of partitions per outer dimension (in order, starting |
| 1269 | // with outer-most dimension first). Currently used by the parallel cpu |
| 1270 | // backend to partition HLOs into parallel tasks. |
| 1271 | // TODO(b/62783254) Replace these methods with a more general way to |
| 1272 | // annotate HLOs with backend-specific information. |
| 1273 | const std::vector<int64>& outer_dimension_partitions() const { |
| 1274 | return outer_dimension_partitions_; |
| 1275 | } |
| 1276 | void set_outer_dimension_partitions( |
| 1277 | const std::vector<int64>& outer_dimension_partitions); |
| 1278 | |
| 1279 | // Change the layout for an Constant Hlo instruction to match new_layout. For |
| 1280 | // tuple shaped constants shape_index is the path to the internal array |
| 1281 | // subshape whose layout needs to be changed. |
| 1282 | void RelayoutConstant(const Layout& new_layout, |
| 1283 | const ShapeIndex& shape_index = {}); |
| 1284 | |
| 1285 | private: |
| 1286 | enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; |
| 1287 | |
| 1288 | // Helper class for computing OperandElementUse for kFusion. |
| 1289 | class FusionReusesParamElements; |
| 1290 | |
| 1291 | // See comments on Identical(). |
| 1292 | // eq_shapes() is used to check shapes for equality, and would normally be |
| 1293 | // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on |
| 1294 | // whether we want a layout-sensitive check or not. |
| 1295 | bool IdenticalSlowPath( |
| 1296 | const HloInstruction& other, |
| 1297 | const std::function<bool(const HloComputation*, const HloComputation*)>& |
| 1298 | eq_computations, |
| 1299 | const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const; |
| 1300 | |
| 1301 | // Creates an n-ary elementwise operation. |
| 1302 | static std::unique_ptr<HloInstruction> CreateNary( |
| 1303 | const Shape& shape, HloOpcode opcode, |
| 1304 | tensorflow::gtl::ArraySlice<HloInstruction*> operands); |
| 1305 | |
| 1306 | // Appends operand to the list of operands and adds this instruction as a user |
| 1307 | // of the operand. |
| 1308 | void AppendOperand(HloInstruction* operand); |
| 1309 | |
| 1310 | // Adds a user for this instruction. |
| 1311 | void AddUser(HloInstruction* user); |
| 1312 | |
| 1313 | // Removes a user for this instruction. |
| 1314 | void RemoveUser(HloInstruction* user); |
| 1315 | |
| 1316 | // Internal constructor for a given opcode/shape, other fields must be filled |
| 1317 | // by factory methods. |
| 1318 | HloInstruction(HloOpcode opcode, const Shape& shape); |
| 1319 | |
| 1320 | // Fuses the given instruction into this fusion instruction. When add_output |
| 1321 | // is false (which is the default), instruction_to_fuse is cloned and the |
| 1322 | // clone is placed in the fusion instruction. instruction_to_fuse is |
| 1323 | // unchanged. |
| 1324 | // |
| 1325 | // When add_output is true, a clone of the instruction_to_fuse will be part |
| 1326 | // of the output of fusion instructions. The users of instruction_to_fuse |
| 1327 | // will be redirected to this fusion instructions. instruction_to_fuse will |
| 1328 | // be removed from its parent computation. |
| 1329 | // |
| 1330 | // Precondition: this->opcode() == HloOpcode::kFusion |
| 1331 | HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, |
| 1332 | bool add_output = false); |
| 1333 | |
| 1334 | // Clones the given instruction_to_fuse and insert the clone into this fusion |
| 1335 | // instruction. If add_output is true, a clone of instruction_to_fuse will |
| 1336 | // be in the output of the this fusion instruction (part of the tuple of the |
| 1337 | // fusion root). |
| 1338 | // |
| 1339 | // Precondition: opcode() == HloOpcode::kFusion |
| 1340 | HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, |
| 1341 | bool add_output = false); |
| 1342 | |
| 1343 | // Clones a fusion instruction with a new shape and operands. |
| 1344 | std::unique_ptr<HloInstruction> CloneFusionWithNewOperands( |
| 1345 | const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, |
| 1346 | HloModule* module = nullptr) const; |
| 1347 | |
| 1348 | // Returns true if this instruction can legally have the dimensions field |
| 1349 | // set. Used for checking precondition of dimensions field accessors. |
| 1350 | bool CanHaveDimensionsField() const; |
| 1351 | |
| 1352 | // Returns how this instruction uses elements of its `i`th operand. |
| 1353 | UseKind OperandElementUse(int64 i) const; |
| 1354 | |
| 1355 | int unique_id_; // Unique to this HloInstruction within a HloModule |
| 1356 | |
| 1357 | // Opcode for this instruction. |
| 1358 | HloOpcode opcode_; |
| 1359 | |
| 1360 | // Instruction operands. |
| 1361 | InstructionVector operands_; |
| 1362 | |
| 1363 | // The set of control predecessors of this instruction. |
| 1364 | std::vector<HloInstruction*> control_predecessors_; |
| 1365 | |
| 1366 | // The users of this instruction. Users are HLOs where this instruction is an |
| 1367 | // operand. The vector users_ and the set user_set_ contain identical |
| 1368 | // members. The set enables fast membership testing and the vector enables |
| 1369 | // fast, stable iteration. |
| 1370 | std::vector<HloInstruction*> users_; |
| 1371 | std::unordered_set<const HloInstruction*> user_set_; |
| 1372 | |
| 1373 | // The set of control successors of this instruction. |
| 1374 | std::vector<HloInstruction*> control_successors_; |
| 1375 | |
| 1376 | // The computation in which this instruction is contained. |
| 1377 | HloComputation* parent_ = nullptr; |
| 1378 | |
| 1379 | // Shape of outfeed request. |
| 1380 | Shape outfeed_shape_; |
| 1381 | |
| 1382 | // Result shape of this instruction. |
| 1383 | Shape shape_; |
| 1384 | |
| 1385 | // Literal, only present for kConstant. |
| 1386 | std::unique_ptr<Literal> literal_; |
| 1387 | |
| 1388 | // Constant index, only present for kGetTupleElement. |
| 1389 | int64 tuple_index_ = -1; |
| 1390 | |
| 1391 | // Dimensions present for some operations that require reshaping or |
| 1392 | // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. |
| 1393 | std::vector<int64> dimensions_; |
| 1394 | |
| 1395 | // Describes the window in a windowed operation such as convolution. |
| 1396 | std::unique_ptr<Window> window_; |
| 1397 | |
| 1398 | // Describes the dimension numbers used for a convolution. |
| 1399 | std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; |
| 1400 | |
| 1401 | // Describes the dimension numbers used for a dot. |
| 1402 | std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; |
| 1403 | |
| 1404 | std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; |
| 1405 | std::vector<int64> gather_window_bounds_; |
| 1406 | |
| 1407 | // Describes FFT type for an FFT instruction. |
| 1408 | FftType fft_type_ = FftType::FFT; |
| 1409 | |
| 1410 | // Indicates the FFT length for an FFT instruction. |
| 1411 | std::vector<int64> fft_length_; |
| 1412 | |
| 1413 | // Describes the [begin, end) index range for a slice. |
| 1414 | std::vector<int64> slice_starts_; |
| 1415 | std::vector<int64> slice_limits_; |
| 1416 | std::vector<int64> slice_strides_; |
| 1417 | |
| 1418 | // Describes whether the slice can be lowered to an offset into the operand. |
| 1419 | bool is_in_place_slice_ = false; |
| 1420 | |
| 1421 | // The bit sizes for a reduce-precision operation. |
| 1422 | int32 exponent_bits_ = 0; |
| 1423 | int32 mantissa_bits_ = 0; |
| 1424 | |
| 1425 | // Describes the [start, start + size) range size for a dynamic slice |
| 1426 | // ('start' is specified dynamically in the second operand of the operation). |
| 1427 | std::vector<int64> dynamic_slice_sizes_; |
| 1428 | |
| 1429 | // The padding configuration that describes the edge padding and interior |
| 1430 | // padding of this pad instruction. Only set for pad instructions. |
| 1431 | std::unique_ptr<PaddingConfig> padding_config_; |
| 1432 | |
| 1433 | // The type of the fusion. Used by kFusion only. |
| 1434 | FusionKind fusion_kind_; |
| 1435 | |
| 1436 | // The sharding, if one exists. |
| 1437 | std::unique_ptr<HloSharding> sharding_; |
| 1438 | |
| 1439 | // For parameter instructions this field holds the parameter number. |
| 1440 | int64 parameter_number_ = 0; |
| 1441 | |
| 1442 | // Name of a global symbol to call, only present for kCustomCall. |
| 1443 | string custom_call_target_; |
| 1444 | |
| 1445 | // Name to use for host send/recv channels, only present for kHostCompute. |
| 1446 | string channel_name_; |
| 1447 | |
| 1448 | // Estimate of the duration of a host computation in nanoseconds. |
| 1449 | int64 cost_estimate_ns_; |
| 1450 | |
| 1451 | // Computations called by this instruction. |
| 1452 | std::vector<HloComputation*> called_computations_; |
| 1453 | |
| 1454 | // Indices of computations in called_computations_ for instructions which call |
| 1455 | // multiple computations. |
| 1456 | enum { |
| 1457 | // kWhile computations. |
| 1458 | kBodyComputationIndex = 0, |
| 1459 | kConditionComputationIndex = 1, |
| 1460 | |
| 1461 | // kSelectAndScatter computations. |
| 1462 | kSelectComputationIndex = 0, |
| 1463 | kScatterComputationIndex = 1, |
| 1464 | |
| 1465 | // kConditional computations. |
| 1466 | kTrueComputationIndex = 0, |
| 1467 | kFalseComputationIndex = 1, |
| 1468 | }; |
| 1469 | |
| 1470 | // Outfeed configuration information, only present for kOutfeed. |
| 1471 | string outfeed_config_; |
| 1472 | |
| 1473 | // A trace instruction that consumes this instruction. |
| 1474 | // |
| 1475 | // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as |
| 1476 | // an operand. |
| 1477 | HloInstruction* trace_instruction_ = nullptr; |
| 1478 | |
| 1479 | // The distribution requested for random number generation. |
| 1480 | // Only present for kRng. |
| 1481 | RandomDistribution distribution_; |
| 1482 | |
| 1483 | // A small float number added to the variance to avoid divide-by-zero error. |
| 1484 | // Only present for kBatchNormTraining. |
| 1485 | float epsilon_ = 0.0f; |
| 1486 | |
| 1487 | // An integer value representing the index of the feature dimension. |
| 1488 | // Only present for kBatchNormTraining. |
| 1489 | int64 feature_index_ = -1; |
| 1490 | |
| 1491 | // Represents a unique identifier for each Send/Recv instruction pair. |
| 1492 | // Only present for kSend or kRecv. |
| 1493 | int64 channel_id_ = -1; |
| 1494 | |
| 1495 | // The string representation of the infeed configuration. |
| 1496 | string infeed_config_; |
| 1497 | |
| 1498 | // String identifier for instruction. |
| 1499 | string name_; |
| 1500 | |
| 1501 | // Metadata for debugging. |
| 1502 | OpMetadata metadata_; |
| 1503 | |
| 1504 | // The number of partitions per outer dimension (listed in order from |
| 1505 | // outer-most dimension first). |
| 1506 | std::vector<int64> outer_dimension_partitions_; |
| 1507 | |
| 1508 | TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); |
| 1509 | }; |
| 1510 | |
| 1511 | string ToString(HloInstruction::FusionKind kind); |
| 1512 | StatusOr<HloInstruction::FusionKind> StringToFusionKind( |
| 1513 | const string& kind_name); |
| 1514 | |
| 1515 | // Custom (de)stringification functions for protos that live inside |
| 1516 | // HloInstruction. |
| 1517 | string PaddingConfigToString(const PaddingConfig& padding); |
| 1518 | string OpMetadataToString(const OpMetadata& metadata); |
| 1519 | string RandomDistributionToString(const RandomDistribution& distribution); |
| 1520 | StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); |
| 1521 | |
| 1522 | std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); |
| 1523 | |
| 1524 | // Map classes that guarantee a deterministic iteration order when the key is |
| 1525 | // an HloInstruction* or a const HloInstruction*. |
| 1526 | // To make the iteration order over the map deterministic, the comparator |
| 1527 | // should not be using the pointer values, but rather an intrinsic property of |
| 1528 | // the hlo. |
| 1529 | // |
| 1530 | // Note that this cannot be used for HLO instructions across multiple modules |
| 1531 | // since the id of HLO instructions are only unique within each HLO module. |
| 1532 | struct HloPtrComparator { |
| 1533 | bool operator()(const HloInstruction* const& lhs, |
| 1534 | const HloInstruction* const& rhs) const { |
| 1535 | return lhs->unique_id() < rhs->unique_id(); |
| 1536 | } |
| 1537 | }; |
| 1538 | |
| 1539 | template <typename ValueT> |
| 1540 | using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>; |
| 1541 | |
| 1542 | template <typename ValueT> |
| 1543 | using ConstHloInstructionMap = |
| 1544 | std::map<const HloInstruction*, ValueT, HloPtrComparator>; |
| 1545 | |
| 1546 | using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>; |
| 1547 | using ConstHloInstructionSet = |
| 1548 | std::set<const HloInstruction*, HloPtrComparator>; |
| 1549 | |
| 1550 | } // namespace xla |
| 1551 | |
| 1552 | #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ |
| 1553 | |