1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ |
17 | #define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ |
18 | |
19 | #include <functional> |
20 | |
21 | #include <utility> |
22 | #include <vector> |
23 | #include "tensorflow/core/framework/allocator.h" |
24 | #include "tensorflow/core/framework/cancellation.h" |
25 | #include "tensorflow/core/framework/control_flow.h" |
26 | #include "tensorflow/core/framework/device_base.h" |
27 | #include "tensorflow/core/framework/kernel_def_builder.h" |
28 | #include "tensorflow/core/framework/node_def_util.h" |
29 | #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove |
30 | #include "tensorflow/core/framework/rendezvous.h" |
31 | #include "tensorflow/core/framework/selective_registration.h" |
32 | #include "tensorflow/core/framework/session_state.h" |
33 | #include "tensorflow/core/framework/tensor.h" |
34 | #include "tensorflow/core/framework/tensor_shape.h" |
35 | #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove |
36 | #include "tensorflow/core/framework/tracking_allocator.h" |
37 | #include "tensorflow/core/framework/types.h" |
38 | #include "tensorflow/core/framework/types.pb.h" |
39 | #include "tensorflow/core/framework/unique_tensor_references.h" |
40 | #include "tensorflow/core/lib/core/errors.h" |
41 | #include "tensorflow/core/lib/core/status.h" |
42 | #include "tensorflow/core/lib/gtl/array_slice.h" |
43 | #include "tensorflow/core/lib/gtl/manual_constructor.h" |
44 | #include "tensorflow/core/platform/env.h" |
45 | #include "tensorflow/core/platform/logging.h" |
46 | #include "tensorflow/core/platform/macros.h" |
47 | #include "tensorflow/core/platform/mutex.h" |
48 | #include "tensorflow/core/platform/thread_annotations.h" |
49 | #include "tensorflow/core/platform/types.h" |
50 | |
51 | namespace Eigen { |
52 | struct ThreadPoolDevice; |
53 | struct GpuDevice; |
54 | struct SyclDevice; |
55 | } // end namespace Eigen |
56 | |
57 | namespace tensorflow { |
58 | |
59 | namespace checkpoint { |
60 | class TensorSliceReaderCacheWrapper; |
61 | } // namespace checkpoint |
62 | |
63 | class AsyncOpKernel; |
64 | class CallFrameInterface; |
65 | class FunctionLibraryRuntime; |
66 | class OpKernelConstruction; // declared below |
67 | class OpKernelContext; // declared below, |
68 | class OpRegistryInterface; |
69 | class ResourceMgr; |
70 | class ScopedStepContainer; |
71 | class CollectiveExecutor; |
72 | class StepStatsCollector; |
73 | |
74 | class OpKernel { |
75 | public: |
76 | // OpKernel won't be instantiated by the scheduler, so you may perform |
77 | // expensive initialization in the descendant's constructor. |
78 | explicit OpKernel(OpKernelConstruction* context); |
79 | |
80 | // Specialized constructor that enables the descendant to provide a different |
81 | // `NodeDef` value. For example, this constructor can be used to provide a |
82 | // stripped-down `NodeDef` that does not contain the full set of attrs (such |
83 | // as tensor values) if the descendant stores them in a different form. |
84 | explicit OpKernel(OpKernelConstruction* context, |
85 | std::unique_ptr<const NodeDef> node_def); |
86 | |
87 | virtual ~OpKernel(); |
88 | |
89 | // An OpKernel's computation can be either synchronous or |
90 | // asynchronous. All OpKernel Compute() methods must be thread-safe as they |
91 | // may be called concurrently (e.g. by multiple executions of the same graph |
92 | // concurrently). |
93 | // |
94 | // Most OpKernels should compute synchronously. They should |
95 | // subclass OpKernel and override the Compute() method and have it |
96 | // return after completing the supplied work. |
97 | // |
98 | // A few special kernels might need to be asynchronous to bound the |
99 | // number of threads (e.g., network receive operations). These |
100 | // kernels must subclass AsyncOpKernel and override |
101 | // AsyncOpKernel::ComputeAsync(). |
102 | // |
103 | // In both cases, implementations of Compute() and ComputeAsync() |
104 | // get inputs and write outputs through the given OpKernelContext |
105 | // and returns a status via context->SetStatus(). They must be |
106 | // thread-safe. |
107 | |
108 | // Synchronous compute. |
109 | // |
110 | // "context" is guaranteed to be alive until Compute() returns. |
111 | virtual void Compute(OpKernelContext* context) = 0; |
112 | |
113 | // Returns nullptr iff this op kernel is synchronous. |
114 | virtual AsyncOpKernel* AsAsync() { return nullptr; } |
115 | |
116 | // Returns true iff this op kernel is considered "expensive". The |
117 | // runtime may use this flag to optimize graph execution for example |
118 | // to "inline" inexpensive kernels. |
119 | virtual bool IsExpensive() { return expensive_; } |
120 | |
121 | // Accessors. |
122 | const NodeDef& def() const { return *def_; } |
123 | const string& name() const; // Same as def().name() |
124 | const string& type_string() const; // Same as def().op() |
125 | const string& requested_device() const; // Same as def().device() |
126 | bool is_internal() const { return is_internal_; } |
127 | |
128 | int num_inputs() const { return input_types_.size(); } |
129 | DataType input_type(int i) const { return input_types_[i]; } |
130 | const DataTypeVector& input_types() const { return input_types_; } |
131 | const MemoryTypeVector& input_memory_types() const { |
132 | return input_memory_types_; |
133 | } |
134 | const string& requested_input(int i) const; // Same as def().input(i) |
135 | |
136 | int num_outputs() const { return output_types_.size(); } |
137 | DataType output_type(int o) const { return output_types_[o]; } |
138 | const DataTypeVector& output_types() const { return output_types_; } |
139 | const MemoryTypeVector& output_memory_types() const { |
140 | return output_memory_types_; |
141 | } |
142 | |
143 | Status InputRange(StringPiece input_name, int* start, int* stop) const; |
144 | Status OutputRange(StringPiece output_name, int* start, int* stop) const; |
145 | |
146 | // We allow legacy scalars within Google up until GraphDef version 6. |
147 | // TODO(irving): Remove when we can drop support for GraphDef version 5. |
148 | bool allow_legacy_scalars() const { |
149 | #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) |
150 | return graph_def_version_ < 6; |
151 | #else |
152 | return false; |
153 | #endif |
154 | } |
155 | |
156 | // Allow either scalars or (if allowing legacy scalars) shape (1,). |
157 | bool IsLegacyScalar(const TensorShape& shape) const { |
158 | return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 && |
159 | shape.dim_size(0) == 1); |
160 | } |
161 | |
162 | // Allow rank 1 or (if allowing legacy scalars) rank 0. |
163 | bool IsLegacyVector(const TensorShape& shape) const { |
164 | return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0); |
165 | } |
166 | |
167 | // Turn a shape Tensor into a TensorShape |
168 | // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars |
169 | Status MakeShape(const Tensor& shape, TensorShape* out) const; |
170 | |
171 | private: |
172 | const std::unique_ptr<const NodeDef> def_; |
173 | const DataTypeVector input_types_; |
174 | const MemoryTypeVector input_memory_types_; |
175 | const DataTypeVector output_types_; |
176 | const MemoryTypeVector output_memory_types_; |
177 | const int graph_def_version_; |
178 | const bool is_internal_; // True if this is an internal operation |
179 | NameRangeMap input_name_map_; |
180 | NameRangeMap output_name_map_; |
181 | bool expensive_; |
182 | |
183 | TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); |
184 | }; |
185 | |
186 | class AsyncOpKernel : public OpKernel { |
187 | public: |
188 | using OpKernel::OpKernel; // Lift OpKernel constructors. |
189 | |
190 | // Asynchronous compute. |
191 | // |
192 | // Implementations of ComputeAsync() must run "done" to signal the |
193 | // completion of the computation. "context" is guaranteed to be |
194 | // alive until the "done" callback starts. |
195 | typedef std::function<void()> DoneCallback; |
196 | virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0; |
197 | |
198 | AsyncOpKernel* AsAsync() final { return this; } |
199 | |
200 | void Compute(OpKernelContext* context) final; |
201 | |
202 | bool IsExpensive() override { return true; } |
203 | }; |
204 | |
205 | // Wraps a tensor that is held by an Op across calls to Compute(). For |
206 | // memory safety when using asynchronous devices like GPUs, the system |
207 | // must be notified when a Tensor is used inside an Op execution. The |
208 | // wrapper ensures that all uses of the Tensor are tracked, because in |
209 | // order to retrieve the Tensor the caller must use AccessTensor which |
210 | // notifies the context. |
211 | class PersistentTensor { |
212 | public: |
213 | PersistentTensor() {} |
214 | explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {} |
215 | |
216 | // Caller does not own the returned Tensor*. |
217 | Tensor* AccessTensor(OpKernelConstruction* context); |
218 | // Caller does not own the returned Tensor*. |
219 | Tensor* AccessTensor(OpKernelContext* context); |
220 | |
221 | // The check for initialization does not need to access the |
222 | // underlying tensor buffer. |
223 | bool IsInitialized() const { return tensor_.IsInitialized(); } |
224 | |
225 | int64 NumElements() const { return tensor_.NumElements(); } |
226 | |
227 | int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); } |
228 | |
229 | private: |
230 | Tensor tensor_; |
231 | }; |
232 | |
233 | class OpKernelConstruction { |
234 | public: |
235 | OpKernelConstruction(DeviceType device_type, DeviceBase* device, |
236 | Allocator* allocator, const NodeDef* node_def, |
237 | const OpDef* op_def, FunctionLibraryRuntime* flib, |
238 | const DataTypeSlice& input_types, |
239 | const MemoryTypeSlice& input_memory_types, |
240 | const DataTypeSlice& output_types, |
241 | const MemoryTypeSlice& output_memory_types, |
242 | int graph_def_version, Status* status); |
243 | |
244 | Env* env() const { return device_->env(); } |
245 | |
246 | // Allocation of tensors during kernel construction: |
247 | // |
248 | // It is legal to temporarily allocate scratch tensor storage during |
249 | // Op kernel construction. Scratch tensors should be allocated using |
250 | // allocate_temp below. Some kernels need to keep tensors in between |
251 | // invocations. If such a Tensor is allocated during kernel |
252 | // construction this must be done using allocate_persistent, and the |
253 | // Op may only store the returned PersistentTensor object. When the |
254 | // Tensor is needed in a subsequent invocation, it can be retrieved |
255 | // from the PersistentTensor using the AccessTensor method. This |
256 | // ensures that the system is made aware of any use of the tensor's |
257 | // allocated memory, which is needed for correctness on asynchronous |
258 | // devices such as GPUs. |
259 | |
260 | // Allocates a temporary Tensor of the specified type and shape. The |
261 | // Tensor must not be used after kernel construction is |
262 | // complete. See comment above. |
263 | Status allocate_temp(DataType type, const TensorShape& shape, |
264 | Tensor* out_temp); |
265 | |
266 | // Allocates a Tensor of the specified type and shape which the Op |
267 | // plans to maintain as persistent state. out_persistent holds the |
268 | // PersistentTensor which is the object the caller should store. For |
269 | // convenience, if out_tensor is non-null then it will be filled in |
270 | // with a Tensor* pointing to the newly-allocated tensor which the |
271 | // caller can use instead of calling |
272 | // out_persistent->AccessTensor. The caller does not own out_tensor |
273 | // and should not keep a copy of it. See comment above. |
274 | Status allocate_persistent(DataType type, const TensorShape& shape, |
275 | PersistentTensor* out_persistent, |
276 | Tensor** out_tensor); |
277 | |
278 | // User-supplied configuration of this operation. |
279 | const NodeDef& def() const { return *def_; } |
280 | |
281 | // For inspecting the inputs to this operation. |
282 | int num_inputs() const { return input_types_.size(); } |
283 | DataType input_type(int i) const { return input_types_[i]; } |
284 | const DataTypeSlice& input_types() const { return input_types_; } |
285 | const MemoryTypeSlice& input_memory_types() const { |
286 | return input_memory_types_; |
287 | } |
288 | |
289 | // For inspecting the outputs expected from this operation. |
290 | int num_outputs() const { return output_types_.size(); } |
291 | DataType output_type(int i) const { return output_types_[i]; } |
292 | const DataTypeSlice& output_types() const { return output_types_; } |
293 | const MemoryTypeSlice& output_memory_types() const { |
294 | return output_memory_types_; |
295 | } |
296 | |
297 | // If expected_inputs == inputs() and expected_outputs == output_types(), |
298 | // returns OK, else returns INVALID_ARGUMENT with an error message. |
299 | // Recommended for Ops with dynamic signatures. |
300 | Status MatchSignature(const DataTypeSlice expected_inputs, |
301 | const DataTypeSlice expected_outputs); |
302 | |
303 | // For recording configuration errors during construction. |
304 | void SetStatus(const Status& status); |
305 | const Status& status() const { return *status_; } |
306 | |
307 | // Look up the attr with name attr_name and set *value to its value. If no |
308 | // attr with attr_name is found in def(), or the attr does not have |
309 | // a matching type, a non-ok status will be returned. |
310 | template <class T> |
311 | Status GetAttr(StringPiece attr_name, T* value) const; |
312 | |
313 | // Return true if the attr_name is defined in def(). |
314 | bool HasAttr(StringPiece attr_name) const; |
315 | |
316 | // Return the device type. |
317 | const DeviceType& device_type() const { return device_type_; } |
318 | |
319 | // If not nullptr, the kernel can instantiate functions defined in |
320 | // the library. E.g., |
321 | // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...). |
322 | FunctionLibraryRuntime* function_library() const { return flib_; } |
323 | |
324 | // The GraphDef version whose behavior we should follow. |
325 | int graph_def_version() const { return graph_def_version_; } |
326 | |
327 | // Helper routines for the OP_REQUIRES macros |
328 | void CtxFailure(const Status& s); |
329 | void CtxFailureWithWarning(const Status& s); |
330 | void CtxFailure(const char* file, int line, const Status& s); |
331 | void CtxFailureWithWarning(const char* file, int line, const Status& s); |
332 | |
333 | // Unrecommended functions: these are functions that have some |
334 | // current uses but are not recommended for use, and may go away at |
335 | // some future major version release. |
336 | |
337 | // May be used, e.g., to get GPU handles, etc. |
338 | // |
339 | // Currently only used to call MakeTensorFromProto() for |
340 | // implementing ConstantOp for every device. See comments |
341 | // on Device::MakeTensorFromProto for longer-term replacement |
342 | // ideas. |
343 | DeviceBase* device() const { return device_; } |
344 | |
345 | private: |
346 | const DeviceType device_type_; |
347 | DeviceBase* const device_; |
348 | Allocator* allocator_; |
349 | const NodeDef* def_; |
350 | const OpDef* op_def_; |
351 | FunctionLibraryRuntime* flib_; |
352 | DataTypeSlice input_types_; |
353 | MemoryTypeSlice input_memory_types_; |
354 | DataTypeSlice output_types_; |
355 | MemoryTypeSlice output_memory_types_; |
356 | const int graph_def_version_; |
357 | Status* status_; |
358 | |
359 | // Allow op_def_ across from OpKernel, but not from subclasses. |
360 | // TODO(irving): Remove protos from this header entirely. |
361 | friend class OpKernel; |
362 | |
363 | TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); |
364 | }; |
365 | |
366 | // TODO(mrry): Consider converting to a random_access_iterator, and upgrading |
367 | // tensorflow::gtl::iterator_range to make the below container classes |
368 | // unnecessary. |
369 | template <typename ListType, typename ElementType> |
370 | class OpArgIterator { |
371 | public: |
372 | typedef OpArgIterator<ListType, ElementType> ME; |
373 | OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} |
374 | bool operator==(const ME& rhs) { |
375 | DCHECK(list_ == rhs.list_); |
376 | return i_ == rhs.i_; |
377 | } |
378 | bool operator!=(const ME& rhs) { |
379 | DCHECK(list_ == rhs.list_); |
380 | return i_ != rhs.i_; |
381 | } |
382 | void operator++() { ++i_; } |
383 | ElementType& operator*() { return (*list_)[i_]; } |
384 | |
385 | private: |
386 | const ListType* const list_; |
387 | int i_; |
388 | }; |
389 | |
390 | // Utility class for representing a list of immutable input tensors |
391 | // that are passed to the op as a single named argument. |
392 | class OpInputList { |
393 | public: |
394 | typedef OpArgIterator<OpInputList, const Tensor&> Iterator; |
395 | OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} |
396 | OpInputList(OpKernelContext* ctx, int start, int stop) |
397 | : ctx_(ctx), start_(start), stop_(stop) {} |
398 | OpInputList& operator=(const OpInputList& other) = default; |
399 | const Tensor& operator[](int i) const; |
400 | int size() const { return stop_ - start_; } |
401 | Iterator begin() const { return Iterator(this, 0); } |
402 | Iterator end() const { return Iterator(this, size()); } |
403 | |
404 | private: |
405 | OpKernelContext* ctx_; // not owned |
406 | int start_; |
407 | int stop_; |
408 | }; |
409 | |
410 | // Utility class for representing a list of mutable ("ref") input tensors |
411 | // that are passed to the op as a single named argument. |
412 | class OpMutableInputList { |
413 | public: |
414 | typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator; |
415 | OpMutableInputList(OpKernelContext* ctx, int start, int stop) |
416 | : ctx_(ctx), start_(start), stop_(stop) {} |
417 | OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {} |
418 | OpMutableInputList& operator=(const OpMutableInputList& other) = default; |
419 | Tensor at(int i, bool lock_held); |
420 | mutex* ref_mutex(int i); |
421 | int size() const { return stop_ - start_; } |
422 | Iterator begin() const { return Iterator(this, 0); } |
423 | Iterator end() const { return Iterator(this, size()); } |
424 | |
425 | private: |
426 | OpKernelContext* ctx_; // not owned |
427 | int start_; |
428 | int stop_; |
429 | }; |
430 | |
431 | // Utility class for representing a list of output tensors that are |
432 | // grouped as a single named output. |
433 | class OpOutputList { |
434 | public: |
435 | typedef OpArgIterator<OpOutputList, const Tensor*> Iterator; |
436 | OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {} |
437 | OpOutputList(OpKernelContext* ctx, int start, int stop) |
438 | : ctx_(ctx), start_(start), stop_(stop) {} |
439 | OpOutputList& operator=(const OpOutputList& other) = default; |
440 | Tensor* operator[](int i); |
441 | bool required(int i) const; |
442 | DataType expected_output_dtype(int i) const; |
443 | Status allocate(int i, const TensorShape& shape, Tensor** output); |
444 | void set(int i, const Tensor& tensor); |
445 | void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); |
446 | int size() const { return stop_ - start_; } |
447 | Iterator begin() const { return Iterator(this, 0); } |
448 | Iterator end() const { return Iterator(this, size()); } |
449 | |
450 | private: |
451 | OpKernelContext* ctx_; // not owned |
452 | int start_; |
453 | int stop_; |
454 | }; |
455 | |
456 | // Holds a tensor or tensor reference. For tensor references, we need |
457 | // a mutex to prevent concurrent access to the tensor. |
458 | struct TensorValue { |
459 | TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {} |
460 | TensorValue(Tensor* t) // NOLINT(runtime/explicit) |
461 | : mutex_if_ref(nullptr), tensor(t) {} |
462 | TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {} |
463 | Tensor* operator->() const { return tensor; } |
464 | bool is_ref() const { return mutex_if_ref != nullptr; } |
465 | |
466 | mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref |
467 | Tensor* tensor; |
468 | }; |
469 | |
470 | class OpKernelContext { |
471 | public: |
472 | // The first element of a WrappedAllocator is a "base" Allocator and |
473 | // the second element is that Allocator wrapped by a |
474 | // TrackingAllocator |
475 | typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator; |
476 | |
477 | // TODO(zhifengc): Do some cleanup of Params. |
478 | // The Params struct is passed in to initialize an OpKernelContext, |
479 | // and must outlive the OpKernelContext. |
480 | struct Params { |
481 | ~Params() { delete eigen_gpu_device; } |
482 | |
483 | // The step being executed. |
484 | int64 step_id = 0; |
485 | |
486 | // The op kernel being computed. |
487 | OpKernel* op_kernel = nullptr; |
488 | |
489 | // The device on which the kernel is running. |
490 | DeviceBase* device = nullptr; |
491 | |
492 | // The Eigen GPU device wrapper, which may include a per-op |
493 | // wrapped allocator. The concrete type of this object depends on |
494 | // the type of this->device, so eigen_gpu_device can't be an |
495 | // inline member and must be heap allocated. However, we don't |
496 | // want to allocate a new eigen_gpu_device for every Op that is |
497 | // executed. Instead this member is allocated on first use using |
498 | // ensure_eigen_gpu_device, and then if the Params structure is |
499 | // re-used for subsequent Ops, the eigen_gpu_device is |
500 | // ReInitialized in the OpKernelContext constructor. Unlike the |
501 | // other pointers in Params, this one is owned by Params. |
502 | PerOpGpuDevice* eigen_gpu_device = nullptr; |
503 | |
504 | inline void ensure_eigen_gpu_device() { |
505 | DCHECK(device); |
506 | if (nullptr == eigen_gpu_device) { |
507 | // Surprisingly, MakeGpuDevice will return nullptr if the |
508 | // device is not a GPU device. This is ok, since those devices |
509 | // will never use eigen_gpu_device. It seems better to have |
510 | // ensure_eigen_gpu_device fall through and regenerate the |
511 | // nullptr every time an OpKernelContext is instantiated, than |
512 | // to do an unnecessary allocation of a dummy eigen GPU |
513 | // device for CPU device Ops. |
514 | eigen_gpu_device = device->MakeGpuDevice(); |
515 | } |
516 | } |
517 | |
518 | bool track_allocations = false; |
519 | bool log_memory = false; |
520 | bool record_tensor_accesses = false; |
521 | |
522 | // Array indexed by output number for this node |
523 | const AllocatorAttributes* output_attr_array = nullptr; |
524 | |
525 | // Shared resources accessible by this op kernel invocation. |
526 | ResourceMgr* resource_manager = nullptr; |
527 | |
528 | // Per-step resources accessible by this op kernel invocation should be |
529 | // stored in this container.. |
530 | ScopedStepContainer* step_container = nullptr; |
531 | |
532 | // Mechanism used by this op kernel invocation to communicate with |
533 | // computations running on other devices. |
534 | Rendezvous* rendezvous = nullptr; |
535 | |
536 | // Mechanism for executing a collective op that needs to coordinate |
537 | // with parallel instances runing on other devices. |
538 | CollectiveExecutor* collective_executor = nullptr; |
539 | |
540 | // The session state for this op. |
541 | SessionState* session_state = nullptr; |
542 | |
543 | // The tensor store for this op. |
544 | TensorStore* tensor_store = nullptr; |
545 | |
546 | // Mechanism used by this op kernel invocation to register a callback |
547 | // for its cancellation. |
548 | CancellationManager* cancellation_manager = nullptr; |
549 | |
550 | // Inputs to this op kernel. |
551 | const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr; |
552 | bool is_input_dead = false; |
553 | |
554 | const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs = |
555 | nullptr; |
556 | |
557 | // Device contexts. |
558 | const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts = |
559 | nullptr; |
560 | DeviceContext* op_device_context = nullptr; |
561 | |
562 | // Control-flow op supports. |
563 | FrameAndIter frame_iter; |
564 | |
565 | // Function call supports. |
566 | CallFrameInterface* call_frame = nullptr; |
567 | FunctionLibraryRuntime* function_library = nullptr; |
568 | std::function<void(std::function<void()>)>* runner = nullptr; |
569 | StepStatsCollector* stats_collector = nullptr; |
570 | |
571 | // TensorSliceReaderCache support. |
572 | checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; |
573 | |
574 | // Support for forwarding reservations (used by ScopedAllocator). |
575 | static const int kNeverForward = -2; |
576 | static const int kNoReservation = -1; |
577 | // Values in [0,...) represent reservations for the indexed output. |
578 | const int* forward_from_array = nullptr; |
579 | }; |
580 | |
581 | // params must outlive the OpKernelContext. |
582 | explicit OpKernelContext(Params* params); |
583 | OpKernelContext(Params* params, int noutputs); |
584 | ~OpKernelContext(); |
585 | |
586 | Env* env() const { return params_->device->env(); } |
587 | |
588 | int64 step_id() const { return params_->step_id; } |
589 | |
590 | const OpKernel& op_kernel() const { return *params_->op_kernel; } |
591 | |
592 | // Input/output signature. |
593 | |
594 | int num_inputs() const { return params_->inputs->size(); } |
595 | DataType input_dtype(int index) const; |
596 | Status input_dtype(StringPiece name, DataType* dtype) const; |
597 | MemoryType input_memory_type(int index) const; |
598 | |
599 | int num_outputs() const { return outputs_.size(); } |
600 | DataType expected_output_dtype(int index) const; |
601 | MemoryType output_memory_type(int index) const; |
602 | |
603 | // Input |
604 | |
605 | // Returns an immutable input tensor. May only be used for non-Ref |
606 | // inputs. For Ref inputs use mutable_input below. |
607 | // REQUIRES: !IsRefType(input_dtype(index)) |
608 | // TODO(mrry): Convert this to return Status. |
609 | const Tensor& input(int index); |
610 | |
611 | // Returns the named immutable input tensor in "tensor", as defined |
612 | // in the OpDef. May only be used for non-Ref inputs. For Ref inputs |
613 | // use mutable_input below. |
614 | // REQUIRES: !IsRefType(input_dtype(index)) |
615 | // REQUIRES: the named input must not be a list. |
616 | Status input(StringPiece name, const Tensor** tensor); |
617 | |
618 | // Returns the named list-valued immutable input in "list", as |
619 | // defined in the OpDef. If the named output is not list-valued, |
620 | // returns a one-element list. May only be used for non-Ref |
621 | // inputs. For Ref inputs use mutable_input below. |
622 | // REQUIRES: !IsRefType(input_dtype(index)) |
623 | Status input_list(StringPiece name, OpInputList* list); |
624 | |
625 | // For mutable inputs, use the following together to make sure there |
626 | // is no concurrent access to mutable_input(), e.g.: |
627 | // { |
628 | // Tensor& t = context->mutable_input(index); |
629 | // mutex_lock lock(*context->input_ref_mutex(index)); |
630 | // // modify the values in t |
631 | // } |
632 | // REQUIRES: IsRefType(input_dtype(index)) |
633 | Status input_ref_mutex(StringPiece name, mutex** out_mutex); |
634 | |
635 | // Returns a mutable input tensor. Must be used to access Ref |
636 | // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may |
637 | // modify the values stored in the Tensor buffer, and modifications |
638 | // will be visible to other Ops reading the same ref tensor. If |
639 | // !lock_held the input mutex will be acquired before returning the |
640 | // Tensor. |
641 | // TODO(mrry): Convert this to return Status. |
642 | Tensor mutable_input(int index, bool lock_held); |
643 | |
644 | // Returns the named mutable input tensor in "tensor", as defined in |
645 | // the OpDef. Must be used to access Ref inputs. The values stored |
646 | // in the Tensor buffer may be modified, and modifications will be |
647 | // visible to other Ops reading the same ref tensor. If !lock_held |
648 | // the input mutex will be acquired before returning the Tensor. |
649 | // REQUIRES: the named input must not be a list. |
650 | // REQUIRES: the named input must be a ref tensor. |
651 | Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); |
652 | |
653 | // Returns the named list-valued mutable input in "list", as defined |
654 | // in the OpDef. If the named input is not list-valued, returns a |
655 | // one-element list. Must be used to access Ref inputs. The values |
656 | // stored in the Tensor buffer may be modified, and modifications |
657 | // will be visible to other Ops reading the same ref tensor. |
658 | // REQUIRES: the named input must be a ref tensor. |
659 | Status mutable_input_list(StringPiece name, OpMutableInputList* list); |
660 | |
661 | // Replace the corresponding Ref Input to use the storage buffer |
662 | // used by tensor. If !lock_held the input mutex will be acquired |
663 | // before returning the Tensor. |
664 | // REQUIRES: IsRefType(input_dtype(index)). |
665 | void replace_ref_input(int index, const Tensor& tensor, bool lock_held); |
666 | |
667 | // Replace the corresponding named Ref Input to use the storage |
668 | // buffer used by tensor. If !lock_held the input mutex will be |
669 | // acquired before returning the Tensor. |
670 | // REQUIRES: IsRefType(input_dtype(index)). |
671 | Status replace_ref_input(StringPiece name, const Tensor& tensor, |
672 | bool lock_held); |
673 | |
674 | // Deletes the Tensor object used as the Ref Input at |
675 | // input_index. This is not usually necessary and should be used |
676 | // with caution. If !lock_held the input mutex will be acquired |
677 | // before returning the Tensor. |
678 | // REQUIRES: IsRefType(input_dtype(input_index)). |
679 | void delete_ref_input(int input_index, bool lock_held); |
680 | |
681 | // Return true if there is input at the given index. An operator has no |
682 | // input at index if its tensor is null. This is primarily used by the |
683 | // merge operator. |
684 | // TODO(mrry): Convert this to return Status. |
685 | bool has_input(int index) const; |
686 | |
687 | // Returns true if all inputs are the same shape, otherwise sets the |
688 | // status to a non-OK value and returns false. |
689 | // Usage: if (!context->ValidateInputsAreSameShape(this)) return; |
690 | bool ValidateInputsAreSameShape(OpKernel* op); |
691 | |
692 | // Input to output forwarding. |
693 | |
694 | // Set the output Ref Tensor at output_index to be an alias of the |
695 | // input Ref Tensor at input_index. |
696 | // REQUIRES: IsRefType(input_dtype(input_index)). |
697 | // REQUIRES: IsRefType(output_dtype(output_index)). |
698 | void forward_ref_input_to_ref_output(int input_index, int output_index); |
699 | |
700 | // Returns true when an alias to input[input_index], reshaped to output_shape, |
701 | // which is safe to use for in-place computation was written to *output. |
702 | // Returns false if input[input_index] has a refcount greater than one, or if |
703 | // its type does not match the expected output type of output[output_index], |
704 | // or the number of elements in input[input_index] does not equal the number |
705 | // of elements in output_shape. |
706 | bool forward_input_to_output_with_shape(int input_index, int output_index, |
707 | const TensorShape& output_shape, |
708 | Tensor** output) TF_MUST_USE_RESULT; |
709 | Status forward_input_to_output_with_shape(StringPiece input_name, |
710 | StringPiece output_name, |
711 | const TensorShape& output_shape, |
712 | Tensor** output) TF_MUST_USE_RESULT; |
713 | |
714 | // Returns a pointer to a Tensor aliasing the underlying buffer backing |
715 | // input[input_index] iff |
716 | // * input[input_index] is not a ref, |
717 | // * the data type, shape, memory type, and allocator attributes of |
718 | // input[input_index] are compatible with those given in dtype, shape, |
719 | // memory_type, and attr, |
720 | // * refcount on the underlying buffer is one. |
721 | // * Either there is no forwarding reservation for either input_index |
722 | // or output_index or the specified input is reserved for the specified |
723 | // output. More precisely: |
724 | // |
725 | // These cases mean neither input nor output has a reservation: |
726 | // forward_from_array = nullptr |
727 | // OR (input_index is not in forward_from_array AND |
728 | // (output_index == kNoReservation OR |
729 | // forward_from_array[output_index] == kNoReservation)) |
730 | // |
731 | // This case means that input_index is reserved for output_index: |
732 | // forward_from_array[output_index] == input_index |
733 | // |
734 | // This case means the output is reserved to always be allocated, |
735 | // never assigned a forwarded input: |
736 | // forward_from_array[output_index] == kNeverForward |
737 | // |
738 | // Otherwise returns nullptr. |
739 | // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic, |
740 | // forwarding is only safe if there are no reads via __ldg() after writes |
741 | // to the same address. |
742 | std::unique_ptr<Tensor> forward_input( |
743 | int input_index, int output_index, DataType output_dtype, |
744 | const TensorShape& output_shape, MemoryType output_memory_type, |
745 | const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT; |
746 | |
747 | // Tries to forward one of the inputs given in input_indices to |
748 | // output[output_index]. If none of the given inputs can be forwarded, calls |
749 | // allocate_output() to allocate a new output buffer. |
750 | Status forward_input_or_allocate_output( |
751 | gtl::ArraySlice<int> candidate_input_indices, int output_index, |
752 | const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; |
753 | Status forward_input_or_allocate_output( |
754 | gtl::ArraySlice<StringPiece> candidate_input_names, |
755 | StringPiece output_name, const TensorShape& output_shape, |
756 | Tensor** output) TF_MUST_USE_RESULT; |
757 | |
758 | // Tries to reuse one of the inputs given in input_indices as a temporary. |
759 | // If none of the given inputs can be forwarded, calls |
760 | // allocate_temp() to allocate a new temporary buffer. |
761 | Status forward_input_or_allocate_temp( |
762 | gtl::ArraySlice<int> candidate_input_indices, DataType type, |
763 | const TensorShape& shape, const AllocatorAttributes& allocator_attr, |
764 | Tensor* out_temp) TF_MUST_USE_RESULT; |
765 | |
766 | Status forward_input_or_allocate_temp( |
767 | gtl::ArraySlice<int> candidate_input_indices, DataType type, |
768 | const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT { |
769 | return forward_input_or_allocate_temp(candidate_input_indices, type, shape, |
770 | AllocatorAttributes(), out_temp); |
771 | } |
772 | |
773 | // Output |
774 | |
775 | // Returns the named list-valued output in "list", as defined in the OpDef. |
776 | // If the named output is not list-valued, returns a one-element list. |
777 | Status output_list(StringPiece name, OpOutputList* list); |
778 | |
779 | // If output_required(index) returns true, the OpKernel's Compute() method |
780 | // should call allocate_output(index, ...), set_output(index, ...), |
781 | // set_output_ref(index, ...), or set the status to a non-ok value. |
782 | // If it returns false, it may output, but is not required to do so. |
783 | // TODO(mrry): Convert this to return Status, and implement a string |
784 | // name version. |
785 | bool output_required(int index) const { |
786 | return true; // TODO(josh11b): implement |
787 | } |
788 | |
789 | // Allocation of tensors during kernel execution inside the Compute |
790 | // method: |
791 | // |
792 | // There are three methods to allocate Tensors when an Op kernel |
793 | // executes. |
794 | // |
795 | // 1) allocate_persistent. This is only needed for Tensors that will |
796 | // be stored by the Op between invocations, and it *must* be used |
797 | // for those Tensors. The call returns a PersistentTensor, and that |
798 | // is the only object the Op is allowed to hold on to between |
799 | // invocations. When the Tensor is needed in a subsequent |
800 | // invocation, it can be retrieved from the PersistentTensor using |
801 | // the AccessTensor method. This ensures that the system is made |
802 | // aware of any use of the tensor's allocated memory, which is |
803 | // needed for correctness on asynchronous devices such as GPUs. |
804 | // |
805 | // 2) allocate_output. This should be used to allocate any tensor |
806 | // that is going to be used as an output from the Op at the end of |
807 | // the current execution. The caller indicates which output the |
808 | // Tensor will be assigned to, and the call returns the |
809 | // newly-allocated Tensor. The Tensor can subsequently be assigned |
810 | // to during kernel execution, and will be used as the designated |
811 | // output when the kernel execution completes. |
812 | // |
813 | // 3) allocate_temp. This should be used to allocate any scratch |
814 | // storage that is needed while the kernel is executing, and will |
815 | // not be retained by the Op. |
816 | // |
817 | // In some cases a Tensor needs to be used as an output even though |
818 | // it was previously allocated elsewhere. The Tensor may have been |
819 | // passed as an input, or stored in a PersistentTensor during a |
820 | // previous kernel execution, or allocated earlier in the kernel |
821 | // execution at a time when it was not known which output it would |
822 | // be assigned to. In this case the kernel can use set_output or |
823 | // set_output_ref to indicate that the tensor should be used as the |
824 | // designated output. It is legal to use any previously-allocated |
825 | // Tensor as an argument to set_output or set_output_ref, including |
826 | // Tensors allocated via allocate_temp. There may be a performance |
827 | // penalty to using a Tensor that was not allocated using |
828 | // allocate_output. This is because allocate_output uses the |
829 | // AllocatorAttributes stored in output_attr_array for the |
830 | // designated output. In some cases, using the wrong attributes may |
831 | // cause an extra copy of the Tensor's buffer. |
832 | |
833 | // Allocates output for the specified output index with shape. |
834 | // OpKernelContext retains ownership of the returned pointer. See |
835 | // comment above. |
836 | // |
837 | // If memory allocation fails, returns an error status. |
838 | // |
839 | // REQUIRES: !IsRefType(expected_output_dtype(index)) |
840 | Status allocate_output(int index, const TensorShape& shape, |
841 | Tensor** tensor) TF_MUST_USE_RESULT; |
842 | Status allocate_output(StringPiece name, const TensorShape& shape, |
843 | Tensor** tensor) TF_MUST_USE_RESULT; |
844 | // The following methods use the supplied attributes instead of |
845 | // those in output_attr_array. The caller is responsible for |
846 | // ensuring that the attributes are "compatible" with the |
847 | // output_attr_array, e.g. the tensor is allocated on the correct |
848 | // device. See comment above. |
849 | Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, |
850 | AllocatorAttributes attr) TF_MUST_USE_RESULT; |
851 | Status allocate_output(StringPiece name, const TensorShape& shape, |
852 | Tensor** tensor, |
853 | AllocatorAttributes attr) TF_MUST_USE_RESULT; |
854 | |
855 | // Allocates a temporary Tensor of the specified type and |
856 | // shape. Devices such as GPUs that enqueue Ops for lazy execution |
857 | // may retain references to the temporary tensors after the Op's |
858 | // Compute method has run. See comment above. |
859 | Status allocate_temp(DataType type, const TensorShape& shape, |
860 | Tensor* out_temp, AllocatorAttributes allocator_attr, |
861 | const AllocationAttributes& allocation_attr); |
862 | Status allocate_temp(DataType type, const TensorShape& shape, |
863 | Tensor* out_temp, AllocatorAttributes allocator_attr) { |
864 | return allocate_temp(type, shape, out_temp, allocator_attr, |
865 | AllocationAttributes()); |
866 | } |
867 | Status allocate_temp(DataType type, const TensorShape& shape, |
868 | Tensor* out_temp) { |
869 | return allocate_temp(type, shape, out_temp, AllocatorAttributes()); |
870 | } |
871 | |
872 | // Allocates a Tensor of the specified type and shape which the Op |
873 | // plans to maintain as persistent state. out_persistent holds the |
874 | // PersistentTensor which is the object the caller should store. For |
875 | // convenience, if out_tensor is non-null then it will be filled in |
876 | // with a Tensor* pointing to the newly-allocated tensor which the |
877 | // caller can use instead of calling |
878 | // out_persistent->AccessTensor. The caller does not own out_tensor |
879 | // and should not keep a copy of it. See comment above. |
880 | Status allocate_persistent(DataType type, const TensorShape& shape, |
881 | PersistentTensor* out_persistent, |
882 | Tensor** out_tensor, AllocatorAttributes attr); |
883 | Status allocate_persistent(DataType type, const TensorShape& shape, |
884 | PersistentTensor* out_persistent, |
885 | Tensor** out_tensor) { |
886 | return allocate_persistent(type, shape, out_persistent, out_tensor, |
887 | AllocatorAttributes()); |
888 | } |
889 | |
890 | // Copies a tensor (allocated by the caller) to the specified output |
891 | // index. REQUIRES: !IsRefType(expected_output_dtype(index)) |
892 | // REQUIRES: 'tensor' must have the same MemoryType as |
893 | // output_memory_types[index]. See comment above. |
894 | Status set_output(StringPiece name, const Tensor& tensor); |
895 | |
896 | // To output a reference. Caller retains ownership of mu and tensor_for_ref, |
897 | // and they must outlive all uses within the step. See comment above. |
898 | // REQUIRES: IsRefType(expected_output_dtype(index)) |
899 | Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref); |
900 | |
901 | // Returns nullptr if allocate_output() or set_output() have not been called. |
902 | Status mutable_output(StringPiece name, Tensor** tensor); |
903 | |
904 | // Transfers ownership of an output tensor to the caller. |
905 | // NOTE: For non-reference outputs, the caller takes responsibility |
906 | // for deletion. For reference outputs, the caller does NOT take |
907 | // responsibility for deletion. |
908 | Status release_output(StringPiece name, TensorValue* value); |
909 | |
910 | // Records device specific state about how the input tensors were |
911 | // computed. |
912 | // |
913 | // If using the templated function, the type must be a subclass |
914 | // of DeviceContext. |
915 | // |
916 | // Get the DeviceContext used for the index input. Returns nullptr |
917 | // if no DeviceContext was provided. |
918 | template <typename T> |
919 | T* input_device_context(int index); |
920 | DeviceContext* input_device_context(int index); |
921 | |
922 | // Return the DeviceContext that should be used for this Op. |
923 | // |
924 | // If using the templated function, the type must be a subclass |
925 | // of DeviceContext. |
926 | // |
927 | // Returns nullptr if the device did not provide one. |
928 | template <typename T> |
929 | T* op_device_context(); |
930 | DeviceContext* op_device_context() { |
931 | DeviceContext* ret = params_->op_device_context; |
932 | if (ret == nullptr) { |
933 | auto* dev_info = device()->tensorflow_gpu_device_info(); |
934 | if (dev_info) ret = dev_info->default_context; |
935 | } |
936 | return ret; |
937 | } |
938 | |
939 | AllocatorAttributes input_alloc_attr(int index) const { |
940 | if (params_->input_alloc_attrs == nullptr) { |
941 | return AllocatorAttributes(); |
942 | } else { |
943 | DCHECK_GE(index, 0); |
944 | DCHECK_LT(index, params_->input_alloc_attrs->size()); |
945 | return (*params_->input_alloc_attrs)[index]; |
946 | } |
947 | } |
948 | |
949 | AllocatorAttributes output_alloc_attr(int index) const { |
950 | return params_->output_attr_array[index]; |
951 | } |
952 | |
953 | gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const { |
954 | mutex_lock lock(mu_); |
955 | gtl::InlinedVector<WrappedAllocator, 4> retrieved = wrapped_allocators_; |
956 | return retrieved; |
957 | } |
958 | |
959 | // Communication. |
960 | // |
961 | // An op kernel communicates with outside environment through |
962 | // Rendezvous Send() and Recv(). |
963 | Rendezvous* rendezvous() const { return params_->rendezvous; } |
964 | |
965 | CollectiveExecutor* collective_executor() const { |
966 | return params_->collective_executor; |
967 | } |
968 | |
969 | // An op kernel can access the session state it belongs to. |
970 | SessionState* session_state() const { return params_->session_state; } |
971 | |
972 | // An op kernel can access the tensor store of the run it belongs to. |
973 | TensorStore* tensor_store() const { return params_->tensor_store; } |
974 | |
975 | // Function call support. |
976 | // |
977 | // If this kernel invocation is within a function execution, |
978 | // call_frame() returns the call frame for the function call. |
979 | CallFrameInterface* call_frame() const { return params_->call_frame; } |
980 | |
981 | // If not nullptr, the kernel invoke functions defined in the |
982 | // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). |
983 | FunctionLibraryRuntime* function_library() const { |
984 | return params_->function_library; |
985 | } |
986 | |
987 | std::function<void(std::function<void()>)>* runner() const { |
988 | return params_->runner; |
989 | } |
990 | StepStatsCollector* stats_collector() const { |
991 | return params_->stats_collector; |
992 | } |
993 | |
994 | // Shared resources accessible to this kernel. |
995 | ResourceMgr* resource_manager() const { return params_->resource_manager; } |
996 | |
997 | checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { |
998 | return params_->slice_reader_cache; |
999 | } |
1000 | |
1001 | // Execution. |
1002 | // |
1003 | // OpKernels can use these eigen devices to carry out their |
1004 | // numerical computation. |
1005 | const Eigen::ThreadPoolDevice& eigen_cpu_device() const { |
1006 | return *device()->eigen_cpu_device(); |
1007 | } |
1008 | const Eigen::GpuDevice& eigen_gpu_device() const { |
1009 | return params_->eigen_gpu_device->device(); |
1010 | } |
1011 | #ifdef TENSORFLOW_USE_SYCL |
1012 | const Eigen::SyclDevice& eigen_sycl_device() const { |
1013 | return *device()->eigen_sycl_device(); |
1014 | } |
1015 | #endif |
1016 | template <typename EigenDeviceType> |
1017 | const EigenDeviceType& eigen_device() const; |
1018 | |
1019 | // Error handling. |
1020 | |
1021 | // If expected_inputs == inputs() and expected_outputs == output_types(), |
1022 | // returns OK, else returns INVALID_ARGUMENT with an error message. |
1023 | // Recommended for Ops with dynamic signatures, where validation can only |
1024 | // be performed at runtime. |
1025 | Status MatchSignature(const DataTypeSlice expected_inputs, |
1026 | const DataTypeSlice expected_outputs); |
1027 | |
1028 | // An OpKernel should call SetStatus() if Compute() encounters an |
1029 | // error. |
1030 | void SetStatus(const Status& status); |
1031 | const Status& status() const { return status_; } |
1032 | |
1033 | // Cancellation. |
1034 | // |
1035 | // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an |
1036 | // example of how to use this API. |
1037 | CancellationManager* cancellation_manager() const { |
1038 | return params_->cancellation_manager; |
1039 | } |
1040 | |
1041 | // Other accessors. |
1042 | |
1043 | // For control flow. |
1044 | FrameAndIter frame_iter() const { return params_->frame_iter; } |
1045 | bool is_input_dead() const { return params_->is_input_dead; } |
1046 | bool* is_output_dead() { return &is_output_dead_; } |
1047 | |
1048 | // May be used, e.g., to get GPU handles, etc. |
1049 | // TODO(tucker): Add example usage. |
1050 | DeviceBase* device() const { return params_->device; } |
1051 | |
1052 | // Retrieve list of referenced tensors in out_vector. Once this is |
1053 | // called, it is not legal to reference any more tensors. Should |
1054 | // not be called from Op kernels. |
1055 | void retrieve_accessed_tensors(TensorReferenceVector* out_vector); |
1056 | |
1057 | // Per-step container for use by white-listed internal ops. |
1058 | ScopedStepContainer* step_container() const { |
1059 | return params_->step_container; |
1060 | } |
1061 | |
1062 | // Helper routines for the OP_REQUIRES macros |
1063 | void CtxFailure(const Status& s); |
1064 | void CtxFailureWithWarning(const Status& s); |
1065 | void CtxFailure(const char* file, int line, const Status& s); |
1066 | void CtxFailureWithWarning(const char* file, int line, const Status& s); |
1067 | |
1068 | // Unrecommended functions: these are functions that have some |
1069 | // current uses but are not recommended for use, and may go away at |
1070 | // some future major version release. |
1071 | // |
1072 | // The following functions all have versions that return Status |
1073 | // to capture error conditions, and are strongly preferred. |
1074 | Tensor* mutable_output(int index); |
1075 | void set_output(int index, const Tensor& tensor); |
1076 | mutex* input_ref_mutex(int index); |
1077 | void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); |
1078 | TensorValue release_output(int index); |
1079 | |
1080 | bool track_allocations() const { return params_->track_allocations; } |
1081 | |
1082 | // Records temp memory allocation. Tensor object is recorded to identify the |
1083 | // case where temp memory is used as output memory. |
1084 | void record_temp_memory_allocation(int64 size, const Tensor& t) |
1085 | LOCKS_EXCLUDED(stats_mu_); |
1086 | |
1087 | // Returns recorded size of temporary memory; |
1088 | int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); |
1089 | |
1090 | // Records persistent memory allocation, size can be negative indicating |
1091 | // deallocation. |
1092 | void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1) |
1093 | LOCKS_EXCLUDED(stats_mu_); |
1094 | |
1095 | // Returns recorded size and ids of persistent memory. |
1096 | int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); |
1097 | |
1098 | std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_); |
1099 | |
1100 | // Resets counters for temp and persistent memory and recorded ids. |
1101 | void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_); |
1102 | |
1103 | bool input_is_ref(int index) const; |
1104 | |
1105 | private: |
1106 | Allocator* get_allocator(AllocatorAttributes attr); |
1107 | |
1108 | // Internal method to add a tensor's buffer to the list of buffers |
1109 | // referenced during the execution of the Op, so that GPUs may |
1110 | // accurately track the memory that may not be reused until the Op |
1111 | // execution completes. |
1112 | void record_tensor_reference(const Tensor& tensor); |
1113 | void really_record_tensor_reference(const Tensor& tensor); |
1114 | |
1115 | // Internal common method used when allocating tensor memory |
1116 | Status allocate_tensor(DataType type, const TensorShape& shape, |
1117 | Tensor* out_tensor, |
1118 | AllocatorAttributes allocator_attr) { |
1119 | return allocate_tensor(type, shape, out_tensor, allocator_attr, |
1120 | AllocationAttributes()); |
1121 | } |
1122 | |
1123 | Status allocate_tensor(DataType type, const TensorShape& shape, |
1124 | Tensor* out_tensor, AllocatorAttributes allocator_attr, |
1125 | const AllocationAttributes& allocation_attr); |
1126 | |
1127 | // This is called by PersistentTensor::AccessTensor whenever the |
1128 | // wrapped tensor is retrieved, to ensure the runtime knows that the |
1129 | // Tensor is being accessed within an Op. This is necessary for |
1130 | // memory safety of devices like GPUs that queue Ops for |
1131 | // asynchronous execution after the Compute() method completes. |
1132 | friend class PersistentTensor; |
1133 | void NotifyUseOfPersistentTensor(const Tensor& tensor); |
1134 | |
1135 | Status status_; |
1136 | friend class CollectiveExecutor; // for access to params_ |
1137 | Params* params_; // not owned |
1138 | mutable mutex mu_; // mutable so const accessors can acquire the lock |
1139 | gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_); |
1140 | gtl::InlinedVector<TensorValue, 4> outputs_; |
1141 | |
1142 | // Constructed only if <params->record_tensor_accesses>. |
1143 | ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_); |
1144 | |
1145 | bool is_output_dead_ = false; |
1146 | |
1147 | // The following data members are only used when allocation tracking is |
1148 | // enabled. |
1149 | mutable mutex stats_mu_; |
1150 | int64 temp_memory_allocated_ GUARDED_BY(stats_mu_); |
1151 | int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_); |
1152 | std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>> |
1153 | temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_); |
1154 | std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_ |
1155 | GUARDED_BY(stats_mu_); |
1156 | |
1157 | TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); |
1158 | }; |
1159 | |
1160 | // Register your OpKernel by specifying the Op's name, the device the |
1161 | // kernel runs on, any type attr constraints for this kernel, any |
1162 | // host-memory args, and the class to instantiate. Examples: |
1163 | // |
1164 | // // A kernel that supports all types. |
1165 | // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); |
1166 | // |
1167 | // // The following are equivalent ways of specifying that the kernel only |
1168 | // // works if the "T" type attr is set to DT_FLOAT. |
1169 | // REGISTER_KERNEL_BUILDER( |
1170 | // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"), |
1171 | // SubOp<float>); |
1172 | // // (You would then repeat this for every type supported by "Sub".) |
1173 | // |
1174 | // // This form allows you to specify a list of types as the constraint. |
1175 | // REGISTER_KERNEL_BUILDER(Name("Sub") |
1176 | // .Device(DEVICE_CPU) |
1177 | // .TypeConstraint("T", {DT_FLOAT}), |
1178 | // SubOp<float>); |
1179 | // |
1180 | // // A kernel that expects one of the input tensors in host memory. |
1181 | // REGISTER_KERNEL_BUILDER( |
1182 | // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp); |
1183 | // |
1184 | // See kernel_def_builder for details. |
1185 | |
1186 | // Instantiate an OpKernel that has been registered. Returns nullptr |
1187 | // if no operation for that type of device / input signature combination |
1188 | // (and a NOT_FOUND *status), or there is an error in construction (and |
1189 | // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership |
1190 | // of the returned pointer. |
1191 | // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...); |
1192 | // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). |
1193 | std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type, |
1194 | DeviceBase* device, |
1195 | Allocator* allocator, |
1196 | const NodeDef& def, |
1197 | int graph_def_version, Status* status); |
1198 | Status CreateOpKernel(DeviceType device_type, DeviceBase* device, |
1199 | Allocator* allocator, FunctionLibraryRuntime* flib, |
1200 | const NodeDef& def, int graph_def_version, |
1201 | OpKernel** kernel); |
1202 | |
1203 | // Returns into 'device_types' the subset of prioritized_types that this |
1204 | // binary has registered for the given NodeDef. |
1205 | // |
1206 | // REQUIRES: * 'device_types' is not nullptr. |
1207 | // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). |
1208 | Status SupportedDeviceTypesForNode( |
1209 | const std::vector<DeviceType>& prioritized_types, const NodeDef& def, |
1210 | DeviceTypeVector* device_types); |
1211 | |
1212 | // Returns a message with a description of the kernels registered for op |
1213 | // `op_name`. |
1214 | string KernelsRegisteredForOp(StringPiece op_name); |
1215 | |
1216 | // Call once after Op registration has completed. |
1217 | Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry); |
1218 | |
1219 | // ----------------------------------------------------------------------------- |
1220 | // OpKernel registration implementation follows, please ignore. |
1221 | |
1222 | // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax. |
1223 | namespace register_kernel { |
1224 | |
1225 | class Name : public KernelDefBuilder { |
1226 | public: |
1227 | // With selective registration, kernels whose implementation class is not used |
1228 | // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in |
1229 | // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an |
1230 | // implementation class with a used kernel would get through that mechanism. |
1231 | // |
1232 | // This mechanism stops that registration by changing the name of the kernel |
1233 | // for the unused op to one that is ignored by |
1234 | // OpKernelRegistrar::InitInternal. Note that this method alone is |
1235 | // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at |
1236 | // compilation time, so this method doesn't actually reduce code size. |
1237 | explicit Name(const char* op) |
1238 | : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register" ) {} |
1239 | }; |
1240 | |
1241 | namespace system { |
1242 | |
1243 | class Name : public KernelDefBuilder { |
1244 | public: |
1245 | // For system kernels, we ignore selective registration and |
1246 | // unconditionally register the kernel. |
1247 | explicit Name(const char* op) : KernelDefBuilder(op) {} |
1248 | }; |
1249 | |
1250 | } // namespace system |
1251 | |
1252 | } // namespace register_kernel |
1253 | |
1254 | #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ |
1255 | REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) |
1256 | |
1257 | #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ |
1258 | REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) |
1259 | |
1260 | #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ |
1261 | constexpr bool should_register_##ctr##__flag = \ |
1262 | SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \ |
1263 | static ::tensorflow::kernel_factory::OpKernelRegistrar \ |
1264 | registrar__body__##ctr##__object( \ |
1265 | should_register_##ctr##__flag \ |
1266 | ? ::tensorflow::register_kernel::kernel_builder.Build() \ |
1267 | : nullptr, \ |
1268 | #__VA_ARGS__, \ |
1269 | [](::tensorflow::OpKernelConstruction* context) \ |
1270 | -> ::tensorflow::OpKernel* { \ |
1271 | return new __VA_ARGS__(context); \ |
1272 | }); |
1273 | |
1274 | // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as |
1275 | // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered |
1276 | // unconditionally even when selective registration is used. |
1277 | #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \ |
1278 | REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \ |
1279 | __VA_ARGS__) |
1280 | |
1281 | #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ |
1282 | REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) |
1283 | |
1284 | #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ |
1285 | static ::tensorflow::kernel_factory::OpKernelRegistrar \ |
1286 | registrar__body__##ctr##__object( \ |
1287 | ::tensorflow::register_kernel::system::kernel_builder.Build(), \ |
1288 | #__VA_ARGS__, \ |
1289 | [](::tensorflow::OpKernelConstruction* context) \ |
1290 | -> ::tensorflow::OpKernel* { \ |
1291 | return new __VA_ARGS__(context); \ |
1292 | }); |
1293 | |
1294 | void* GlobalKernelRegistry(); |
1295 | |
1296 | // If node_def has a corresponding kernel registered on device_type, |
1297 | // returns OK and fill in the kernel def and kernel_class_name. <def> and |
1298 | // <kernel_class_name> may be null. |
1299 | Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, |
1300 | const KernelDef** def, string* kernel_class_name); |
1301 | |
1302 | // Writes a list of all registered kernels to LOG(INFO), to help users debug |
1303 | // missing kernel errors. |
1304 | void LogAllRegisteredKernels(); |
1305 | |
1306 | namespace kernel_factory { |
1307 | |
1308 | class OpKernelRegistrar { |
1309 | public: |
1310 | typedef OpKernel* (*Factory)(OpKernelConstruction*); |
1311 | |
1312 | OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, |
1313 | Factory factory) { |
1314 | // Perform the check in the header to allow compile-time optimization |
1315 | // to a no-op, allowing the linker to remove the kernel symbols. |
1316 | if (kernel_def != nullptr) { |
1317 | InitInternal(kernel_def, kernel_class_name, factory); |
1318 | } |
1319 | } |
1320 | |
1321 | private: |
1322 | void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, |
1323 | Factory factory); |
1324 | }; |
1325 | |
1326 | } // namespace kernel_factory |
1327 | |
1328 | // ----------------------------------------------------------------------------- |
1329 | // Template and inline method implementations, please ignore |
1330 | |
1331 | template <class T> |
1332 | Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { |
1333 | return GetNodeAttr(def(), attr_name, value); |
1334 | } |
1335 | |
1336 | inline DataType OpKernelContext::input_dtype(int index) const { |
1337 | DCHECK_GE(index, 0); |
1338 | DCHECK_LT(index, num_inputs()); |
1339 | const TensorValue& value((*params_->inputs)[index]); |
1340 | if (value.is_ref()) { |
1341 | return MakeRefType(value->dtype()); |
1342 | } else { |
1343 | return value->dtype(); |
1344 | } |
1345 | } |
1346 | |
1347 | inline MemoryType OpKernelContext::input_memory_type(int index) const { |
1348 | DCHECK_GE(index, 0); |
1349 | DCHECK_LT(index, num_inputs()); |
1350 | return op_kernel().input_memory_types()[index]; |
1351 | } |
1352 | |
1353 | inline DataType OpKernelContext::expected_output_dtype(int index) const { |
1354 | DCHECK_GE(index, 0); |
1355 | DCHECK_LT(index, num_outputs()); |
1356 | return params_->op_kernel->output_type(index); |
1357 | } |
1358 | |
1359 | inline MemoryType OpKernelContext::output_memory_type(int index) const { |
1360 | DCHECK_GE(index, 0); |
1361 | DCHECK_LT(index, num_outputs()); |
1362 | return op_kernel().output_memory_types()[index]; |
1363 | } |
1364 | |
1365 | inline bool OpKernelContext::input_is_ref(int index) const { |
1366 | const TensorValue& value((*params_->inputs)[index]); |
1367 | return value.is_ref(); |
1368 | } |
1369 | |
1370 | inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { |
1371 | DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(), |
1372 | params_->record_tensor_accesses); |
1373 | if (params_->record_tensor_accesses) { |
1374 | really_record_tensor_reference(tensor); |
1375 | } |
1376 | } |
1377 | |
1378 | inline void OpKernelContext::retrieve_accessed_tensors( |
1379 | TensorReferenceVector* out_vector) { |
1380 | if (params_->record_tensor_accesses) { |
1381 | mutex_lock l(mu_); |
1382 | referenced_tensors_->FreezeAndReturnReferences(out_vector); |
1383 | } |
1384 | } |
1385 | |
1386 | // no input if tensor == nullptr. |
1387 | inline bool OpKernelContext::has_input(int index) const { |
1388 | DCHECK_GE(index, 0); |
1389 | DCHECK_LT(index, num_inputs()); |
1390 | return (*params_->inputs)[index].tensor != nullptr; |
1391 | } |
1392 | |
1393 | inline mutex* OpKernelContext::input_ref_mutex(int index) { |
1394 | DCHECK_GE(index, 0); |
1395 | DCHECK_LT(index, num_inputs()); |
1396 | DCHECK(input_is_ref(index)); |
1397 | return (*params_->inputs)[index].mutex_if_ref; |
1398 | } |
1399 | |
1400 | inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { |
1401 | if (t.IsInitialized()) { |
1402 | record_tensor_reference(t); |
1403 | } |
1404 | } |
1405 | |
1406 | inline Tensor* OpKernelContext::mutable_output(int index) { |
1407 | DCHECK_GE(index, 0); |
1408 | DCHECK_LT(index, num_outputs()); |
1409 | // No need to record_tensor_reference since the output must already |
1410 | // have been set by a call that did so. |
1411 | return outputs_[index].tensor; |
1412 | } |
1413 | |
1414 | inline TensorValue OpKernelContext::release_output(int index) { |
1415 | DCHECK_GE(index, 0); |
1416 | DCHECK_LT(index, num_outputs()); |
1417 | TensorValue value = outputs_[index]; |
1418 | outputs_[index] = TensorValue(); |
1419 | return value; |
1420 | } |
1421 | |
1422 | inline Status OpKernelContext::forward_input_or_allocate_output( |
1423 | gtl::ArraySlice<int> candidate_input_indices, int output_index, |
1424 | const TensorShape& output_shape, Tensor** output) { |
1425 | for (int input_index : candidate_input_indices) { |
1426 | if (forward_input_to_output_with_shape(input_index, output_index, |
1427 | output_shape, output)) { |
1428 | return Status::OK(); |
1429 | } |
1430 | } |
1431 | return allocate_output(output_index, output_shape, output); |
1432 | } |
1433 | |
1434 | inline Status OpKernelContext::forward_input_or_allocate_output( |
1435 | gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name, |
1436 | const TensorShape& output_shape, Tensor** output) { |
1437 | for (const StringPiece& input_name : candidate_input_names) { |
1438 | if (forward_input_to_output_with_shape(input_name, output_name, |
1439 | output_shape, output) |
1440 | .ok()) { |
1441 | return Status::OK(); |
1442 | } |
1443 | } |
1444 | return allocate_output(output_name, output_shape, output); |
1445 | } |
1446 | |
1447 | template <typename T> |
1448 | T* OpKernelContext::op_device_context() { |
1449 | static_assert(std::is_base_of<DeviceContext, T>::value, |
1450 | "T is not a subclass of DeviceContext" ); |
1451 | return static_cast<T*>(op_device_context()); |
1452 | } |
1453 | |
1454 | template <typename T> |
1455 | T* OpKernelContext::input_device_context(int index) { |
1456 | DCHECK_GE(index, 0); |
1457 | DCHECK_LT(index, params_->input_device_contexts->size()); |
1458 | static_assert(std::is_base_of<DeviceContext, T>::value, |
1459 | "T is not a subclass of DeviceContext" ); |
1460 | return static_cast<T*>((*params_->input_device_contexts)[index]); |
1461 | } |
1462 | |
1463 | inline DeviceContext* OpKernelContext::input_device_context(int index) { |
1464 | DCHECK_GE(index, 0); |
1465 | DCHECK_LT(index, params_->input_device_contexts->size()); |
1466 | return (*params_->input_device_contexts)[index]; |
1467 | } |
1468 | |
1469 | inline const Tensor& OpInputList::operator[](int i) const { |
1470 | DCHECK_GE(i, 0); |
1471 | DCHECK_LT(i, stop_ - start_); |
1472 | return ctx_->input(start_ + i); |
1473 | } |
1474 | |
1475 | inline mutex* OpMutableInputList::ref_mutex(int i) { |
1476 | DCHECK_GE(i, 0); |
1477 | DCHECK_LT(i, stop_ - start_); |
1478 | return ctx_->input_ref_mutex(start_ + i); |
1479 | } |
1480 | |
1481 | inline Tensor OpMutableInputList::at(int i, bool lock_held) { |
1482 | DCHECK_GE(i, 0); |
1483 | DCHECK_LT(i, stop_ - start_); |
1484 | return ctx_->mutable_input(start_ + i, lock_held); |
1485 | } |
1486 | |
1487 | inline Tensor* OpOutputList::operator[](int i) { |
1488 | DCHECK_GE(i, 0); |
1489 | DCHECK_LT(i, stop_ - start_); |
1490 | return ctx_->mutable_output(start_ + i); |
1491 | } |
1492 | |
1493 | inline bool OpOutputList::required(int i) const { |
1494 | DCHECK_GE(i, 0); |
1495 | DCHECK_LT(i, stop_ - start_); |
1496 | return ctx_->output_required(start_ + i); |
1497 | } |
1498 | |
1499 | inline DataType OpOutputList::expected_output_dtype(int i) const { |
1500 | DCHECK_GE(i, 0); |
1501 | DCHECK_LT(i, stop_ - start_); |
1502 | return ctx_->expected_output_dtype(start_ + i); |
1503 | } |
1504 | |
1505 | inline Status OpOutputList::allocate(int i, const TensorShape& shape, |
1506 | Tensor** output) { |
1507 | DCHECK_GE(i, 0); |
1508 | DCHECK_LT(i, stop_ - start_); |
1509 | return ctx_->allocate_output(start_ + i, shape, output); |
1510 | } |
1511 | |
1512 | inline void OpOutputList::set(int i, const Tensor& tensor) { |
1513 | DCHECK_GE(i, 0); |
1514 | DCHECK_LT(i, stop_ - start_); |
1515 | ctx_->set_output(start_ + i, tensor); |
1516 | } |
1517 | |
1518 | inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { |
1519 | DCHECK_GE(i, 0); |
1520 | DCHECK_LT(i, stop_ - start_); |
1521 | ctx_->set_output_ref(i, mu, tensor_for_ref); |
1522 | } |
1523 | |
1524 | // Convenience macros for asserting and handling exceptional conditions. |
1525 | // Analogous to the CHECK* macros provided by logging.h. |
1526 | // |
1527 | // Example use: |
1528 | // void Compute(OperationContext* context) { |
1529 | // OP_REQUIRES(context, context->num_inputs() == 2, |
1530 | // errors::InvalidArgument("FooOp requires 2 arguments")); |
1531 | // ... |
1532 | // Status status = SomeUncertainMethod(); |
1533 | // OP_REQUIRES_OK(context, status); |
1534 | // ... |
1535 | // } |
1536 | |
1537 | #define OP_REQUIRES(CTX, EXP, STATUS) \ |
1538 | do { \ |
1539 | if (!TF_PREDICT_TRUE(EXP)) { \ |
1540 | (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ |
1541 | return; \ |
1542 | } \ |
1543 | } while (0) |
1544 | |
1545 | #define OP_REQUIRES_OK(CTX, ...) \ |
1546 | do { \ |
1547 | ::tensorflow::Status _s(__VA_ARGS__); \ |
1548 | if (!TF_PREDICT_TRUE(_s.ok())) { \ |
1549 | (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ |
1550 | return; \ |
1551 | } \ |
1552 | } while (0) |
1553 | |
1554 | #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ |
1555 | do { \ |
1556 | if (!TF_PREDICT_TRUE(EXP)) { \ |
1557 | (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ |
1558 | (CALLBACK)(); \ |
1559 | return; \ |
1560 | } \ |
1561 | } while (0) |
1562 | |
1563 | #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ |
1564 | do { \ |
1565 | ::tensorflow::Status _s(STATUS); \ |
1566 | if (!TF_PREDICT_TRUE(_s.ok())) { \ |
1567 | (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ |
1568 | (CALLBACK)(); \ |
1569 | return; \ |
1570 | } \ |
1571 | } while (0) |
1572 | |
1573 | } // namespace tensorflow |
1574 | |
1575 | #endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ |
1576 | |