17 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 18 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 23 #include "tensorflow/compiler/xla/service/executable.h" 25 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 26 #include "tensorflow/compiler/xla/service/logical_buffer.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/core/lib/gtl/array_slice.h" 30 #include "tensorflow/core/platform/mutex.h" 31 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 32 #include "tensorflow/core/platform/thread_annotations.h" 37 using ObjectFileData = std::vector<char>;
41 using BufferSizes = std::vector<int64>;
43 class AotCompilationResult {
45 AotCompilationResult(
const AotCompilationResult&) =
delete;
46 AotCompilationResult& operator=(AotCompilationResult
const&) =
delete;
47 virtual ~AotCompilationResult() =
default;
49 AotCompilationResult() =
default;
52 class AotCompilationOptions {
54 AotCompilationOptions(
const AotCompilationOptions&) =
delete;
55 AotCompilationOptions& operator=(AotCompilationOptions
const&) =
delete;
56 virtual ~AotCompilationOptions() =
default;
58 virtual perftools::gputools::Platform::Id PlatformId()
const = 0;
61 DeviceMemoryAllocator* device_allocator()
const {
return device_allocator_; }
62 void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
63 device_allocator_ = device_allocator;
65 const DebugOptions& debug_options()
const {
return debug_options_; }
66 DebugOptions* mutable_debug_options() {
return &debug_options_; }
68 AotCompilationOptions();
70 DeviceMemoryAllocator* device_allocator_ =
nullptr;
71 DebugOptions debug_options_;
96 virtual perftools::gputools::Platform::Id PlatformId()
const = 0;
105 virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
106 std::unique_ptr<HloModule> module,
107 perftools::gputools::StreamExecutor* executor,
108 DeviceMemoryAllocator* device_allocator) = 0;
121 virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
122 std::unique_ptr<HloModule> module,
123 perftools::gputools::StreamExecutor* executor,
124 DeviceMemoryAllocator* device_allocator) = 0;
133 virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
134 std::vector<std::unique_ptr<HloModule>> modules,
135 std::vector<std::vector<perftools::gputools::StreamExecutor*>>
137 DeviceMemoryAllocator* device_allocator) = 0;
140 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
141 CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
142 const AotCompilationOptions& options) = 0;
146 using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
151 static void RegisterCompilerFactory(
152 perftools::gputools::Platform::Id platform_id,
153 CompilerFactory compiler_factory);
156 static StatusOr<Compiler*> GetForPlatform(
157 const perftools::gputools::Platform* platform);
160 virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction()
const = 0;
163 std::function<int64(const LogicalBuffer&)> BufferSizeBytesFunction() {
164 HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction();
165 return [shape_size](
const LogicalBuffer& buffer) {
166 return shape_size(buffer.shape());
171 static tensorflow::mutex platform_compiler_mutex_;
173 static std::map<perftools::gputools::Platform::Id, CompilerFactory>*
174 GetPlatformCompilerFactories();
177 static std::map<perftools::gputools::Platform::Id, std::unique_ptr<Compiler>>*
178 GetPlatformCompilers();
181 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ namespace for xla
Definition: client_library.cc:26
Definition: compiler.h:92