1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_FRAMEWORK_OP_H_
17#define TENSORFLOW_FRAMEWORK_OP_H_
18
19#include <functional>
20#include <unordered_map>
21
22#include <vector>
23#include "tensorflow/core/framework/op_def_builder.h"
24#include "tensorflow/core/framework/op_def_util.h"
25#include "tensorflow/core/framework/selective_registration.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/lib/core/status.h"
28#include "tensorflow/core/lib/strings/str_util.h"
29#include "tensorflow/core/lib/strings/strcat.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/macros.h"
32#include "tensorflow/core/platform/mutex.h"
33#include "tensorflow/core/platform/thread_annotations.h"
34#include "tensorflow/core/platform/types.h"
35
36namespace tensorflow {
37
38// Users that want to look up an OpDef by type name should take an
39// OpRegistryInterface. Functions accepting a
40// (const) OpRegistryInterface* may call LookUp() from multiple threads.
41class OpRegistryInterface {
42 public:
43 virtual ~OpRegistryInterface();
44
45 // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
46 // registered under that name, otherwise returns the registered OpDef.
47 // Caller must not delete the returned pointer.
48 virtual Status LookUp(const string& op_type_name,
49 const OpRegistrationData** op_reg_data) const = 0;
50
51 // Shorthand for calling LookUp to get the OpDef.
52 Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
53};
54
55// The standard implementation of OpRegistryInterface, along with a
56// global singleton used for registering ops via the REGISTER
57// macros below. Thread-safe.
58//
59// Example registration:
60// OpRegistry::Global()->Register(
61// [](OpRegistrationData* op_reg_data)->Status {
62// // Populate *op_reg_data here.
63// return Status::OK();
64// });
65class OpRegistry : public OpRegistryInterface {
66 public:
67 typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
68
69 OpRegistry();
70 ~OpRegistry() override;
71
72 void Register(const OpRegistrationDataFactory& op_data_factory);
73
74 Status LookUp(const string& op_type_name,
75 const OpRegistrationData** op_reg_data) const override;
76
77 // Fills *ops with all registered OpDefs (except those with names
78 // starting with '_' if include_internal == false) sorted in
79 // ascending alphabetical order.
80 void Export(bool include_internal, OpList* ops) const;
81
82 // Returns ASCII-format OpList for all registered OpDefs (except
83 // those with names starting with '_' if include_internal == false).
84 string DebugString(bool include_internal) const;
85
86 // A singleton available at startup.
87 static OpRegistry* Global();
88
89 // Get all registered ops.
90 void GetRegisteredOps(std::vector<OpDef>* op_defs);
91
92 // Get all `OpRegistrationData`s.
93 void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
94
95 // Watcher, a function object.
96 // The watcher, if set by SetWatcher(), is called every time an op is
97 // registered via the Register function. The watcher is passed the Status
98 // obtained from building and adding the OpDef to the registry, and the OpDef
99 // itself if it was successfully built. A watcher returns a Status which is in
100 // turn returned as the final registration status.
101 typedef std::function<Status(const Status&, const OpDef&)> Watcher;
102
103 // An OpRegistry object has only one watcher. This interface is not thread
104 // safe, as different clients are free to set the watcher any time.
105 // Clients are expected to atomically perform the following sequence of
106 // operations :
107 // SetWatcher(a_watcher);
108 // Register some ops;
109 // op_registry->ProcessRegistrations();
110 // SetWatcher(nullptr);
111 // Returns a non-OK status if a non-null watcher is over-written by another
112 // non-null watcher.
113 Status SetWatcher(const Watcher& watcher);
114
115 // Process the current list of deferred registrations. Note that calls to
116 // Export, LookUp and DebugString would also implicitly process the deferred
117 // registrations. Returns the status of the first failed op registration or
118 // Status::OK() otherwise.
119 Status ProcessRegistrations() const;
120
121 // Defer the registrations until a later call to a function that processes
122 // deferred registrations are made. Normally, registrations that happen after
123 // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
124 // immediately. Call this to defer future registrations.
125 void DeferRegistrations();
126
127 // Clear the registrations that have been deferred.
128 void ClearDeferredRegistrations();
129
130 private:
131 // Ensures that all the functions in deferred_ get called, their OpDef's
132 // registered, and returns with deferred_ empty. Returns true the first
133 // time it is called. Prints a fatal log if any op registration fails.
134 bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
135
136 // Calls the functions in deferred_ and registers their OpDef's
137 // It returns the Status of the first failed op registration or Status::OK()
138 // otherwise.
139 Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
140
141 // Add 'def' to the registry with additional data 'data'. On failure, or if
142 // there is already an OpDef with that name registered, returns a non-okay
143 // status.
144 Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
145 const EXCLUSIVE_LOCKS_REQUIRED(mu_);
146
147 mutable mutex mu_;
148 // Functions in deferred_ may only be called with mu_ held.
149 mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
150 // Values are owned.
151 mutable std::unordered_map<string, const OpRegistrationData*> registry_
152 GUARDED_BY(mu_);
153 mutable bool initialized_ GUARDED_BY(mu_);
154
155 // Registry watcher.
156 mutable Watcher watcher_ GUARDED_BY(mu_);
157};
158
159// An adapter to allow an OpList to be used as an OpRegistryInterface.
160//
161// Note that shape inference functions are not passed in to OpListOpRegistry, so
162// it will return an unusable shape inference function for every op it supports;
163// therefore, it should only be used in contexts where this is okay.
164class OpListOpRegistry : public OpRegistryInterface {
165 public:
166 // Does not take ownership of op_list, *op_list must outlive *this.
167 OpListOpRegistry(const OpList* op_list);
168 ~OpListOpRegistry() override;
169 Status LookUp(const string& op_type_name,
170 const OpRegistrationData** op_reg_data) const override;
171
172 private:
173 // Values are owned.
174 std::unordered_map<string, const OpRegistrationData*> index_;
175};
176
177// Support for defining the OpDef (specifying the semantics of the Op and how
178// it should be created) and registering it in the OpRegistry::Global()
179// registry. Usage:
180//
181// REGISTER_OP("my_op_name")
182// .Attr("<name>:<type>")
183// .Attr("<name>:<type>=<default>")
184// .Input("<name>:<type-expr>")
185// .Input("<name>:Ref(<type-expr>)")
186// .Output("<name>:<type-expr>")
187// .Doc(R"(
188// <1-line summary>
189// <rest of the description (potentially many lines)>
190// <name-of-attr-input-or-output>: <description of name>
191// <name-of-attr-input-or-output>: <description of name;
192// if long, indent the description on subsequent lines>
193// )");
194//
195// Note: .Doc() should be last.
196// For details, see the OpDefBuilder class in op_def_builder.h.
197
198namespace register_op {
199
200// OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP
201// calls. This allows the result of REGISTER_OP to be used in chaining, as in
202// REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective
203// registration to turn the entire call-chain into a no-op.
204template <bool should_register>
205class OpDefBuilderWrapper;
206
207// Template specialization that forwards all calls to the contained builder.
208template <>
209class OpDefBuilderWrapper<true> {
210 public:
211 OpDefBuilderWrapper(const char name[]) : builder_(name) {}
212 OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
213 builder_.Attr(spec);
214 return *this;
215 }
216 OpDefBuilderWrapper<true>& Input(StringPiece spec) {
217 builder_.Input(spec);
218 return *this;
219 }
220 OpDefBuilderWrapper<true>& Output(StringPiece spec) {
221 builder_.Output(spec);
222 return *this;
223 }
224 OpDefBuilderWrapper<true>& SetIsCommutative() {
225 builder_.SetIsCommutative();
226 return *this;
227 }
228 OpDefBuilderWrapper<true>& SetIsAggregate() {
229 builder_.SetIsAggregate();
230 return *this;
231 }
232 OpDefBuilderWrapper<true>& SetIsStateful() {
233 builder_.SetIsStateful();
234 return *this;
235 }
236 OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
237 builder_.SetAllowsUninitializedInput();
238 return *this;
239 }
240 OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
241 builder_.Deprecated(version, explanation);
242 return *this;
243 }
244 OpDefBuilderWrapper<true>& Doc(StringPiece text) {
245 builder_.Doc(text);
246 return *this;
247 }
248 OpDefBuilderWrapper<true>& SetShapeFn(
249 Status (*fn)(shape_inference::InferenceContext*)) {
250 builder_.SetShapeFn(fn);
251 return *this;
252 }
253 const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
254
255 private:
256 mutable ::tensorflow::OpDefBuilder builder_;
257};
258
259// Template specialization that turns all calls into no-ops.
260template <>
261class OpDefBuilderWrapper<false> {
262 public:
263 constexpr OpDefBuilderWrapper(const char name[]) {}
264 OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; }
265 OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; }
266 OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; }
267 OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; }
268 OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
269 OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
270 OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
271 OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
272 OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
273 OpDefBuilderWrapper<false>& SetShapeFn(
274 Status (*fn)(shape_inference::InferenceContext*)) {
275 return *this;
276 }
277};
278
279struct OpDefBuilderReceiver {
280 // To call OpRegistry::Global()->Register(...), used by the
281 // REGISTER_OP macro below.
282 // Note: These are implicitly converting constructors.
283 OpDefBuilderReceiver(
284 const OpDefBuilderWrapper<true>& wrapper); // NOLINT(runtime/explicit)
285 constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) {
286 } // NOLINT(runtime/explicit)
287};
288} // namespace register_op
289
290#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
291#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
292#define REGISTER_OP_UNIQ(ctr, name) \
293 static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
294 TF_ATTRIBUTE_UNUSED = \
295 ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
296 name)>(name)
297
298// The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
299// that the op is registered unconditionally even when selective
300// registration is used.
301#define REGISTER_SYSTEM_OP(name) \
302 REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name)
303#define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \
304 REGISTER_SYSTEM_OP_UNIQ(ctr, name)
305#define REGISTER_SYSTEM_OP_UNIQ(ctr, name) \
306 static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
307 TF_ATTRIBUTE_UNUSED = \
308 ::tensorflow::register_op::OpDefBuilderWrapper<true>(name)
309
310} // namespace tensorflow
311
312#endif // TENSORFLOW_FRAMEWORK_OP_H_
313