| 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | |
| 16 | #ifndef TENSORFLOW_FRAMEWORK_OP_H_ |
| 17 | #define TENSORFLOW_FRAMEWORK_OP_H_ |
| 18 | |
| 19 | #include <functional> |
| 20 | #include <unordered_map> |
| 21 | |
| 22 | #include <vector> |
| 23 | #include "tensorflow/core/framework/op_def_builder.h" |
| 24 | #include "tensorflow/core/framework/op_def_util.h" |
| 25 | #include "tensorflow/core/framework/selective_registration.h" |
| 26 | #include "tensorflow/core/lib/core/errors.h" |
| 27 | #include "tensorflow/core/lib/core/status.h" |
| 28 | #include "tensorflow/core/lib/strings/str_util.h" |
| 29 | #include "tensorflow/core/lib/strings/strcat.h" |
| 30 | #include "tensorflow/core/platform/logging.h" |
| 31 | #include "tensorflow/core/platform/macros.h" |
| 32 | #include "tensorflow/core/platform/mutex.h" |
| 33 | #include "tensorflow/core/platform/thread_annotations.h" |
| 34 | #include "tensorflow/core/platform/types.h" |
| 35 | |
| 36 | namespace tensorflow { |
| 37 | |
| 38 | // Users that want to look up an OpDef by type name should take an |
| 39 | // OpRegistryInterface. Functions accepting a |
| 40 | // (const) OpRegistryInterface* may call LookUp() from multiple threads. |
| 41 | class OpRegistryInterface { |
| 42 | public: |
| 43 | virtual ~OpRegistryInterface(); |
| 44 | |
| 45 | // Returns an error status and sets *op_reg_data to nullptr if no OpDef is |
| 46 | // registered under that name, otherwise returns the registered OpDef. |
| 47 | // Caller must not delete the returned pointer. |
| 48 | virtual Status LookUp(const string& op_type_name, |
| 49 | const OpRegistrationData** op_reg_data) const = 0; |
| 50 | |
| 51 | // Shorthand for calling LookUp to get the OpDef. |
| 52 | Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const; |
| 53 | }; |
| 54 | |
| 55 | // The standard implementation of OpRegistryInterface, along with a |
| 56 | // global singleton used for registering ops via the REGISTER |
| 57 | // macros below. Thread-safe. |
| 58 | // |
| 59 | // Example registration: |
| 60 | // OpRegistry::Global()->Register( |
| 61 | // [](OpRegistrationData* op_reg_data)->Status { |
| 62 | // // Populate *op_reg_data here. |
| 63 | // return Status::OK(); |
| 64 | // }); |
| 65 | class OpRegistry : public OpRegistryInterface { |
| 66 | public: |
| 67 | typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; |
| 68 | |
| 69 | OpRegistry(); |
| 70 | ~OpRegistry() override; |
| 71 | |
| 72 | void Register(const OpRegistrationDataFactory& op_data_factory); |
| 73 | |
| 74 | Status LookUp(const string& op_type_name, |
| 75 | const OpRegistrationData** op_reg_data) const override; |
| 76 | |
| 77 | // Fills *ops with all registered OpDefs (except those with names |
| 78 | // starting with '_' if include_internal == false) sorted in |
| 79 | // ascending alphabetical order. |
| 80 | void Export(bool include_internal, OpList* ops) const; |
| 81 | |
| 82 | // Returns ASCII-format OpList for all registered OpDefs (except |
| 83 | // those with names starting with '_' if include_internal == false). |
| 84 | string DebugString(bool include_internal) const; |
| 85 | |
| 86 | // A singleton available at startup. |
| 87 | static OpRegistry* Global(); |
| 88 | |
| 89 | // Get all registered ops. |
| 90 | void GetRegisteredOps(std::vector<OpDef>* op_defs); |
| 91 | |
| 92 | // Get all `OpRegistrationData`s. |
| 93 | void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data); |
| 94 | |
| 95 | // Watcher, a function object. |
| 96 | // The watcher, if set by SetWatcher(), is called every time an op is |
| 97 | // registered via the Register function. The watcher is passed the Status |
| 98 | // obtained from building and adding the OpDef to the registry, and the OpDef |
| 99 | // itself if it was successfully built. A watcher returns a Status which is in |
| 100 | // turn returned as the final registration status. |
| 101 | typedef std::function<Status(const Status&, const OpDef&)> Watcher; |
| 102 | |
| 103 | // An OpRegistry object has only one watcher. This interface is not thread |
| 104 | // safe, as different clients are free to set the watcher any time. |
| 105 | // Clients are expected to atomically perform the following sequence of |
| 106 | // operations : |
| 107 | // SetWatcher(a_watcher); |
| 108 | // Register some ops; |
| 109 | // op_registry->ProcessRegistrations(); |
| 110 | // SetWatcher(nullptr); |
| 111 | // Returns a non-OK status if a non-null watcher is over-written by another |
| 112 | // non-null watcher. |
| 113 | Status SetWatcher(const Watcher& watcher); |
| 114 | |
| 115 | // Process the current list of deferred registrations. Note that calls to |
| 116 | // Export, LookUp and DebugString would also implicitly process the deferred |
| 117 | // registrations. Returns the status of the first failed op registration or |
| 118 | // Status::OK() otherwise. |
| 119 | Status ProcessRegistrations() const; |
| 120 | |
| 121 | // Defer the registrations until a later call to a function that processes |
| 122 | // deferred registrations are made. Normally, registrations that happen after |
| 123 | // calls to Export, LookUp, ProcessRegistrations and DebugString are processed |
| 124 | // immediately. Call this to defer future registrations. |
| 125 | void DeferRegistrations(); |
| 126 | |
| 127 | // Clear the registrations that have been deferred. |
| 128 | void ClearDeferredRegistrations(); |
| 129 | |
| 130 | private: |
| 131 | // Ensures that all the functions in deferred_ get called, their OpDef's |
| 132 | // registered, and returns with deferred_ empty. Returns true the first |
| 133 | // time it is called. Prints a fatal log if any op registration fails. |
| 134 | bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); |
| 135 | |
| 136 | // Calls the functions in deferred_ and registers their OpDef's |
| 137 | // It returns the Status of the first failed op registration or Status::OK() |
| 138 | // otherwise. |
| 139 | Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); |
| 140 | |
| 141 | // Add 'def' to the registry with additional data 'data'. On failure, or if |
| 142 | // there is already an OpDef with that name registered, returns a non-okay |
| 143 | // status. |
| 144 | Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) |
| 145 | const EXCLUSIVE_LOCKS_REQUIRED(mu_); |
| 146 | |
| 147 | mutable mutex mu_; |
| 148 | // Functions in deferred_ may only be called with mu_ held. |
| 149 | mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_); |
| 150 | // Values are owned. |
| 151 | mutable std::unordered_map<string, const OpRegistrationData*> registry_ |
| 152 | GUARDED_BY(mu_); |
| 153 | mutable bool initialized_ GUARDED_BY(mu_); |
| 154 | |
| 155 | // Registry watcher. |
| 156 | mutable Watcher watcher_ GUARDED_BY(mu_); |
| 157 | }; |
| 158 | |
| 159 | // An adapter to allow an OpList to be used as an OpRegistryInterface. |
| 160 | // |
| 161 | // Note that shape inference functions are not passed in to OpListOpRegistry, so |
| 162 | // it will return an unusable shape inference function for every op it supports; |
| 163 | // therefore, it should only be used in contexts where this is okay. |
| 164 | class OpListOpRegistry : public OpRegistryInterface { |
| 165 | public: |
| 166 | // Does not take ownership of op_list, *op_list must outlive *this. |
| 167 | OpListOpRegistry(const OpList* op_list); |
| 168 | ~OpListOpRegistry() override; |
| 169 | Status LookUp(const string& op_type_name, |
| 170 | const OpRegistrationData** op_reg_data) const override; |
| 171 | |
| 172 | private: |
| 173 | // Values are owned. |
| 174 | std::unordered_map<string, const OpRegistrationData*> index_; |
| 175 | }; |
| 176 | |
| 177 | // Support for defining the OpDef (specifying the semantics of the Op and how |
| 178 | // it should be created) and registering it in the OpRegistry::Global() |
| 179 | // registry. Usage: |
| 180 | // |
| 181 | // REGISTER_OP("my_op_name") |
| 182 | // .Attr("<name>:<type>") |
| 183 | // .Attr("<name>:<type>=<default>") |
| 184 | // .Input("<name>:<type-expr>") |
| 185 | // .Input("<name>:Ref(<type-expr>)") |
| 186 | // .Output("<name>:<type-expr>") |
| 187 | // .Doc(R"( |
| 188 | // <1-line summary> |
| 189 | // <rest of the description (potentially many lines)> |
| 190 | // <name-of-attr-input-or-output>: <description of name> |
| 191 | // <name-of-attr-input-or-output>: <description of name; |
| 192 | // if long, indent the description on subsequent lines> |
| 193 | // )"); |
| 194 | // |
| 195 | // Note: .Doc() should be last. |
| 196 | // For details, see the OpDefBuilder class in op_def_builder.h. |
| 197 | |
| 198 | namespace register_op { |
| 199 | |
| 200 | // OpDefBuilderWrapper is a templated class that is used in the REGISTER_OP |
| 201 | // calls. This allows the result of REGISTER_OP to be used in chaining, as in |
| 202 | // REGISTER_OP(a).Attr("...").Input("...");, while still allowing selective |
| 203 | // registration to turn the entire call-chain into a no-op. |
| 204 | template <bool should_register> |
| 205 | class OpDefBuilderWrapper; |
| 206 | |
| 207 | // Template specialization that forwards all calls to the contained builder. |
| 208 | template <> |
| 209 | class OpDefBuilderWrapper<true> { |
| 210 | public: |
| 211 | OpDefBuilderWrapper(const char name[]) : builder_(name) {} |
| 212 | OpDefBuilderWrapper<true>& Attr(StringPiece spec) { |
| 213 | builder_.Attr(spec); |
| 214 | return *this; |
| 215 | } |
| 216 | OpDefBuilderWrapper<true>& Input(StringPiece spec) { |
| 217 | builder_.Input(spec); |
| 218 | return *this; |
| 219 | } |
| 220 | OpDefBuilderWrapper<true>& Output(StringPiece spec) { |
| 221 | builder_.Output(spec); |
| 222 | return *this; |
| 223 | } |
| 224 | OpDefBuilderWrapper<true>& SetIsCommutative() { |
| 225 | builder_.SetIsCommutative(); |
| 226 | return *this; |
| 227 | } |
| 228 | OpDefBuilderWrapper<true>& SetIsAggregate() { |
| 229 | builder_.SetIsAggregate(); |
| 230 | return *this; |
| 231 | } |
| 232 | OpDefBuilderWrapper<true>& SetIsStateful() { |
| 233 | builder_.SetIsStateful(); |
| 234 | return *this; |
| 235 | } |
| 236 | OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() { |
| 237 | builder_.SetAllowsUninitializedInput(); |
| 238 | return *this; |
| 239 | } |
| 240 | OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) { |
| 241 | builder_.Deprecated(version, explanation); |
| 242 | return *this; |
| 243 | } |
| 244 | OpDefBuilderWrapper<true>& Doc(StringPiece text) { |
| 245 | builder_.Doc(text); |
| 246 | return *this; |
| 247 | } |
| 248 | OpDefBuilderWrapper<true>& SetShapeFn( |
| 249 | Status (*fn)(shape_inference::InferenceContext*)) { |
| 250 | builder_.SetShapeFn(fn); |
| 251 | return *this; |
| 252 | } |
| 253 | const ::tensorflow::OpDefBuilder& builder() const { return builder_; } |
| 254 | |
| 255 | private: |
| 256 | mutable ::tensorflow::OpDefBuilder builder_; |
| 257 | }; |
| 258 | |
| 259 | // Template specialization that turns all calls into no-ops. |
| 260 | template <> |
| 261 | class OpDefBuilderWrapper<false> { |
| 262 | public: |
| 263 | constexpr OpDefBuilderWrapper(const char name[]) {} |
| 264 | OpDefBuilderWrapper<false>& Attr(StringPiece spec) { return *this; } |
| 265 | OpDefBuilderWrapper<false>& Input(StringPiece spec) { return *this; } |
| 266 | OpDefBuilderWrapper<false>& Output(StringPiece spec) { return *this; } |
| 267 | OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; } |
| 268 | OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; } |
| 269 | OpDefBuilderWrapper<false>& SetIsStateful() { return *this; } |
| 270 | OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; } |
| 271 | OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; } |
| 272 | OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; } |
| 273 | OpDefBuilderWrapper<false>& SetShapeFn( |
| 274 | Status (*fn)(shape_inference::InferenceContext*)) { |
| 275 | return *this; |
| 276 | } |
| 277 | }; |
| 278 | |
| 279 | struct OpDefBuilderReceiver { |
| 280 | // To call OpRegistry::Global()->Register(...), used by the |
| 281 | // REGISTER_OP macro below. |
| 282 | // Note: These are implicitly converting constructors. |
| 283 | OpDefBuilderReceiver( |
| 284 | const OpDefBuilderWrapper<true>& wrapper); // NOLINT(runtime/explicit) |
| 285 | constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) { |
| 286 | } // NOLINT(runtime/explicit) |
| 287 | }; |
| 288 | } // namespace register_op |
| 289 | |
| 290 | #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name) |
| 291 | #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name) |
| 292 | #define REGISTER_OP_UNIQ(ctr, name) \ |
| 293 | static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ |
| 294 | TF_ATTRIBUTE_UNUSED = \ |
| 295 | ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \ |
| 296 | name)>(name) |
| 297 | |
| 298 | // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except |
| 299 | // that the op is registered unconditionally even when selective |
| 300 | // registration is used. |
| 301 | #define REGISTER_SYSTEM_OP(name) \ |
| 302 | REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name) |
| 303 | #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \ |
| 304 | REGISTER_SYSTEM_OP_UNIQ(ctr, name) |
| 305 | #define REGISTER_SYSTEM_OP_UNIQ(ctr, name) \ |
| 306 | static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ |
| 307 | TF_ATTRIBUTE_UNUSED = \ |
| 308 | ::tensorflow::register_op::OpDefBuilderWrapper<true>(name) |
| 309 | |
| 310 | } // namespace tensorflow |
| 311 | |
| 312 | #endif // TENSORFLOW_FRAMEWORK_OP_H_ |
| 313 | |