14 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 15 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 19 #include <unordered_map> 21 #include "tensorflow/core/common_runtime/device_factory.h" 22 #include "tensorflow/core/common_runtime/local_device.h" 23 #include "tensorflow/core/framework/device_base.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/types.pb.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/mem.h" 28 #include "tensorflow/core/platform/mutex.h" 29 #include "tensorflow/core/platform/thread_annotations.h" 30 #include "tensorflow/core/public/session_options.h" 35 extern const char*
const DEVICE_CPU_XLA_JIT;
36 extern const char*
const DEVICE_GPU_XLA_JIT;
37 extern const char*
const DEVICE_XLA_CPU;
38 extern const char*
const DEVICE_XLA_GPU;
39 constexpr std::array<DataType, 4> kFloatTypes = {
40 {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
41 constexpr std::array<DataType, 9> kNumericTypes = {
42 {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
43 DT_COMPLEX64, DT_BFLOAT16}};
44 constexpr std::array<DataType, 9> kCpuAllTypes = {
45 {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
46 DT_COMPLEX64, DT_BOOL}};
47 constexpr std::array<DataType, 10> kGpuAllTypes = {
48 {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
49 DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
56 typedef OpKernel* (*Factory)(OpKernelConstruction*);
58 struct DeviceRegistration {
60 string compilation_device_name;
62 bool requires_compilation;
67 bool enable_jit_by_default;
69 bool compile_resource_ops =
false;
80 typedef bool (*BackendOpFilter)(KernelDef* kdef);
81 static void RegisterBackend(
const string& compilation_device_name,
82 gtl::ArraySlice<DataType> supported_types,
83 BackendOpFilter op_filter);
85 static std::vector<string> BackendNames();
87 static bool IsBackendRegistered(
const string& name);
90 static void RegisterCompilationDevice(
const string& device_name,
91 const DeviceRegistration& registration);
98 static bool GetCompilationDevice(
const string& device_name,
99 const DeviceRegistration** registration);
106 static std::vector<const KernelDef*> DeviceKernels(
107 const string& compilation_device_name,
108 bool include_compilation_only_kernels);
111 static const std::unordered_set<string>* CompileTimeConstantInputs(
114 friend class XlaBackendRegistrar;
115 friend class XlaOpRegistrar;
116 friend class XlaOpRegistrationBuilder;
124 std::set<DataType> supported_types;
127 BackendOpFilter op_filter;
130 std::vector<std::unique_ptr<KernelDef>> kernel_defs;
133 std::unordered_map<string, Backend> backends_ GUARDED_BY(mutex_);
135 std::unordered_map<string, DeviceRegistration> compilation_devices_
138 struct OpRegistration {
142 bool compilation_only =
false;
145 bool allow_resource_types =
false;
147 std::unordered_map<string, std::set<DataType>> type_constraints;
150 bool has_device_whitelist =
false;
151 std::unordered_set<string> device_whitelist;
153 std::unordered_set<string> compile_time_constant_inputs;
162 static bool IsCompatible(
const OpRegistration& x,
const OpRegistration& y);
166 std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_
169 bool jit_kernels_registered_ =
false;
173 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
174 kernel_registrars_ GUARDED_BY(mutex_);
182 #define REGISTER_XLA_OP(NAME, OP) \ 183 REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) 184 class XlaOpRegistrationBuilder {
187 static XlaOpRegistrationBuilder Name(StringPiece name);
189 XlaOpRegistrationBuilder& Device(StringPiece devices);
190 XlaOpRegistrationBuilder& Device(gtl::ArraySlice<StringPiece> devices);
193 XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
195 XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
196 gtl::ArraySlice<DataType> allowed);
199 XlaOpRegistrationBuilder& CompilationOnly();
201 XlaOpRegistrationBuilder& AllowResourceTypes();
203 XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
204 std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
205 XlaOpRegistry::Factory factory);
207 XlaOpRegistrationBuilder(StringPiece name);
208 std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
212 #define REGISTER_XLA_BACKEND(NAME, ...) \ 213 REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__) 215 class XlaOpRegistrar {
217 XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration);
219 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \ 220 REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP) 221 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ 222 static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ 223 XlaOpRegistrationBuilder::BUILDER.Build( \ 224 [](::tensorflow::OpKernelConstruction* context) \ 225 -> ::tensorflow::OpKernel* { return new OP(context); })); 226 class XlaBackendRegistrar {
228 XlaBackendRegistrar(StringPiece name, gtl::ArraySlice<DataType> types,
229 XlaOpRegistry::BackendOpFilter op_filter =
nullptr);
231 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \ 232 REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__) 233 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \ 234 static ::tensorflow::XlaBackendRegistrar \ 235 xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__); 237 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ static void RegisterCompilationKernels()
Definition: xla_op_registry.cc:153
Definition: xla_op_registry.h:54
Definition: compile.cc:35