18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ 21 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 29 class ParallelCostModel {
31 virtual ~ParallelCostModel() =
default;
32 virtual int64 GetParallelTaskCount(HloInstruction* instruction) = 0;
36 class ParallelTaskAssignment {
42 ParallelTaskAssignment(
const int64 max_parallelism,
43 const HloCostAnalysis::ShapeSizeFunction& shape_size,
45 ~ParallelTaskAssignment() {}
48 int64 GetTargetParallelTaskCount(HloInstruction* instruction);
51 std::unique_ptr<ParallelCostModel> cost_model_;
70 const HloCostAnalysis::ShapeSizeFunction& shape_size)
71 : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {}
74 tensorflow::StringPiece name()
const override {
75 return "cpu-parallel-task-assigner";
80 StatusOr<bool> Run(
HloModule* module)
override;
83 using HloToParallelTasks = std::unordered_map<const HloInstruction*, int64>;
88 bool AssignParallelTasks(
HloModule* module,
89 const HloToParallelTasks& hlo_to_parallel_tasks);
90 bool AssignParallelTasksHelper(
92 const HloToParallelTasks& hlo_to_parallel_tasks);
96 void ComputeTargetParallelTasks(
HloModule* module,
97 HloToParallelTasks* hlo_to_parallel_tasks);
99 int64 max_parallelism_;
100 HloCostAnalysis::ShapeSizeFunction shape_size_function_;
106 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ Definition: parallel_task_assignment.h:64
Definition: hlo_computation.h:60
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52