tf_1.8_xla_doc
ir_emitter.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
20 
21 #include <stddef.h>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <unordered_map>
26 #include <vector>
27 
28 #include "llvm/ADT/Triple.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/Target/TargetMachine.h"
34 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
35 #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
36 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
37 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
41 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
45 #include "tensorflow/compiler/xla/service/name_uniquer.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/gtl/array_slice.h"
51 #include "tensorflow/core/lib/gtl/flatmap.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/types.h"
54 
55 namespace xla {
56 namespace cpu {
63 class IrEmitter : public DfsHloVisitorWithDefault {
64  public:
81  IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment,
82  llvm::Module* llvm_module,
83  std::unordered_map<const HloInstruction*, int64>
84  instruction_to_profile_idx,
85  std::unordered_map<const HloComputation*, int64>
86  computation_to_profile_idx,
87  llvm::TargetMachine* target_machine,
88  ExternalConstantPool* external_constant_pool);
89  ~IrEmitter() override;
113  StatusOr<llvm::Function*> EmitComputation(
114  HloComputation* computation, const string& function_name_prefix,
115  bool is_top_level_computation,
116  std::vector<const HloInstruction*>* instruction_order);
117 
118  llvm::IRBuilder<>* ir_builder() { return &ir_builder_; }
119 
120  // Emits a call to `computation` with scalar arguments `arguments`.
121  StatusOr<llvm::Value*> EmitScalarCall(
122  PrimitiveType return_type, HloComputation* computation,
123  const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
124 
125  protected:
126  //
127  // The following methods implement the DfsHloVisitor interface.
128  //
129  // Default action which emits code for most operations. Operations which are
130  // special in some way are handled explicitly in HandleFoo methods.
131  Status DefaultAction(HloInstruction* hlo) override;
132 
133  Status HandleBitcast(HloInstruction* bitcast) override;
134  Status HandleConstant(HloInstruction* constant) override;
135  Status HandleCopy(HloInstruction* copy) override;
136  Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
137  Status HandleSelect(HloInstruction* select) override;
138  Status HandleDot(HloInstruction* dot) override;
139  Status HandleConvolution(HloInstruction* convolution) override;
140  Status HandleFft(HloInstruction* fft) override;
141  Status HandleCrossReplicaSum(HloInstruction* crs) override;
142  Status HandleInfeed(HloInstruction* infeed) override;
143  Status HandleOutfeed(HloInstruction* outfeed) override;
144  Status HandleSort(HloInstruction* sort) override;
145  Status HandleParameter(HloInstruction* parameter) override;
146  Status HandleReduce(HloInstruction* reduce) override;
147  Status HandleReduceWindow(HloInstruction* reduce_window) override;
148  Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
149  Status HandleSend(HloInstruction* send) override;
150  Status HandleSendDone(HloInstruction* send_done) override;
151  Status HandleSlice(HloInstruction* slice) override;
152  Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
153  Status HandleDynamicUpdateSlice(
154  HloInstruction* dynamic_update_slice) override;
155  Status HandleRecv(HloInstruction* recv) override;
156  Status HandleRecvDone(HloInstruction* recv_done) override;
157  Status HandlePad(HloInstruction* pad) override;
158  Status HandleTuple(HloInstruction* tuple) override;
159  Status HandleMap(HloInstruction* map) override;
160  Status HandleFusion(HloInstruction* fusion) override;
161  Status HandleCall(HloInstruction* call) override;
162  Status HandleCustomCall(HloInstruction* custom_call) override;
163  Status HandleWhile(HloInstruction* xla_while) override;
164  Status HandleConcatenate(HloInstruction* concatenate) override;
165  Status HandleConditional(HloInstruction* conditional) override;
166  Status FinishVisit(HloInstruction* root) override;
167 
168  Status Preprocess(HloInstruction* hlo) override;
169  Status Postprocess(HloInstruction* hlo) override;
170 
171  private:
172  // Private helper to initialize an IR function for the computation.
173  void InitializeIrFunction(const string& function_name);
174 
175  template <typename T>
176  llvm::Value* GetProfileCounterCommon(
177  const T& hlo,
178  const std::unordered_map<const T*, int64>& profile_index_map);
179 
180  // Convenience functions to generate a GEP into the profile counter parameter
181  // which would correspond to the index for a given HLO instruction or
182  // computation.
183  llvm::Value* GetProfileCounterFor(const HloInstruction& instruction) {
184  return GetProfileCounterCommon<HloInstruction>(instruction,
185  instruction_to_profile_idx_);
186  }
187 
188  llvm::Value* GetProfileCounterFor(const HloComputation& computation) {
189  return GetProfileCounterCommon<HloComputation>(computation,
190  computation_to_profile_idx_);
191  }
192 
193  // Gets the IR Value emitted previously for the given hlo.
194  //
195  // Prefer calling GetIrArrayFor if the value you're reading is a buffer,
196  // because GetIrArrayFor annotates buffer's loads/stores with noalias
197  // metadata.
198  //
199  // Make sure to call this only when you're certain a value *was* emitted - if
200  // not found, this will log a fatal error.
201  llvm::Value* GetEmittedValueFor(const HloInstruction* hlo);
202 
203  // Gets an IrArray representing the given hlo.
204  llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo);
205 
206  // Gets a list of IrArrays, one for each of hlo's operands.
207  std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
208  const HloInstruction* hlo);
209 
210  // Augments IrArray with aliasing information.
211  void AddAliasingInformationToIrArray(const HloInstruction& hlo,
212  llvm_ir::IrArray* array) {
213  alias_analysis_.AddAliasingInformationToIrArray(hlo, array);
214  }
215 
216  // Convenience function to get the IR type matching the given shape.
217  llvm::Type* IrShapeType(const Shape& shape);
218 
219  // Get the llvm::Value* that represents the "prof_counters" argument of the
220  // computation function being emitted by this emitter.
221  llvm::Value* GetProfileCountersArgument();
222 
223  // Get the xla::ExecutableRunOptions that represents the "run_options"
224  // argument of the computation function being emitted by this emitter.
225  llvm::Value* GetExecutableRunOptionsArgument();
226 
227  // Get the llvm::Value* that represents the "temps" argument of the
228  // computation function being emitted by this emitter.
229  llvm::Value* GetTempBuffersArgument();
230 
231  // Emits code that computes the address of the given temporary buffer to the
232  // function. target_shape is the shape of this temporary buffer.
233  // The returned Value's type is a pointer to element_type.
234  llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
235  const Shape& target_shape);
236 
237  // Emits a function into the current module. This can be used for
238  // computations embedded inside other computations, such as the
239  // function that a map operation applies.
240  StatusOr<llvm::Function*> EmitFunction(
241  HloComputation* function, // The function to emit.
242  tensorflow::StringPiece
243  function_name_suffix); // Used for LLVM IR register names.
244 
245  // Methods that emit a function call.
246  // Parameters:
247  // function - The LLVM function to call.
248  // return_shape - The return shape of the HLO computation that was used to
249  // make the function. Not the same as the return type of the function
250  // in LLVM, since we use output parameters for the return type.
251  // element_count - number of elements to return (array form only).
252  // parameter_addresses - pointers to be passed to the function as
253  // parameters.
254  // name - used for LLVM IR register names.
255 
256  // Emits a function call, returning a scalar, often an element of a larger
257  // array. Returns a Value for the scalar element returned by the function.
258  llvm::Value* EmitElementFunctionCall(
259  llvm::Function* function, const Shape& return_shape,
260  tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
261  tensorflow::StringPiece name);
262 
263  // Array function call emitter. Stores the function's result into a supplied
264  // buffer.
265  // Parameters:
266  // function - The LLVM function to call.
267  // parameter_addresses - pointers to be passed to the function as
268  // parameters.
269  // return_value - pointer to a buffer where the call result is stored.
270 
271  void EmitArrayFunctionCallInto(
272  llvm::Function* function,
273  tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
274  llvm::Value* return_value_buffer, tensorflow::StringPiece name);
275 
276  // Array function call emitter. Returns a Value for the function's return
277  // value buffer address. The return value buffer is alloca'ed by this
278  // function.
279  llvm::Value* EmitArrayFunctionCall(
280  llvm::Function* function, const Shape& return_shape, int64 element_count,
281  tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
282  tensorflow::StringPiece name);
283 
284  // Verifies that the element types of all of the given operand instructions
285  // match and are of one of the given supported types.
286  Status ElementTypesSameAndSupported(
287  const HloInstruction& instruction,
288  tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
289  tensorflow::gtl::ArraySlice<PrimitiveType> supported_types);
290 
291  // Emit IR to perform a computation for every element in the given target op.
292  // This produces a series of nested loops (one for each dimension of the op's
293  // shape). The body of the inner-most loop is provided by the body_emitter
294  // function.
295  //
296  // desc is an optional human-readable string that's added to the loop name in
297  // IR. Regardless of whether desc is provided, target_op->name() is included
298  // in the loop name.
299  //
300  // TODO(jingyue): target_op should be a `const HloInstruction*`.
301  Status EmitTargetElementLoop(
302  HloInstruction* target_op,
303  const llvm_ir::ElementGenerator& element_generator);
304  Status EmitTargetElementLoop(
305  HloInstruction* target_op, tensorflow::StringPiece desc,
306  const llvm_ir::ElementGenerator& element_generator);
307 
308  // Emits a memcpy from the source instruction's result value to the
309  // destination's. Both source and destination must have an entry in the
310  // emitted_value_ table.
311  Status EmitMemcpy(const HloInstruction& source,
312  const HloInstruction& destination);
313 
314  // Emits IR to compute the target address of the buffer for the given op.
315  // After calling this function, you can get a pointer to this buffer by
316  // calling GetIrArrayForOp or GetEmittedValueFor.
317  Status EmitTargetAddressForOp(const HloInstruction* op);
318 
319  // Structurizes "array_elements" into an MD array that represents "shape".
320  // This is a recursive function, and "dimension_index" indicates the index of
321  // the current dimension that the function is considering (0 means the
322  // most-minor dimension).
323  llvm::Constant* CreateInitializerForConstantArray(
324  const std::vector<llvm::Constant*>& array_elements, const Shape& shape,
325  int64 dimension_index);
326 
327  // Tries to codegen a reduction operation using vectorized instructions.
328  // Returns true if successful, and false on failure. On failure, sets
329  // "failure_reason" to a string describing why it could not vectorize the
330  // reduction.
331  //
332  // TODO(sanjoy): Some of the things we do here can be abstracted out into
333  // concepts that generalize over other vectorizable operations. We should
334  // consider pulling out these abstractions into a VectorizingIrEmitter or
335  // something similar.
336  StatusOr<bool> EmitVectorizedReduce(
337  HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
338  tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function,
339  string* failure_reason);
340 
341  // We'd like to keep one or two one cache-line's worth of data in registers
342  // without generating IR with illegal (e.g. excessively large or
343  // non-power-of-two) vector types. We do this by introducing a layer of
344  // abstraction: we introduce a high level vector-like concept called a
345  // "sharded vector" that models data paralleism, and is mapped to a sequence
346  // scalar and vector llvm::Value s.
347  //
348  // For example, we can represent 29 f32 elements by a sharded vector mapped to
349  // a sequence of LLVM values of types [<16 x f32>, <8 x f32>, <4 x f32>, f32].
350  // Note that the last element is scalar.
351  //
352  // There is no requirement on the ordering or the uniqueness of the elements
353  // mapped to sharded vectors -- we allow repeated elements, and we allow
354  // elements to appear in any order.
355  using ShardedVector = std::vector<llvm::Value*>;
356 
357  // A sharded vector type is the element-wise llvm::Type's of some
358  // ShardedVector.
359  using ShardedVectorType = std::vector<llvm::Type*>;
360 
361  // Create a sharded vector type corresponding to a "element_count" long
362  // sequence of "element_type" values.
363  ShardedVectorType CreateShardedVectorType(PrimitiveType element_type,
364  unsigned element_count);
365 
366  // Emit LLVM IR to store the sharded vector "value_to_store" to
367  // "store_address".
368  void EmitShardedVectorStore(llvm::Value* store_address,
369  const ShardedVector& value_to_store,
370  const int alignment,
371  const llvm_ir::IrArray& containing_array);
372 
373  using ReductionGenerator = std ::function<llvm::Value*(
374  llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>;
375 
376  // Tries to match the reduction function "function" to a known reduction
377  // pattern. Returns a non-null ReductionGenerator on a successful match,
378  // which can be used to generate the LLVM IR corresponding to said reduction.
379  // On failure, this stores a reason string into "failure_reason".
380  ReductionGenerator MatchReductionGenerator(HloComputation* function,
381  string* failure_reason) const;
382 
383  // Emits the inner loop nest that runs the reduction. Helper function for
384  // EmitVectorizedReduce.
385  StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction(
386  const ReductionGenerator& reduction_generator,
387  const llvm_ir::IrArray::Index& output_index,
388  const ShardedVectorType& accumulator_type, HloInstruction* init_value,
389  HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
390  unsigned element_alignment);
391 
392  // Tries to emit a fast concatenate operation using memcpy. Returns true if
393  // successful, and false on failure. On failure, sets "failure_reason" to a
394  // string describing why it could not emit a fast concatenate.
395  StatusOr<bool> EmitFastConcatenate(
396  HloInstruction* concatenate,
397  tensorflow::gtl::ArraySlice<HloInstruction*> operands,
398  string* failure_reason);
399 
400  // Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
401  // from the address "source" to the address "target".
402  void EmitTransferElements(llvm::Value* target, llvm::Value* source,
403  int64 element_count, PrimitiveType primitive_type,
404  const llvm_ir::IrArray& target_array,
405  const llvm_ir::IrArray& source_array);
406 
407  // Assignment of the temporary buffers needed by the computation and their
408  // shape information.
409  const BufferAssignment& assignment_;
410 
411  // The LLVM module into which IR will be emitted.
412  llvm::Module* module_;
413 
414  // The target architecture.
415  llvm::Triple::ArchType arch_type_;
416 
417  // Used to produce unique names for generated functions.
418  NameUniquer name_uniquer_;
419 
420  // Map containing all previously emitted computations.
421  std::map<HloComputation*, llvm::Function*> emitted_functions_;
422 
423  // Map containing all previously emitted thread-local temporary buffers.
424  std::map<std::pair<llvm::Function*, BufferAllocation::Slice>,
425  llvm::AllocaInst*>
426  thread_local_buffers_;
427 
428  // The following fields track the IR emission state. According to LLVM memory
429  // management rules, their memory is owned by the module (Note that IrFunction
430  // creates the encapsulated llvm::Function s.t. it is added to the llvm
431  // module's function list).
432  std::unique_ptr<IrFunction> compute_function_;
433  llvm::IRBuilder<> ir_builder_;
434 
435  // Maps HLO instructions to their index into the profile counter array.
436  const std::unordered_map<const HloInstruction*, int64>
437  instruction_to_profile_idx_;
438 
439  // Maps HLO computations to their index into the profile counter array.
440  const std::unordered_map<const HloComputation*, int64>
441  computation_to_profile_idx_;
442 
443  // Maps HLOs to Values emitted for them.
444  std::unordered_map<const HloInstruction*, llvm::Value*> emitted_value_;
445 
446  llvm_ir::AliasAnalysis alias_analysis_;
447 
448  // The number of root instruction outer dimensions used in parallel loop
449  // emission (ParallelLoopEmitter).
450  int64 num_dynamic_loop_bounds_ = 0;
451 
452  // Returns whether the given instruction should be emitted as a parallel loop.
453  bool ShouldEmitParallelLoopFor(const HloInstruction& op) const {
454  // Emit parallel loop for root instruction if dynamic outer-dimension loop
455  // bounds were specified.
456  return num_dynamic_loop_bounds_ > 0 &&
457  op.parent()->root_instruction() == &op;
458  }
459 
460  // This struct contains all the state needed to emit instructions for
461  // profiling a computation.
462  class ProfilingState {
463  public:
464  ProfilingState() : use_rdtscp_(false), prof_counters_(nullptr) {}
465  ProfilingState(bool use_rdtscp, llvm::Value* prof_counters)
466  : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {}
467 
468  // Record the cycle counter before an HLO executes.
469  void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo);
470  // Record the number of cycles it took for an HLO to execute.
471  void RecordCycleDelta(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo,
472  llvm::Value* prof_counter);
473  // Record the number of cycles it took for the entire computation to
474  // execute.
475  void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder,
476  llvm::Value* prof_counter);
477 
478  // Convenience function to generate a call to an intrinsic which reads the
479  // CPU cycle counter.
480  llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder);
481 
482  // Store the cycle counter delta to the per-HLO profile counter.
483  void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder,
484  llvm::Value* prof_counter, llvm::Value* cycle_end,
485  llvm::Value* cycle_start);
486 
487  private:
488  // Should we use the x86-specific rdtscp or the generic readcyclecounter
489  // intrinsic?
490  bool use_rdtscp_;
491 
492  // The argument which corresponds to the profile counter buffer.
493  llvm::Value* prof_counters_;
494 
495  // The first read cycle counter in the program.
496  llvm::Value* first_read_cycle_start_ = nullptr;
497 
498  // The last read cycle counter in the program.
499  llvm::Value* last_read_cycle_end_ = nullptr;
500 
501  // An alloca used to hold the output of the aux value returned by the rdtscp
502  // intrinsic.
503  llvm::Value* aux_i8ptr_ = nullptr;
504 
505  // Maps HLOs to the value the cycle counter contained right before the HLO
506  // began to execute.
507  std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_;
508  };
509 
510  ProfilingState profiling_state_;
511 
512  // Given a load instruction and a shape or buffer size, annotate the load's
513  // result with the alignment required by the shape or size.
514  void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);
515  void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size);
516 
517  // Given a load instruction and a shape or buffer size, annotate the load's
518  // result with the dereferenceable bytes required by the shape / buffer size.
519  void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
520  const Shape& shape);
521  void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
522  int64 buffer_size);
523 
524  // Calculate the alignment of a buffer allocated for a given shape.
525  int MinimumAlignmentForShape(const Shape& shape);
526 
527  // Calculate the alignment of a buffer allocated for a given primitive type.
528  int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
529 
530  // Calculate the alignment of a buffer with a particular size.
531  int MinimumAlignmentForBufferSize(int64 buffer_size);
532 
533  // Returns the number of bytes within the shape.
534  int64 ByteSizeOf(const Shape& shape) const;
535 
536  enum class XfeedKind {
537  kInfeed,
538  kOutfeed,
539  };
540 
541  // Emit IR to transfer between a {infeed,outfeed} buffer and an in-program
542  // address.
543  Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
544  llvm::Value* program_buffer_address);
545 
546  const HloModuleConfig& hlo_module_config_;
547 
548  const bool parallel_cpu_backend_;
549 
550  bool is_top_level_computation_;
551 
552  TargetMachineFeatures target_machine_features_;
553 
554  int64 external_global_constant_counter_ = 0;
555  ExternalConstantPool* external_constant_pool_;
556 
557  TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
558 };
559 
560 } // namespace cpu
561 } // namespace xla
562 
563 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
Definition: hlo_computation.h:60
Definition: ir_emitter.h:63
IrEmitter(const HloModule &hlo_module, const BufferAssignment &assignment, llvm::Module *llvm_module, std::unordered_map< const HloInstruction *, int64 > instruction_to_profile_idx, std::unordered_map< const HloComputation *, int64 > computation_to_profile_idx, llvm::TargetMachine *target_machine, ExternalConstantPool *external_constant_pool)
Create a new LLVM IR emitter.
Definition: ir_emitter.cc:83
Definition: hlo_instruction.h:165
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52
StatusOr< llvm::Function * > EmitComputation(HloComputation *computation, const string &function_name_prefix, bool is_top_level_computation, std::vector< const HloInstruction *> *instruction_order)
Definition: ir_emitter.cc:108