18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ 25 #include <unordered_map> 28 #include "llvm/ADT/Triple.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/Module.h" 32 #include "llvm/IR/Value.h" 33 #include "llvm/Target/TargetMachine.h" 34 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 35 #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" 36 #include "tensorflow/compiler/xla/service/cpu/ir_function.h" 37 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" 38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 41 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 42 #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" 43 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 44 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" 45 #include "tensorflow/compiler/xla/service/name_uniquer.h" 46 #include "tensorflow/compiler/xla/statusor.h" 47 #include "tensorflow/compiler/xla/types.h" 48 #include "tensorflow/compiler/xla/xla_data.pb.h" 49 #include "tensorflow/core/lib/core/stringpiece.h" 50 #include "tensorflow/core/lib/gtl/array_slice.h" 51 #include "tensorflow/core/lib/gtl/flatmap.h" 52 #include "tensorflow/core/platform/macros.h" 53 #include "tensorflow/core/platform/types.h" 82 llvm::Module* llvm_module,
83 std::unordered_map<const HloInstruction*, int64>
84 instruction_to_profile_idx,
85 std::unordered_map<const HloComputation*, int64>
86 computation_to_profile_idx,
87 llvm::TargetMachine* target_machine,
88 ExternalConstantPool* external_constant_pool);
115 bool is_top_level_computation,
116 std::vector<const HloInstruction*>* instruction_order);
118 llvm::IRBuilder<>* ir_builder() {
return &ir_builder_; }
121 StatusOr<llvm::Value*> EmitScalarCall(
123 const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
136 Status HandleGetTupleElement(
HloInstruction* get_tuple_element)
override;
147 Status HandleReduceWindow(
HloInstruction* reduce_window)
override;
148 Status HandleSelectAndScatter(
HloInstruction* select_and_scatter)
override;
152 Status HandleDynamicSlice(
HloInstruction* dynamic_slice)
override;
153 Status HandleDynamicUpdateSlice(
173 void InitializeIrFunction(
const string& function_name);
175 template <
typename T>
176 llvm::Value* GetProfileCounterCommon(
178 const std::unordered_map<const T*, int64>& profile_index_map);
183 llvm::Value* GetProfileCounterFor(
const HloInstruction& instruction) {
184 return GetProfileCounterCommon<HloInstruction>(instruction,
185 instruction_to_profile_idx_);
188 llvm::Value* GetProfileCounterFor(
const HloComputation& computation) {
189 return GetProfileCounterCommon<HloComputation>(computation,
190 computation_to_profile_idx_);
207 std::vector<llvm_ir::IrArray> GetIrArraysForOperandsOf(
212 llvm_ir::IrArray* array) {
213 alias_analysis_.AddAliasingInformationToIrArray(hlo, array);
217 llvm::Type* IrShapeType(
const Shape& shape);
221 llvm::Value* GetProfileCountersArgument();
225 llvm::Value* GetExecutableRunOptionsArgument();
229 llvm::Value* GetTempBuffersArgument();
234 llvm::Value* EmitTempBufferPointer(
const BufferAllocation::Slice& slice,
235 const Shape& target_shape);
240 StatusOr<llvm::Function*> EmitFunction(
242 tensorflow::StringPiece
243 function_name_suffix);
258 llvm::Value* EmitElementFunctionCall(
259 llvm::Function*
function,
const Shape& return_shape,
260 tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
261 tensorflow::StringPiece name);
271 void EmitArrayFunctionCallInto(
272 llvm::Function*
function,
273 tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
274 llvm::Value* return_value_buffer, tensorflow::StringPiece name);
279 llvm::Value* EmitArrayFunctionCall(
280 llvm::Function*
function,
const Shape& return_shape, int64 element_count,
281 tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
282 tensorflow::StringPiece name);
286 Status ElementTypesSameAndSupported(
288 tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
289 tensorflow::gtl::ArraySlice<PrimitiveType> supported_types);
301 Status EmitTargetElementLoop(
303 const llvm_ir::ElementGenerator& element_generator);
304 Status EmitTargetElementLoop(
306 const llvm_ir::ElementGenerator& element_generator);
323 llvm::Constant* CreateInitializerForConstantArray(
324 const std::vector<llvm::Constant*>& array_elements,
const Shape& shape,
325 int64 dimension_index);
336 StatusOr<bool> EmitVectorizedReduce(
338 tensorflow::gtl::ArraySlice<int64> dimensions,
HloComputation*
function,
339 string* failure_reason);
355 using ShardedVector = std::vector<llvm::Value*>;
359 using ShardedVectorType = std::vector<llvm::Type*>;
363 ShardedVectorType CreateShardedVectorType(PrimitiveType element_type,
364 unsigned element_count);
368 void EmitShardedVectorStore(llvm::Value* store_address,
369 const ShardedVector& value_to_store,
371 const llvm_ir::IrArray& containing_array);
373 using ReductionGenerator = std ::function<llvm::Value*(
374 llvm::IRBuilder<>*, llvm::Value*, llvm::Value*)>;
380 ReductionGenerator MatchReductionGenerator(
HloComputation*
function,
381 string* failure_reason)
const;
385 StatusOr<ShardedVector> EmitInnerLoopForVectorizedReduction(
386 const ReductionGenerator& reduction_generator,
387 const llvm_ir::IrArray::Index& output_index,
388 const ShardedVectorType& accumulator_type,
HloInstruction* init_value,
389 HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
390 unsigned element_alignment);
395 StatusOr<bool> EmitFastConcatenate(
397 tensorflow::gtl::ArraySlice<HloInstruction*> operands,
398 string* failure_reason);
402 void EmitTransferElements(llvm::Value* target, llvm::Value* source,
403 int64 element_count, PrimitiveType primitive_type,
404 const llvm_ir::IrArray& target_array,
405 const llvm_ir::IrArray& source_array);
409 const BufferAssignment& assignment_;
412 llvm::Module* module_;
415 llvm::Triple::ArchType arch_type_;
418 NameUniquer name_uniquer_;
421 std::map<HloComputation*, llvm::Function*> emitted_functions_;
424 std::map<std::pair<llvm::Function*, BufferAllocation::Slice>,
426 thread_local_buffers_;
432 std::unique_ptr<IrFunction> compute_function_;
433 llvm::IRBuilder<> ir_builder_;
436 const std::unordered_map<const HloInstruction*, int64>
437 instruction_to_profile_idx_;
440 const std::unordered_map<const HloComputation*, int64>
441 computation_to_profile_idx_;
444 std::unordered_map<const HloInstruction*, llvm::Value*> emitted_value_;
446 llvm_ir::AliasAnalysis alias_analysis_;
450 int64 num_dynamic_loop_bounds_ = 0;
456 return num_dynamic_loop_bounds_ > 0 &&
457 op.parent()->root_instruction() == &op;
462 class ProfilingState {
464 ProfilingState() : use_rdtscp_(
false), prof_counters_(
nullptr) {}
465 ProfilingState(
bool use_rdtscp, llvm::Value* prof_counters)
466 : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {}
469 void RecordCycleStart(llvm::IRBuilder<>* ir_builder,
HloInstruction* hlo);
471 void RecordCycleDelta(llvm::IRBuilder<>* ir_builder,
HloInstruction* hlo,
472 llvm::Value* prof_counter);
475 void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder,
476 llvm::Value* prof_counter);
480 llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder);
483 void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder,
484 llvm::Value* prof_counter, llvm::Value* cycle_end,
485 llvm::Value* cycle_start);
493 llvm::Value* prof_counters_;
496 llvm::Value* first_read_cycle_start_ =
nullptr;
499 llvm::Value* last_read_cycle_end_ =
nullptr;
503 llvm::Value* aux_i8ptr_ =
nullptr;
507 std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_;
510 ProfilingState profiling_state_;
514 void AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
const Shape& shape);
515 void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size);
519 void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
521 void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
525 int MinimumAlignmentForShape(
const Shape& shape);
528 int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
531 int MinimumAlignmentForBufferSize(int64 buffer_size);
534 int64 ByteSizeOf(
const Shape& shape)
const;
536 enum class XfeedKind {
543 Status EmitXfeedTransfer(XfeedKind kind,
const Shape& shape,
544 llvm::Value* program_buffer_address);
546 const HloModuleConfig& hlo_module_config_;
548 const bool parallel_cpu_backend_;
550 bool is_top_level_computation_;
552 TargetMachineFeatures target_machine_features_;
554 int64 external_global_constant_counter_ = 0;
555 ExternalConstantPool* external_constant_pool_;
563 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
Definition: hlo_computation.h:60
Definition: ir_emitter.h:63
IrEmitter(const HloModule &hlo_module, const BufferAssignment &assignment, llvm::Module *llvm_module, std::unordered_map< const HloInstruction *, int64 > instruction_to_profile_idx, std::unordered_map< const HloComputation *, int64 > computation_to_profile_idx, llvm::TargetMachine *target_machine, ExternalConstantPool *external_constant_pool)
Create a new LLVM IR emitter.
Definition: ir_emitter.cc:83
Definition: hlo_instruction.h:165
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52
StatusOr< llvm::Function * > EmitComputation(HloComputation *computation, const string &function_name_prefix, bool is_top_level_computation, std::vector< const HloInstruction *> *instruction_order)
Definition: ir_emitter.cc:108