tf_1.8_xla_doc
xla_op_registry.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7  http://www.apache.org/licenses/LICENSE-2.0
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
15 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
16 #include <functional>
17 #include <memory>
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21 #include "tensorflow/core/common_runtime/device_factory.h"
22 #include "tensorflow/core/common_runtime/local_device.h"
23 #include "tensorflow/core/framework/device_base.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/mem.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/thread_annotations.h"
30 #include "tensorflow/core/public/session_options.h"
31 namespace tensorflow {
32 // Names of the XLA compilation devices. These are not user-visible, and are
33 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of
34 // a Tensorflow graph.
35 extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT"
36 extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT"
37 extern const char* const DEVICE_XLA_CPU;
38 extern const char* const DEVICE_XLA_GPU;
39 constexpr std::array<DataType, 4> kFloatTypes = {
40  {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
41 constexpr std::array<DataType, 9> kNumericTypes = {
42  {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
43  DT_COMPLEX64, DT_BFLOAT16}};
44 constexpr std::array<DataType, 9> kCpuAllTypes = {
45  {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
46  DT_COMPLEX64, DT_BOOL}};
47 constexpr std::array<DataType, 10> kGpuAllTypes = {
48  {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
49  DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
55  public:
56  typedef OpKernel* (*Factory)(OpKernelConstruction*);
57  // Describes how to compile operators assigned to a device.
58  struct DeviceRegistration {
59  // The name of the an XLA compilation device to use to compile code.
60  string compilation_device_name;
61  // Do operators assigned to this device require compilation?
62  bool requires_compilation;
63  // If !requires_compilation, should we try to JIT operators on this device
64  // when XLA JIT compilation is enabled globally via the SessionOptions?
65  // (It is still possible to explicitly mark operators to JIT compile, even
66  // if enable_jit_by_default is false.)
67  bool enable_jit_by_default;
68  // Enable compilation of operators that use DT_RESOURCE types?
69  bool compile_resource_ops = false;
70  };
71  // Registers an XLA backend. `compilation_device_name` is the name of the
72  // device used for symbolic execution during compilation. `supported_types`
73  // is the list of non-resource types supported by the device. Each operators
74  // will be registered for the intersection of the operator's supported types
75  // and the device's supported types. `backend_op_filter` is a function used
76  // to exclude or modify operator registrations on the device; it may be
77  // nullptr, in which case all ops are included.
78  // `backend_op_filter` should return true if the op should be registered on
79  // the device; it may optionally modify the KernelDef.
80  typedef bool (*BackendOpFilter)(KernelDef* kdef);
81  static void RegisterBackend(const string& compilation_device_name,
82  gtl::ArraySlice<DataType> supported_types,
83  BackendOpFilter op_filter);
84  // Returns the names of the registered backends.
85  static std::vector<string> BackendNames();
86  // Returns true iff a backend with the given name is registered.
87  static bool IsBackendRegistered(const string& name);
88  // Registers `device_name` for XLA compilation, using information from
89  // `registration`.
90  static void RegisterCompilationDevice(const string& device_name,
91  const DeviceRegistration& registration);
92  // Returns the JIT device name associated with 'device_name', setting
93  // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they
94  // are not null. Returns false and leaves the outputs unchanged if no matching
95  // JIT device is registered.
96  // '*enable_jit_by_default' is set to true if we should try to JIT using this
97  // device when the JIT is enabled via the Session OptimizerOptions.
98  static bool GetCompilationDevice(const string& device_name,
99  const DeviceRegistration** registration);
100  // Registers all JIT kernels on JIT devices, if not already registered.
101  // Does nothing otherwise.
102  static void RegisterCompilationKernels();
103  // Returns KernelDefs for compilation ops registered on
104  // 'compilation_device_name'. Does not include kernels registered as
105  // CompilationOnly, iff include_compilation_only_kernels=false.
106  static std::vector<const KernelDef*> DeviceKernels(
107  const string& compilation_device_name,
108  bool include_compilation_only_kernels);
109  // Returns the set of compile-time constant inputs to 'op'. Returns nullptr
110  // if the op is not registered.
111  static const std::unordered_set<string>* CompileTimeConstantInputs(
112  const string& op);
113  private:
114  friend class XlaBackendRegistrar;
115  friend class XlaOpRegistrar;
116  friend class XlaOpRegistrationBuilder;
117  static XlaOpRegistry& Instance();
118  XlaOpRegistry();
119  ~XlaOpRegistry();
120  mutex mutex_;
121  // Describes an XLA backend.
122  struct Backend {
123  // Which types are supported by this device?
124  std::set<DataType> supported_types;
125  // The per-backend operator filter function. See the comment on
126  // RegisterBackend() for details.
127  BackendOpFilter op_filter;
128  // KernelDefs built by RegisterCompilationKernels() for each op supported
129  // by the device.
130  std::vector<std::unique_ptr<KernelDef>> kernel_defs;
131  };
132  // Map from compilation device names to a description of the backend.
133  std::unordered_map<string, Backend> backends_ GUARDED_BY(mutex_);
134  // Map from Tensorflow device names to the corresponding JIT device metadata.
135  std::unordered_map<string, DeviceRegistration> compilation_devices_
136  GUARDED_BY(mutex_);
137  // A description of a Tensorflow operator that can be compiled to XLA.
138  struct OpRegistration {
139  string name;
140  // Should this operator be registered only on compilation devices, without a
141  // dummy kernel registered on the corresponding XLA device?
142  bool compilation_only = false;
143  // Should we allow resource types for type attributes? Used by _Arg to
144  // allow DT_RESOURCE.
145  bool allow_resource_types = false;
146  // Mapping from attribute name to a list of supported types.
147  std::unordered_map<string, std::set<DataType>> type_constraints;
148  // An optional whitelist of devices. If there is no whitelist, all devices
149  // are permitted.
150  bool has_device_whitelist = false;
151  std::unordered_set<string> device_whitelist;
152  // Names of arguments that must be compile-time constants.
153  std::unordered_set<string> compile_time_constant_inputs;
154  // Factory used to build OpKernels that perform symbolic execution.
155  Factory factory;
156  };
157  // Returns true if registrations x and y can both be added to the registry.
158  // This is always the case if they refer to different ops. If they refer to
159  // the same op name, they must: have the same values for compilation_only and
160  // allow_resource_types; use a device_whitelist; and their
161  // whitelists must not intersect.
162  static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
163  // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
164  // Registrations present under the same key must satisfy IsCompatible above,
165  // and this is checked during registration.
166  std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_
167  GUARDED_BY(mutex_);
168  // Have we already registered the JIT kernels on the JIT devices?
169  bool jit_kernels_registered_ = false;
170  // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel
171  // registrations created by RegisterCompilationKernels() and
172  // RegisterDeviceKernels().
173  std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
174  kernel_registrars_ GUARDED_BY(mutex_);
175 };
176 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example:
177 // REGISTER_XLA_OP(Name("Add"), AddOp);
178 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add".
179 //
180 // We don't use a variadic macro here because we don't expect JIT operators to
181 // be templated.
182 #define REGISTER_XLA_OP(NAME, OP) \
183  REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
184 class XlaOpRegistrationBuilder {
185  public:
186  // Starts an operator registration chain.
187  static XlaOpRegistrationBuilder Name(StringPiece name);
188  // Specifies a whitelist of devices on which the operator may run.
189  XlaOpRegistrationBuilder& Device(StringPiece devices);
190  XlaOpRegistrationBuilder& Device(gtl::ArraySlice<StringPiece> devices);
191  // Specifies a type constraint for a type variable attribute. Each constraint
192  // specifies the set of types that the type variable may assume.
193  XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
194  DataType allowed);
195  XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
196  gtl::ArraySlice<DataType> allowed);
197  // Specifies that a dummy copy of this operator should not be registered on
198  // XLA_* devices, but may be used during compilation.
199  XlaOpRegistrationBuilder& CompilationOnly();
200  // Allow DT_RESOURCE types for type parameters.
201  XlaOpRegistrationBuilder& AllowResourceTypes();
202  // Mark 'input_name' as an argument whose value must be known at compile-time.
203  XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
204  std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
205  XlaOpRegistry::Factory factory);
206  private:
207  XlaOpRegistrationBuilder(StringPiece name);
208  std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
209 };
210 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage:
211 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
212 #define REGISTER_XLA_BACKEND(NAME, ...) \
213  REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__)
214 // Implementation details.
215 class XlaOpRegistrar {
216  public:
217  XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration);
218 };
219 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \
220  REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP)
221 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \
222  static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
223  XlaOpRegistrationBuilder::BUILDER.Build( \
224  [](::tensorflow::OpKernelConstruction* context) \
225  -> ::tensorflow::OpKernel* { return new OP(context); }));
226 class XlaBackendRegistrar {
227  public:
228  XlaBackendRegistrar(StringPiece name, gtl::ArraySlice<DataType> types,
229  XlaOpRegistry::BackendOpFilter op_filter = nullptr);
230 };
231 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \
232  REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__)
233 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \
234  static ::tensorflow::XlaBackendRegistrar \
235  xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__);
236 } // namespace tensorflow
237 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
238 
static void RegisterCompilationKernels()
Definition: xla_op_registry.cc:153
Definition: xla_op_registry.h:54
Definition: compile.cc:35