1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
17#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
18
19#include <functional>
20
21#include <utility>
22#include <vector>
23#include "tensorflow/core/framework/allocator.h"
24#include "tensorflow/core/framework/cancellation.h"
25#include "tensorflow/core/framework/control_flow.h"
26#include "tensorflow/core/framework/device_base.h"
27#include "tensorflow/core/framework/kernel_def_builder.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
30#include "tensorflow/core/framework/rendezvous.h"
31#include "tensorflow/core/framework/selective_registration.h"
32#include "tensorflow/core/framework/session_state.h"
33#include "tensorflow/core/framework/tensor.h"
34#include "tensorflow/core/framework/tensor_shape.h"
35#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
36#include "tensorflow/core/framework/tracking_allocator.h"
37#include "tensorflow/core/framework/types.h"
38#include "tensorflow/core/framework/types.pb.h"
39#include "tensorflow/core/framework/unique_tensor_references.h"
40#include "tensorflow/core/lib/core/errors.h"
41#include "tensorflow/core/lib/core/status.h"
42#include "tensorflow/core/lib/gtl/array_slice.h"
43#include "tensorflow/core/lib/gtl/manual_constructor.h"
44#include "tensorflow/core/platform/env.h"
45#include "tensorflow/core/platform/logging.h"
46#include "tensorflow/core/platform/macros.h"
47#include "tensorflow/core/platform/mutex.h"
48#include "tensorflow/core/platform/thread_annotations.h"
49#include "tensorflow/core/platform/types.h"
50
51namespace Eigen {
52struct ThreadPoolDevice;
53struct GpuDevice;
54struct SyclDevice;
55} // end namespace Eigen
56
57namespace tensorflow {
58
59namespace checkpoint {
60class TensorSliceReaderCacheWrapper;
61} // namespace checkpoint
62
63class AsyncOpKernel;
64class CallFrameInterface;
65class FunctionLibraryRuntime;
66class OpKernelConstruction; // declared below
67class OpKernelContext; // declared below,
68class OpRegistryInterface;
69class ResourceMgr;
70class ScopedStepContainer;
71class CollectiveExecutor;
72class StepStatsCollector;
73
74class OpKernel {
75 public:
76 // OpKernel won't be instantiated by the scheduler, so you may perform
77 // expensive initialization in the descendant's constructor.
78 explicit OpKernel(OpKernelConstruction* context);
79
80 // Specialized constructor that enables the descendant to provide a different
81 // `NodeDef` value. For example, this constructor can be used to provide a
82 // stripped-down `NodeDef` that does not contain the full set of attrs (such
83 // as tensor values) if the descendant stores them in a different form.
84 explicit OpKernel(OpKernelConstruction* context,
85 std::unique_ptr<const NodeDef> node_def);
86
87 virtual ~OpKernel();
88
89 // An OpKernel's computation can be either synchronous or
90 // asynchronous. All OpKernel Compute() methods must be thread-safe as they
91 // may be called concurrently (e.g. by multiple executions of the same graph
92 // concurrently).
93 //
94 // Most OpKernels should compute synchronously. They should
95 // subclass OpKernel and override the Compute() method and have it
96 // return after completing the supplied work.
97 //
98 // A few special kernels might need to be asynchronous to bound the
99 // number of threads (e.g., network receive operations). These
100 // kernels must subclass AsyncOpKernel and override
101 // AsyncOpKernel::ComputeAsync().
102 //
103 // In both cases, implementations of Compute() and ComputeAsync()
104 // get inputs and write outputs through the given OpKernelContext
105 // and returns a status via context->SetStatus(). They must be
106 // thread-safe.
107
108 // Synchronous compute.
109 //
110 // "context" is guaranteed to be alive until Compute() returns.
111 virtual void Compute(OpKernelContext* context) = 0;
112
113 // Returns nullptr iff this op kernel is synchronous.
114 virtual AsyncOpKernel* AsAsync() { return nullptr; }
115
116 // Returns true iff this op kernel is considered "expensive". The
117 // runtime may use this flag to optimize graph execution for example
118 // to "inline" inexpensive kernels.
119 virtual bool IsExpensive() { return expensive_; }
120
121 // Accessors.
122 const NodeDef& def() const { return *def_; }
123 const string& name() const; // Same as def().name()
124 const string& type_string() const; // Same as def().op()
125 const string& requested_device() const; // Same as def().device()
126 bool is_internal() const { return is_internal_; }
127
128 int num_inputs() const { return input_types_.size(); }
129 DataType input_type(int i) const { return input_types_[i]; }
130 const DataTypeVector& input_types() const { return input_types_; }
131 const MemoryTypeVector& input_memory_types() const {
132 return input_memory_types_;
133 }
134 const string& requested_input(int i) const; // Same as def().input(i)
135
136 int num_outputs() const { return output_types_.size(); }
137 DataType output_type(int o) const { return output_types_[o]; }
138 const DataTypeVector& output_types() const { return output_types_; }
139 const MemoryTypeVector& output_memory_types() const {
140 return output_memory_types_;
141 }
142
143 Status InputRange(StringPiece input_name, int* start, int* stop) const;
144 Status OutputRange(StringPiece output_name, int* start, int* stop) const;
145
146 // We allow legacy scalars within Google up until GraphDef version 6.
147 // TODO(irving): Remove when we can drop support for GraphDef version 5.
148 bool allow_legacy_scalars() const {
149#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
150 return graph_def_version_ < 6;
151#else
152 return false;
153#endif
154 }
155
156 // Allow either scalars or (if allowing legacy scalars) shape (1,).
157 bool IsLegacyScalar(const TensorShape& shape) const {
158 return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 &&
159 shape.dim_size(0) == 1);
160 }
161
162 // Allow rank 1 or (if allowing legacy scalars) rank 0.
163 bool IsLegacyVector(const TensorShape& shape) const {
164 return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
165 }
166
167 // Turn a shape Tensor into a TensorShape
168 // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
169 Status MakeShape(const Tensor& shape, TensorShape* out) const;
170
171 private:
172 const std::unique_ptr<const NodeDef> def_;
173 const DataTypeVector input_types_;
174 const MemoryTypeVector input_memory_types_;
175 const DataTypeVector output_types_;
176 const MemoryTypeVector output_memory_types_;
177 const int graph_def_version_;
178 const bool is_internal_; // True if this is an internal operation
179 NameRangeMap input_name_map_;
180 NameRangeMap output_name_map_;
181 bool expensive_;
182
183 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
184};
185
186class AsyncOpKernel : public OpKernel {
187 public:
188 using OpKernel::OpKernel; // Lift OpKernel constructors.
189
190 // Asynchronous compute.
191 //
192 // Implementations of ComputeAsync() must run "done" to signal the
193 // completion of the computation. "context" is guaranteed to be
194 // alive until the "done" callback starts.
195 typedef std::function<void()> DoneCallback;
196 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
197
198 AsyncOpKernel* AsAsync() final { return this; }
199
200 void Compute(OpKernelContext* context) final;
201
202 bool IsExpensive() override { return true; }
203};
204
205// Wraps a tensor that is held by an Op across calls to Compute(). For
206// memory safety when using asynchronous devices like GPUs, the system
207// must be notified when a Tensor is used inside an Op execution. The
208// wrapper ensures that all uses of the Tensor are tracked, because in
209// order to retrieve the Tensor the caller must use AccessTensor which
210// notifies the context.
211class PersistentTensor {
212 public:
213 PersistentTensor() {}
214 explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
215
216 // Caller does not own the returned Tensor*.
217 Tensor* AccessTensor(OpKernelConstruction* context);
218 // Caller does not own the returned Tensor*.
219 Tensor* AccessTensor(OpKernelContext* context);
220
221 // The check for initialization does not need to access the
222 // underlying tensor buffer.
223 bool IsInitialized() const { return tensor_.IsInitialized(); }
224
225 int64 NumElements() const { return tensor_.NumElements(); }
226
227 int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
228
229 private:
230 Tensor tensor_;
231};
232
233class OpKernelConstruction {
234 public:
235 OpKernelConstruction(DeviceType device_type, DeviceBase* device,
236 Allocator* allocator, const NodeDef* node_def,
237 const OpDef* op_def, FunctionLibraryRuntime* flib,
238 const DataTypeSlice& input_types,
239 const MemoryTypeSlice& input_memory_types,
240 const DataTypeSlice& output_types,
241 const MemoryTypeSlice& output_memory_types,
242 int graph_def_version, Status* status);
243
244 Env* env() const { return device_->env(); }
245
246 // Allocation of tensors during kernel construction:
247 //
248 // It is legal to temporarily allocate scratch tensor storage during
249 // Op kernel construction. Scratch tensors should be allocated using
250 // allocate_temp below. Some kernels need to keep tensors in between
251 // invocations. If such a Tensor is allocated during kernel
252 // construction this must be done using allocate_persistent, and the
253 // Op may only store the returned PersistentTensor object. When the
254 // Tensor is needed in a subsequent invocation, it can be retrieved
255 // from the PersistentTensor using the AccessTensor method. This
256 // ensures that the system is made aware of any use of the tensor's
257 // allocated memory, which is needed for correctness on asynchronous
258 // devices such as GPUs.
259
260 // Allocates a temporary Tensor of the specified type and shape. The
261 // Tensor must not be used after kernel construction is
262 // complete. See comment above.
263 Status allocate_temp(DataType type, const TensorShape& shape,
264 Tensor* out_temp);
265
266 // Allocates a Tensor of the specified type and shape which the Op
267 // plans to maintain as persistent state. out_persistent holds the
268 // PersistentTensor which is the object the caller should store. For
269 // convenience, if out_tensor is non-null then it will be filled in
270 // with a Tensor* pointing to the newly-allocated tensor which the
271 // caller can use instead of calling
272 // out_persistent->AccessTensor. The caller does not own out_tensor
273 // and should not keep a copy of it. See comment above.
274 Status allocate_persistent(DataType type, const TensorShape& shape,
275 PersistentTensor* out_persistent,
276 Tensor** out_tensor);
277
278 // User-supplied configuration of this operation.
279 const NodeDef& def() const { return *def_; }
280
281 // For inspecting the inputs to this operation.
282 int num_inputs() const { return input_types_.size(); }
283 DataType input_type(int i) const { return input_types_[i]; }
284 const DataTypeSlice& input_types() const { return input_types_; }
285 const MemoryTypeSlice& input_memory_types() const {
286 return input_memory_types_;
287 }
288
289 // For inspecting the outputs expected from this operation.
290 int num_outputs() const { return output_types_.size(); }
291 DataType output_type(int i) const { return output_types_[i]; }
292 const DataTypeSlice& output_types() const { return output_types_; }
293 const MemoryTypeSlice& output_memory_types() const {
294 return output_memory_types_;
295 }
296
297 // If expected_inputs == inputs() and expected_outputs == output_types(),
298 // returns OK, else returns INVALID_ARGUMENT with an error message.
299 // Recommended for Ops with dynamic signatures.
300 Status MatchSignature(const DataTypeSlice expected_inputs,
301 const DataTypeSlice expected_outputs);
302
303 // For recording configuration errors during construction.
304 void SetStatus(const Status& status);
305 const Status& status() const { return *status_; }
306
307 // Look up the attr with name attr_name and set *value to its value. If no
308 // attr with attr_name is found in def(), or the attr does not have
309 // a matching type, a non-ok status will be returned.
310 template <class T>
311 Status GetAttr(StringPiece attr_name, T* value) const;
312
313 // Return true if the attr_name is defined in def().
314 bool HasAttr(StringPiece attr_name) const;
315
316 // Return the device type.
317 const DeviceType& device_type() const { return device_type_; }
318
319 // If not nullptr, the kernel can instantiate functions defined in
320 // the library. E.g.,
321 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
322 FunctionLibraryRuntime* function_library() const { return flib_; }
323
324 // The GraphDef version whose behavior we should follow.
325 int graph_def_version() const { return graph_def_version_; }
326
327 // Helper routines for the OP_REQUIRES macros
328 void CtxFailure(const Status& s);
329 void CtxFailureWithWarning(const Status& s);
330 void CtxFailure(const char* file, int line, const Status& s);
331 void CtxFailureWithWarning(const char* file, int line, const Status& s);
332
333 // Unrecommended functions: these are functions that have some
334 // current uses but are not recommended for use, and may go away at
335 // some future major version release.
336
337 // May be used, e.g., to get GPU handles, etc.
338 //
339 // Currently only used to call MakeTensorFromProto() for
340 // implementing ConstantOp for every device. See comments
341 // on Device::MakeTensorFromProto for longer-term replacement
342 // ideas.
343 DeviceBase* device() const { return device_; }
344
345 private:
346 const DeviceType device_type_;
347 DeviceBase* const device_;
348 Allocator* allocator_;
349 const NodeDef* def_;
350 const OpDef* op_def_;
351 FunctionLibraryRuntime* flib_;
352 DataTypeSlice input_types_;
353 MemoryTypeSlice input_memory_types_;
354 DataTypeSlice output_types_;
355 MemoryTypeSlice output_memory_types_;
356 const int graph_def_version_;
357 Status* status_;
358
359 // Allow op_def_ across from OpKernel, but not from subclasses.
360 // TODO(irving): Remove protos from this header entirely.
361 friend class OpKernel;
362
363 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
364};
365
366// TODO(mrry): Consider converting to a random_access_iterator, and upgrading
367// tensorflow::gtl::iterator_range to make the below container classes
368// unnecessary.
369template <typename ListType, typename ElementType>
370class OpArgIterator {
371 public:
372 typedef OpArgIterator<ListType, ElementType> ME;
373 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
374 bool operator==(const ME& rhs) {
375 DCHECK(list_ == rhs.list_);
376 return i_ == rhs.i_;
377 }
378 bool operator!=(const ME& rhs) {
379 DCHECK(list_ == rhs.list_);
380 return i_ != rhs.i_;
381 }
382 void operator++() { ++i_; }
383 ElementType& operator*() { return (*list_)[i_]; }
384
385 private:
386 const ListType* const list_;
387 int i_;
388};
389
390// Utility class for representing a list of immutable input tensors
391// that are passed to the op as a single named argument.
392class OpInputList {
393 public:
394 typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
395 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
396 OpInputList(OpKernelContext* ctx, int start, int stop)
397 : ctx_(ctx), start_(start), stop_(stop) {}
398 OpInputList& operator=(const OpInputList& other) = default;
399 const Tensor& operator[](int i) const;
400 int size() const { return stop_ - start_; }
401 Iterator begin() const { return Iterator(this, 0); }
402 Iterator end() const { return Iterator(this, size()); }
403
404 private:
405 OpKernelContext* ctx_; // not owned
406 int start_;
407 int stop_;
408};
409
410// Utility class for representing a list of mutable ("ref") input tensors
411// that are passed to the op as a single named argument.
412class OpMutableInputList {
413 public:
414 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
415 OpMutableInputList(OpKernelContext* ctx, int start, int stop)
416 : ctx_(ctx), start_(start), stop_(stop) {}
417 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
418 OpMutableInputList& operator=(const OpMutableInputList& other) = default;
419 Tensor at(int i, bool lock_held);
420 mutex* ref_mutex(int i);
421 int size() const { return stop_ - start_; }
422 Iterator begin() const { return Iterator(this, 0); }
423 Iterator end() const { return Iterator(this, size()); }
424
425 private:
426 OpKernelContext* ctx_; // not owned
427 int start_;
428 int stop_;
429};
430
431// Utility class for representing a list of output tensors that are
432// grouped as a single named output.
433class OpOutputList {
434 public:
435 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
436 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
437 OpOutputList(OpKernelContext* ctx, int start, int stop)
438 : ctx_(ctx), start_(start), stop_(stop) {}
439 OpOutputList& operator=(const OpOutputList& other) = default;
440 Tensor* operator[](int i);
441 bool required(int i) const;
442 DataType expected_output_dtype(int i) const;
443 Status allocate(int i, const TensorShape& shape, Tensor** output);
444 void set(int i, const Tensor& tensor);
445 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
446 int size() const { return stop_ - start_; }
447 Iterator begin() const { return Iterator(this, 0); }
448 Iterator end() const { return Iterator(this, size()); }
449
450 private:
451 OpKernelContext* ctx_; // not owned
452 int start_;
453 int stop_;
454};
455
456// Holds a tensor or tensor reference. For tensor references, we need
457// a mutex to prevent concurrent access to the tensor.
458struct TensorValue {
459 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
460 TensorValue(Tensor* t) // NOLINT(runtime/explicit)
461 : mutex_if_ref(nullptr), tensor(t) {}
462 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
463 Tensor* operator->() const { return tensor; }
464 bool is_ref() const { return mutex_if_ref != nullptr; }
465
466 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
467 Tensor* tensor;
468};
469
470class OpKernelContext {
471 public:
472 // The first element of a WrappedAllocator is a "base" Allocator and
473 // the second element is that Allocator wrapped by a
474 // TrackingAllocator
475 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
476
477 // TODO(zhifengc): Do some cleanup of Params.
478 // The Params struct is passed in to initialize an OpKernelContext,
479 // and must outlive the OpKernelContext.
480 struct Params {
481 ~Params() { delete eigen_gpu_device; }
482
483 // The step being executed.
484 int64 step_id = 0;
485
486 // The op kernel being computed.
487 OpKernel* op_kernel = nullptr;
488
489 // The device on which the kernel is running.
490 DeviceBase* device = nullptr;
491
492 // The Eigen GPU device wrapper, which may include a per-op
493 // wrapped allocator. The concrete type of this object depends on
494 // the type of this->device, so eigen_gpu_device can't be an
495 // inline member and must be heap allocated. However, we don't
496 // want to allocate a new eigen_gpu_device for every Op that is
497 // executed. Instead this member is allocated on first use using
498 // ensure_eigen_gpu_device, and then if the Params structure is
499 // re-used for subsequent Ops, the eigen_gpu_device is
500 // ReInitialized in the OpKernelContext constructor. Unlike the
501 // other pointers in Params, this one is owned by Params.
502 PerOpGpuDevice* eigen_gpu_device = nullptr;
503
504 inline void ensure_eigen_gpu_device() {
505 DCHECK(device);
506 if (nullptr == eigen_gpu_device) {
507 // Surprisingly, MakeGpuDevice will return nullptr if the
508 // device is not a GPU device. This is ok, since those devices
509 // will never use eigen_gpu_device. It seems better to have
510 // ensure_eigen_gpu_device fall through and regenerate the
511 // nullptr every time an OpKernelContext is instantiated, than
512 // to do an unnecessary allocation of a dummy eigen GPU
513 // device for CPU device Ops.
514 eigen_gpu_device = device->MakeGpuDevice();
515 }
516 }
517
518 bool track_allocations = false;
519 bool log_memory = false;
520 bool record_tensor_accesses = false;
521
522 // Array indexed by output number for this node
523 const AllocatorAttributes* output_attr_array = nullptr;
524
525 // Shared resources accessible by this op kernel invocation.
526 ResourceMgr* resource_manager = nullptr;
527
528 // Per-step resources accessible by this op kernel invocation should be
529 // stored in this container..
530 ScopedStepContainer* step_container = nullptr;
531
532 // Mechanism used by this op kernel invocation to communicate with
533 // computations running on other devices.
534 Rendezvous* rendezvous = nullptr;
535
536 // Mechanism for executing a collective op that needs to coordinate
537 // with parallel instances runing on other devices.
538 CollectiveExecutor* collective_executor = nullptr;
539
540 // The session state for this op.
541 SessionState* session_state = nullptr;
542
543 // The tensor store for this op.
544 TensorStore* tensor_store = nullptr;
545
546 // Mechanism used by this op kernel invocation to register a callback
547 // for its cancellation.
548 CancellationManager* cancellation_manager = nullptr;
549
550 // Inputs to this op kernel.
551 const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
552 bool is_input_dead = false;
553
554 const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
555 nullptr;
556
557 // Device contexts.
558 const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
559 nullptr;
560 DeviceContext* op_device_context = nullptr;
561
562 // Control-flow op supports.
563 FrameAndIter frame_iter;
564
565 // Function call supports.
566 CallFrameInterface* call_frame = nullptr;
567 FunctionLibraryRuntime* function_library = nullptr;
568 std::function<void(std::function<void()>)>* runner = nullptr;
569 StepStatsCollector* stats_collector = nullptr;
570
571 // TensorSliceReaderCache support.
572 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
573
574 // Support for forwarding reservations (used by ScopedAllocator).
575 static const int kNeverForward = -2;
576 static const int kNoReservation = -1;
577 // Values in [0,...) represent reservations for the indexed output.
578 const int* forward_from_array = nullptr;
579 };
580
581 // params must outlive the OpKernelContext.
582 explicit OpKernelContext(Params* params);
583 OpKernelContext(Params* params, int noutputs);
584 ~OpKernelContext();
585
586 Env* env() const { return params_->device->env(); }
587
588 int64 step_id() const { return params_->step_id; }
589
590 const OpKernel& op_kernel() const { return *params_->op_kernel; }
591
592 // Input/output signature.
593
594 int num_inputs() const { return params_->inputs->size(); }
595 DataType input_dtype(int index) const;
596 Status input_dtype(StringPiece name, DataType* dtype) const;
597 MemoryType input_memory_type(int index) const;
598
599 int num_outputs() const { return outputs_.size(); }
600 DataType expected_output_dtype(int index) const;
601 MemoryType output_memory_type(int index) const;
602
603 // Input
604
605 // Returns an immutable input tensor. May only be used for non-Ref
606 // inputs. For Ref inputs use mutable_input below.
607 // REQUIRES: !IsRefType(input_dtype(index))
608 // TODO(mrry): Convert this to return Status.
609 const Tensor& input(int index);
610
611 // Returns the named immutable input tensor in "tensor", as defined
612 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
613 // use mutable_input below.
614 // REQUIRES: !IsRefType(input_dtype(index))
615 // REQUIRES: the named input must not be a list.
616 Status input(StringPiece name, const Tensor** tensor);
617
618 // Returns the named list-valued immutable input in "list", as
619 // defined in the OpDef. If the named output is not list-valued,
620 // returns a one-element list. May only be used for non-Ref
621 // inputs. For Ref inputs use mutable_input below.
622 // REQUIRES: !IsRefType(input_dtype(index))
623 Status input_list(StringPiece name, OpInputList* list);
624
625 // For mutable inputs, use the following together to make sure there
626 // is no concurrent access to mutable_input(), e.g.:
627 // {
628 // Tensor& t = context->mutable_input(index);
629 // mutex_lock lock(*context->input_ref_mutex(index));
630 // // modify the values in t
631 // }
632 // REQUIRES: IsRefType(input_dtype(index))
633 Status input_ref_mutex(StringPiece name, mutex** out_mutex);
634
635 // Returns a mutable input tensor. Must be used to access Ref
636 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
637 // modify the values stored in the Tensor buffer, and modifications
638 // will be visible to other Ops reading the same ref tensor. If
639 // !lock_held the input mutex will be acquired before returning the
640 // Tensor.
641 // TODO(mrry): Convert this to return Status.
642 Tensor mutable_input(int index, bool lock_held);
643
644 // Returns the named mutable input tensor in "tensor", as defined in
645 // the OpDef. Must be used to access Ref inputs. The values stored
646 // in the Tensor buffer may be modified, and modifications will be
647 // visible to other Ops reading the same ref tensor. If !lock_held
648 // the input mutex will be acquired before returning the Tensor.
649 // REQUIRES: the named input must not be a list.
650 // REQUIRES: the named input must be a ref tensor.
651 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
652
653 // Returns the named list-valued mutable input in "list", as defined
654 // in the OpDef. If the named input is not list-valued, returns a
655 // one-element list. Must be used to access Ref inputs. The values
656 // stored in the Tensor buffer may be modified, and modifications
657 // will be visible to other Ops reading the same ref tensor.
658 // REQUIRES: the named input must be a ref tensor.
659 Status mutable_input_list(StringPiece name, OpMutableInputList* list);
660
661 // Replace the corresponding Ref Input to use the storage buffer
662 // used by tensor. If !lock_held the input mutex will be acquired
663 // before returning the Tensor.
664 // REQUIRES: IsRefType(input_dtype(index)).
665 void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
666
667 // Replace the corresponding named Ref Input to use the storage
668 // buffer used by tensor. If !lock_held the input mutex will be
669 // acquired before returning the Tensor.
670 // REQUIRES: IsRefType(input_dtype(index)).
671 Status replace_ref_input(StringPiece name, const Tensor& tensor,
672 bool lock_held);
673
674 // Deletes the Tensor object used as the Ref Input at
675 // input_index. This is not usually necessary and should be used
676 // with caution. If !lock_held the input mutex will be acquired
677 // before returning the Tensor.
678 // REQUIRES: IsRefType(input_dtype(input_index)).
679 void delete_ref_input(int input_index, bool lock_held);
680
681 // Return true if there is input at the given index. An operator has no
682 // input at index if its tensor is null. This is primarily used by the
683 // merge operator.
684 // TODO(mrry): Convert this to return Status.
685 bool has_input(int index) const;
686
687 // Returns true if all inputs are the same shape, otherwise sets the
688 // status to a non-OK value and returns false.
689 // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
690 bool ValidateInputsAreSameShape(OpKernel* op);
691
692 // Input to output forwarding.
693
694 // Set the output Ref Tensor at output_index to be an alias of the
695 // input Ref Tensor at input_index.
696 // REQUIRES: IsRefType(input_dtype(input_index)).
697 // REQUIRES: IsRefType(output_dtype(output_index)).
698 void forward_ref_input_to_ref_output(int input_index, int output_index);
699
700 // Returns true when an alias to input[input_index], reshaped to output_shape,
701 // which is safe to use for in-place computation was written to *output.
702 // Returns false if input[input_index] has a refcount greater than one, or if
703 // its type does not match the expected output type of output[output_index],
704 // or the number of elements in input[input_index] does not equal the number
705 // of elements in output_shape.
706 bool forward_input_to_output_with_shape(int input_index, int output_index,
707 const TensorShape& output_shape,
708 Tensor** output) TF_MUST_USE_RESULT;
709 Status forward_input_to_output_with_shape(StringPiece input_name,
710 StringPiece output_name,
711 const TensorShape& output_shape,
712 Tensor** output) TF_MUST_USE_RESULT;
713
714 // Returns a pointer to a Tensor aliasing the underlying buffer backing
715 // input[input_index] iff
716 // * input[input_index] is not a ref,
717 // * the data type, shape, memory type, and allocator attributes of
718 // input[input_index] are compatible with those given in dtype, shape,
719 // memory_type, and attr,
720 // * refcount on the underlying buffer is one.
721 // * Either there is no forwarding reservation for either input_index
722 // or output_index or the specified input is reserved for the specified
723 // output. More precisely:
724 //
725 // These cases mean neither input nor output has a reservation:
726 // forward_from_array = nullptr
727 // OR (input_index is not in forward_from_array AND
728 // (output_index == kNoReservation OR
729 // forward_from_array[output_index] == kNoReservation))
730 //
731 // This case means that input_index is reserved for output_index:
732 // forward_from_array[output_index] == input_index
733 //
734 // This case means the output is reserved to always be allocated,
735 // never assigned a forwarded input:
736 // forward_from_array[output_index] == kNeverForward
737 //
738 // Otherwise returns nullptr.
739 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
740 // forwarding is only safe if there are no reads via __ldg() after writes
741 // to the same address.
742 std::unique_ptr<Tensor> forward_input(
743 int input_index, int output_index, DataType output_dtype,
744 const TensorShape& output_shape, MemoryType output_memory_type,
745 const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
746
747 // Tries to forward one of the inputs given in input_indices to
748 // output[output_index]. If none of the given inputs can be forwarded, calls
749 // allocate_output() to allocate a new output buffer.
750 Status forward_input_or_allocate_output(
751 gtl::ArraySlice<int> candidate_input_indices, int output_index,
752 const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
753 Status forward_input_or_allocate_output(
754 gtl::ArraySlice<StringPiece> candidate_input_names,
755 StringPiece output_name, const TensorShape& output_shape,
756 Tensor** output) TF_MUST_USE_RESULT;
757
758 // Tries to reuse one of the inputs given in input_indices as a temporary.
759 // If none of the given inputs can be forwarded, calls
760 // allocate_temp() to allocate a new temporary buffer.
761 Status forward_input_or_allocate_temp(
762 gtl::ArraySlice<int> candidate_input_indices, DataType type,
763 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
764 Tensor* out_temp) TF_MUST_USE_RESULT;
765
766 Status forward_input_or_allocate_temp(
767 gtl::ArraySlice<int> candidate_input_indices, DataType type,
768 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
769 return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
770 AllocatorAttributes(), out_temp);
771 }
772
773 // Output
774
775 // Returns the named list-valued output in "list", as defined in the OpDef.
776 // If the named output is not list-valued, returns a one-element list.
777 Status output_list(StringPiece name, OpOutputList* list);
778
779 // If output_required(index) returns true, the OpKernel's Compute() method
780 // should call allocate_output(index, ...), set_output(index, ...),
781 // set_output_ref(index, ...), or set the status to a non-ok value.
782 // If it returns false, it may output, but is not required to do so.
783 // TODO(mrry): Convert this to return Status, and implement a string
784 // name version.
785 bool output_required(int index) const {
786 return true; // TODO(josh11b): implement
787 }
788
789 // Allocation of tensors during kernel execution inside the Compute
790 // method:
791 //
792 // There are three methods to allocate Tensors when an Op kernel
793 // executes.
794 //
795 // 1) allocate_persistent. This is only needed for Tensors that will
796 // be stored by the Op between invocations, and it *must* be used
797 // for those Tensors. The call returns a PersistentTensor, and that
798 // is the only object the Op is allowed to hold on to between
799 // invocations. When the Tensor is needed in a subsequent
800 // invocation, it can be retrieved from the PersistentTensor using
801 // the AccessTensor method. This ensures that the system is made
802 // aware of any use of the tensor's allocated memory, which is
803 // needed for correctness on asynchronous devices such as GPUs.
804 //
805 // 2) allocate_output. This should be used to allocate any tensor
806 // that is going to be used as an output from the Op at the end of
807 // the current execution. The caller indicates which output the
808 // Tensor will be assigned to, and the call returns the
809 // newly-allocated Tensor. The Tensor can subsequently be assigned
810 // to during kernel execution, and will be used as the designated
811 // output when the kernel execution completes.
812 //
813 // 3) allocate_temp. This should be used to allocate any scratch
814 // storage that is needed while the kernel is executing, and will
815 // not be retained by the Op.
816 //
817 // In some cases a Tensor needs to be used as an output even though
818 // it was previously allocated elsewhere. The Tensor may have been
819 // passed as an input, or stored in a PersistentTensor during a
820 // previous kernel execution, or allocated earlier in the kernel
821 // execution at a time when it was not known which output it would
822 // be assigned to. In this case the kernel can use set_output or
823 // set_output_ref to indicate that the tensor should be used as the
824 // designated output. It is legal to use any previously-allocated
825 // Tensor as an argument to set_output or set_output_ref, including
826 // Tensors allocated via allocate_temp. There may be a performance
827 // penalty to using a Tensor that was not allocated using
828 // allocate_output. This is because allocate_output uses the
829 // AllocatorAttributes stored in output_attr_array for the
830 // designated output. In some cases, using the wrong attributes may
831 // cause an extra copy of the Tensor's buffer.
832
833 // Allocates output for the specified output index with shape.
834 // OpKernelContext retains ownership of the returned pointer. See
835 // comment above.
836 //
837 // If memory allocation fails, returns an error status.
838 //
839 // REQUIRES: !IsRefType(expected_output_dtype(index))
840 Status allocate_output(int index, const TensorShape& shape,
841 Tensor** tensor) TF_MUST_USE_RESULT;
842 Status allocate_output(StringPiece name, const TensorShape& shape,
843 Tensor** tensor) TF_MUST_USE_RESULT;
844 // The following methods use the supplied attributes instead of
845 // those in output_attr_array. The caller is responsible for
846 // ensuring that the attributes are "compatible" with the
847 // output_attr_array, e.g. the tensor is allocated on the correct
848 // device. See comment above.
849 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
850 AllocatorAttributes attr) TF_MUST_USE_RESULT;
851 Status allocate_output(StringPiece name, const TensorShape& shape,
852 Tensor** tensor,
853 AllocatorAttributes attr) TF_MUST_USE_RESULT;
854
855 // Allocates a temporary Tensor of the specified type and
856 // shape. Devices such as GPUs that enqueue Ops for lazy execution
857 // may retain references to the temporary tensors after the Op's
858 // Compute method has run. See comment above.
859 Status allocate_temp(DataType type, const TensorShape& shape,
860 Tensor* out_temp, AllocatorAttributes allocator_attr,
861 const AllocationAttributes& allocation_attr);
862 Status allocate_temp(DataType type, const TensorShape& shape,
863 Tensor* out_temp, AllocatorAttributes allocator_attr) {
864 return allocate_temp(type, shape, out_temp, allocator_attr,
865 AllocationAttributes());
866 }
867 Status allocate_temp(DataType type, const TensorShape& shape,
868 Tensor* out_temp) {
869 return allocate_temp(type, shape, out_temp, AllocatorAttributes());
870 }
871
872 // Allocates a Tensor of the specified type and shape which the Op
873 // plans to maintain as persistent state. out_persistent holds the
874 // PersistentTensor which is the object the caller should store. For
875 // convenience, if out_tensor is non-null then it will be filled in
876 // with a Tensor* pointing to the newly-allocated tensor which the
877 // caller can use instead of calling
878 // out_persistent->AccessTensor. The caller does not own out_tensor
879 // and should not keep a copy of it. See comment above.
880 Status allocate_persistent(DataType type, const TensorShape& shape,
881 PersistentTensor* out_persistent,
882 Tensor** out_tensor, AllocatorAttributes attr);
883 Status allocate_persistent(DataType type, const TensorShape& shape,
884 PersistentTensor* out_persistent,
885 Tensor** out_tensor) {
886 return allocate_persistent(type, shape, out_persistent, out_tensor,
887 AllocatorAttributes());
888 }
889
890 // Copies a tensor (allocated by the caller) to the specified output
891 // index. REQUIRES: !IsRefType(expected_output_dtype(index))
892 // REQUIRES: 'tensor' must have the same MemoryType as
893 // output_memory_types[index]. See comment above.
894 Status set_output(StringPiece name, const Tensor& tensor);
895
896 // To output a reference. Caller retains ownership of mu and tensor_for_ref,
897 // and they must outlive all uses within the step. See comment above.
898 // REQUIRES: IsRefType(expected_output_dtype(index))
899 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
900
901 // Returns nullptr if allocate_output() or set_output() have not been called.
902 Status mutable_output(StringPiece name, Tensor** tensor);
903
904 // Transfers ownership of an output tensor to the caller.
905 // NOTE: For non-reference outputs, the caller takes responsibility
906 // for deletion. For reference outputs, the caller does NOT take
907 // responsibility for deletion.
908 Status release_output(StringPiece name, TensorValue* value);
909
910 // Records device specific state about how the input tensors were
911 // computed.
912 //
913 // If using the templated function, the type must be a subclass
914 // of DeviceContext.
915 //
916 // Get the DeviceContext used for the index input. Returns nullptr
917 // if no DeviceContext was provided.
918 template <typename T>
919 T* input_device_context(int index);
920 DeviceContext* input_device_context(int index);
921
922 // Return the DeviceContext that should be used for this Op.
923 //
924 // If using the templated function, the type must be a subclass
925 // of DeviceContext.
926 //
927 // Returns nullptr if the device did not provide one.
928 template <typename T>
929 T* op_device_context();
930 DeviceContext* op_device_context() {
931 DeviceContext* ret = params_->op_device_context;
932 if (ret == nullptr) {
933 auto* dev_info = device()->tensorflow_gpu_device_info();
934 if (dev_info) ret = dev_info->default_context;
935 }
936 return ret;
937 }
938
939 AllocatorAttributes input_alloc_attr(int index) const {
940 if (params_->input_alloc_attrs == nullptr) {
941 return AllocatorAttributes();
942 } else {
943 DCHECK_GE(index, 0);
944 DCHECK_LT(index, params_->input_alloc_attrs->size());
945 return (*params_->input_alloc_attrs)[index];
946 }
947 }
948
949 AllocatorAttributes output_alloc_attr(int index) const {
950 return params_->output_attr_array[index];
951 }
952
953 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const {
954 mutex_lock lock(mu_);
955 gtl::InlinedVector<WrappedAllocator, 4> retrieved = wrapped_allocators_;
956 return retrieved;
957 }
958
959 // Communication.
960 //
961 // An op kernel communicates with outside environment through
962 // Rendezvous Send() and Recv().
963 Rendezvous* rendezvous() const { return params_->rendezvous; }
964
965 CollectiveExecutor* collective_executor() const {
966 return params_->collective_executor;
967 }
968
969 // An op kernel can access the session state it belongs to.
970 SessionState* session_state() const { return params_->session_state; }
971
972 // An op kernel can access the tensor store of the run it belongs to.
973 TensorStore* tensor_store() const { return params_->tensor_store; }
974
975 // Function call support.
976 //
977 // If this kernel invocation is within a function execution,
978 // call_frame() returns the call frame for the function call.
979 CallFrameInterface* call_frame() const { return params_->call_frame; }
980
981 // If not nullptr, the kernel invoke functions defined in the
982 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
983 FunctionLibraryRuntime* function_library() const {
984 return params_->function_library;
985 }
986
987 std::function<void(std::function<void()>)>* runner() const {
988 return params_->runner;
989 }
990 StepStatsCollector* stats_collector() const {
991 return params_->stats_collector;
992 }
993
994 // Shared resources accessible to this kernel.
995 ResourceMgr* resource_manager() const { return params_->resource_manager; }
996
997 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
998 return params_->slice_reader_cache;
999 }
1000
1001 // Execution.
1002 //
1003 // OpKernels can use these eigen devices to carry out their
1004 // numerical computation.
1005 const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1006 return *device()->eigen_cpu_device();
1007 }
1008 const Eigen::GpuDevice& eigen_gpu_device() const {
1009 return params_->eigen_gpu_device->device();
1010 }
1011#ifdef TENSORFLOW_USE_SYCL
1012 const Eigen::SyclDevice& eigen_sycl_device() const {
1013 return *device()->eigen_sycl_device();
1014 }
1015#endif
1016 template <typename EigenDeviceType>
1017 const EigenDeviceType& eigen_device() const;
1018
1019 // Error handling.
1020
1021 // If expected_inputs == inputs() and expected_outputs == output_types(),
1022 // returns OK, else returns INVALID_ARGUMENT with an error message.
1023 // Recommended for Ops with dynamic signatures, where validation can only
1024 // be performed at runtime.
1025 Status MatchSignature(const DataTypeSlice expected_inputs,
1026 const DataTypeSlice expected_outputs);
1027
1028 // An OpKernel should call SetStatus() if Compute() encounters an
1029 // error.
1030 void SetStatus(const Status& status);
1031 const Status& status() const { return status_; }
1032
1033 // Cancellation.
1034 //
1035 // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an
1036 // example of how to use this API.
1037 CancellationManager* cancellation_manager() const {
1038 return params_->cancellation_manager;
1039 }
1040
1041 // Other accessors.
1042
1043 // For control flow.
1044 FrameAndIter frame_iter() const { return params_->frame_iter; }
1045 bool is_input_dead() const { return params_->is_input_dead; }
1046 bool* is_output_dead() { return &is_output_dead_; }
1047
1048 // May be used, e.g., to get GPU handles, etc.
1049 // TODO(tucker): Add example usage.
1050 DeviceBase* device() const { return params_->device; }
1051
1052 // Retrieve list of referenced tensors in out_vector. Once this is
1053 // called, it is not legal to reference any more tensors. Should
1054 // not be called from Op kernels.
1055 void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
1056
1057 // Per-step container for use by white-listed internal ops.
1058 ScopedStepContainer* step_container() const {
1059 return params_->step_container;
1060 }
1061
1062 // Helper routines for the OP_REQUIRES macros
1063 void CtxFailure(const Status& s);
1064 void CtxFailureWithWarning(const Status& s);
1065 void CtxFailure(const char* file, int line, const Status& s);
1066 void CtxFailureWithWarning(const char* file, int line, const Status& s);
1067
1068 // Unrecommended functions: these are functions that have some
1069 // current uses but are not recommended for use, and may go away at
1070 // some future major version release.
1071 //
1072 // The following functions all have versions that return Status
1073 // to capture error conditions, and are strongly preferred.
1074 Tensor* mutable_output(int index);
1075 void set_output(int index, const Tensor& tensor);
1076 mutex* input_ref_mutex(int index);
1077 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1078 TensorValue release_output(int index);
1079
1080 bool track_allocations() const { return params_->track_allocations; }
1081
1082 // Records temp memory allocation. Tensor object is recorded to identify the
1083 // case where temp memory is used as output memory.
1084 void record_temp_memory_allocation(int64 size, const Tensor& t)
1085 LOCKS_EXCLUDED(stats_mu_);
1086
1087 // Returns recorded size of temporary memory;
1088 int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
1089
1090 // Records persistent memory allocation, size can be negative indicating
1091 // deallocation.
1092 void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
1093 LOCKS_EXCLUDED(stats_mu_);
1094
1095 // Returns recorded size and ids of persistent memory.
1096 int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
1097
1098 std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_);
1099
1100 // Resets counters for temp and persistent memory and recorded ids.
1101 void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_);
1102
1103 bool input_is_ref(int index) const;
1104
1105 private:
1106 Allocator* get_allocator(AllocatorAttributes attr);
1107
1108 // Internal method to add a tensor's buffer to the list of buffers
1109 // referenced during the execution of the Op, so that GPUs may
1110 // accurately track the memory that may not be reused until the Op
1111 // execution completes.
1112 void record_tensor_reference(const Tensor& tensor);
1113 void really_record_tensor_reference(const Tensor& tensor);
1114
1115 // Internal common method used when allocating tensor memory
1116 Status allocate_tensor(DataType type, const TensorShape& shape,
1117 Tensor* out_tensor,
1118 AllocatorAttributes allocator_attr) {
1119 return allocate_tensor(type, shape, out_tensor, allocator_attr,
1120 AllocationAttributes());
1121 }
1122
1123 Status allocate_tensor(DataType type, const TensorShape& shape,
1124 Tensor* out_tensor, AllocatorAttributes allocator_attr,
1125 const AllocationAttributes& allocation_attr);
1126
1127 // This is called by PersistentTensor::AccessTensor whenever the
1128 // wrapped tensor is retrieved, to ensure the runtime knows that the
1129 // Tensor is being accessed within an Op. This is necessary for
1130 // memory safety of devices like GPUs that queue Ops for
1131 // asynchronous execution after the Compute() method completes.
1132 friend class PersistentTensor;
1133 void NotifyUseOfPersistentTensor(const Tensor& tensor);
1134
1135 Status status_;
1136 friend class CollectiveExecutor; // for access to params_
1137 Params* params_; // not owned
1138 mutable mutex mu_; // mutable so const accessors can acquire the lock
1139 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
1140 gtl::InlinedVector<TensorValue, 4> outputs_;
1141
1142 // Constructed only if <params->record_tensor_accesses>.
1143 ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
1144
1145 bool is_output_dead_ = false;
1146
1147 // The following data members are only used when allocation tracking is
1148 // enabled.
1149 mutable mutex stats_mu_;
1150 int64 temp_memory_allocated_ GUARDED_BY(stats_mu_);
1151 int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_);
1152 std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>>
1153 temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_);
1154 std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_
1155 GUARDED_BY(stats_mu_);
1156
1157 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1158};
1159
1160// Register your OpKernel by specifying the Op's name, the device the
1161// kernel runs on, any type attr constraints for this kernel, any
1162// host-memory args, and the class to instantiate. Examples:
1163//
1164// // A kernel that supports all types.
1165// REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1166//
1167// // The following are equivalent ways of specifying that the kernel only
1168// // works if the "T" type attr is set to DT_FLOAT.
1169// REGISTER_KERNEL_BUILDER(
1170// Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1171// SubOp<float>);
1172// // (You would then repeat this for every type supported by "Sub".)
1173//
1174// // This form allows you to specify a list of types as the constraint.
1175// REGISTER_KERNEL_BUILDER(Name("Sub")
1176// .Device(DEVICE_CPU)
1177// .TypeConstraint("T", {DT_FLOAT}),
1178// SubOp<float>);
1179//
1180// // A kernel that expects one of the input tensors in host memory.
1181// REGISTER_KERNEL_BUILDER(
1182// Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1183//
1184// See kernel_def_builder for details.
1185
1186// Instantiate an OpKernel that has been registered. Returns nullptr
1187// if no operation for that type of device / input signature combination
1188// (and a NOT_FOUND *status), or there is an error in construction (and
1189// an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
1190// of the returned pointer.
1191// EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1192// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1193std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1194 DeviceBase* device,
1195 Allocator* allocator,
1196 const NodeDef& def,
1197 int graph_def_version, Status* status);
1198Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1199 Allocator* allocator, FunctionLibraryRuntime* flib,
1200 const NodeDef& def, int graph_def_version,
1201 OpKernel** kernel);
1202
1203// Returns into 'device_types' the subset of prioritized_types that this
1204// binary has registered for the given NodeDef.
1205//
1206// REQUIRES: * 'device_types' is not nullptr.
1207// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1208Status SupportedDeviceTypesForNode(
1209 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1210 DeviceTypeVector* device_types);
1211
1212// Returns a message with a description of the kernels registered for op
1213// `op_name`.
1214string KernelsRegisteredForOp(StringPiece op_name);
1215
1216// Call once after Op registration has completed.
1217Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1218
1219// -----------------------------------------------------------------------------
1220// OpKernel registration implementation follows, please ignore.
1221
1222// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1223namespace register_kernel {
1224
1225class Name : public KernelDefBuilder {
1226 public:
1227 // With selective registration, kernels whose implementation class is not used
1228 // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in
1229 // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an
1230 // implementation class with a used kernel would get through that mechanism.
1231 //
1232 // This mechanism stops that registration by changing the name of the kernel
1233 // for the unused op to one that is ignored by
1234 // OpKernelRegistrar::InitInternal. Note that this method alone is
1235 // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at
1236 // compilation time, so this method doesn't actually reduce code size.
1237 explicit Name(const char* op)
1238 : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
1239};
1240
1241namespace system {
1242
1243class Name : public KernelDefBuilder {
1244 public:
1245 // For system kernels, we ignore selective registration and
1246 // unconditionally register the kernel.
1247 explicit Name(const char* op) : KernelDefBuilder(op) {}
1248};
1249
1250} // namespace system
1251
1252} // namespace register_kernel
1253
1254#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1255 REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
1256
1257#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1258 REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1259
1260#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
1261 constexpr bool should_register_##ctr##__flag = \
1262 SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
1263 static ::tensorflow::kernel_factory::OpKernelRegistrar \
1264 registrar__body__##ctr##__object( \
1265 should_register_##ctr##__flag \
1266 ? ::tensorflow::register_kernel::kernel_builder.Build() \
1267 : nullptr, \
1268 #__VA_ARGS__, \
1269 [](::tensorflow::OpKernelConstruction* context) \
1270 -> ::tensorflow::OpKernel* { \
1271 return new __VA_ARGS__(context); \
1272 });
1273
1274// The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1275// `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1276// unconditionally even when selective registration is used.
1277#define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1278 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \
1279 __VA_ARGS__)
1280
1281#define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1282 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1283
1284#define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
1285 static ::tensorflow::kernel_factory::OpKernelRegistrar \
1286 registrar__body__##ctr##__object( \
1287 ::tensorflow::register_kernel::system::kernel_builder.Build(), \
1288 #__VA_ARGS__, \
1289 [](::tensorflow::OpKernelConstruction* context) \
1290 -> ::tensorflow::OpKernel* { \
1291 return new __VA_ARGS__(context); \
1292 });
1293
1294void* GlobalKernelRegistry();
1295
1296// If node_def has a corresponding kernel registered on device_type,
1297// returns OK and fill in the kernel def and kernel_class_name. <def> and
1298// <kernel_class_name> may be null.
1299Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1300 const KernelDef** def, string* kernel_class_name);
1301
1302// Writes a list of all registered kernels to LOG(INFO), to help users debug
1303// missing kernel errors.
1304void LogAllRegisteredKernels();
1305
1306namespace kernel_factory {
1307
1308class OpKernelRegistrar {
1309 public:
1310 typedef OpKernel* (*Factory)(OpKernelConstruction*);
1311
1312 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1313 Factory factory) {
1314 // Perform the check in the header to allow compile-time optimization
1315 // to a no-op, allowing the linker to remove the kernel symbols.
1316 if (kernel_def != nullptr) {
1317 InitInternal(kernel_def, kernel_class_name, factory);
1318 }
1319 }
1320
1321 private:
1322 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1323 Factory factory);
1324};
1325
1326} // namespace kernel_factory
1327
1328// -----------------------------------------------------------------------------
1329// Template and inline method implementations, please ignore
1330
1331template <class T>
1332Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1333 return GetNodeAttr(def(), attr_name, value);
1334}
1335
1336inline DataType OpKernelContext::input_dtype(int index) const {
1337 DCHECK_GE(index, 0);
1338 DCHECK_LT(index, num_inputs());
1339 const TensorValue& value((*params_->inputs)[index]);
1340 if (value.is_ref()) {
1341 return MakeRefType(value->dtype());
1342 } else {
1343 return value->dtype();
1344 }
1345}
1346
1347inline MemoryType OpKernelContext::input_memory_type(int index) const {
1348 DCHECK_GE(index, 0);
1349 DCHECK_LT(index, num_inputs());
1350 return op_kernel().input_memory_types()[index];
1351}
1352
1353inline DataType OpKernelContext::expected_output_dtype(int index) const {
1354 DCHECK_GE(index, 0);
1355 DCHECK_LT(index, num_outputs());
1356 return params_->op_kernel->output_type(index);
1357}
1358
1359inline MemoryType OpKernelContext::output_memory_type(int index) const {
1360 DCHECK_GE(index, 0);
1361 DCHECK_LT(index, num_outputs());
1362 return op_kernel().output_memory_types()[index];
1363}
1364
1365inline bool OpKernelContext::input_is_ref(int index) const {
1366 const TensorValue& value((*params_->inputs)[index]);
1367 return value.is_ref();
1368}
1369
1370inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
1371 DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
1372 params_->record_tensor_accesses);
1373 if (params_->record_tensor_accesses) {
1374 really_record_tensor_reference(tensor);
1375 }
1376}
1377
1378inline void OpKernelContext::retrieve_accessed_tensors(
1379 TensorReferenceVector* out_vector) {
1380 if (params_->record_tensor_accesses) {
1381 mutex_lock l(mu_);
1382 referenced_tensors_->FreezeAndReturnReferences(out_vector);
1383 }
1384}
1385
1386// no input if tensor == nullptr.
1387inline bool OpKernelContext::has_input(int index) const {
1388 DCHECK_GE(index, 0);
1389 DCHECK_LT(index, num_inputs());
1390 return (*params_->inputs)[index].tensor != nullptr;
1391}
1392
1393inline mutex* OpKernelContext::input_ref_mutex(int index) {
1394 DCHECK_GE(index, 0);
1395 DCHECK_LT(index, num_inputs());
1396 DCHECK(input_is_ref(index));
1397 return (*params_->inputs)[index].mutex_if_ref;
1398}
1399
1400inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
1401 if (t.IsInitialized()) {
1402 record_tensor_reference(t);
1403 }
1404}
1405
1406inline Tensor* OpKernelContext::mutable_output(int index) {
1407 DCHECK_GE(index, 0);
1408 DCHECK_LT(index, num_outputs());
1409 // No need to record_tensor_reference since the output must already
1410 // have been set by a call that did so.
1411 return outputs_[index].tensor;
1412}
1413
1414inline TensorValue OpKernelContext::release_output(int index) {
1415 DCHECK_GE(index, 0);
1416 DCHECK_LT(index, num_outputs());
1417 TensorValue value = outputs_[index];
1418 outputs_[index] = TensorValue();
1419 return value;
1420}
1421
1422inline Status OpKernelContext::forward_input_or_allocate_output(
1423 gtl::ArraySlice<int> candidate_input_indices, int output_index,
1424 const TensorShape& output_shape, Tensor** output) {
1425 for (int input_index : candidate_input_indices) {
1426 if (forward_input_to_output_with_shape(input_index, output_index,
1427 output_shape, output)) {
1428 return Status::OK();
1429 }
1430 }
1431 return allocate_output(output_index, output_shape, output);
1432}
1433
1434inline Status OpKernelContext::forward_input_or_allocate_output(
1435 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1436 const TensorShape& output_shape, Tensor** output) {
1437 for (const StringPiece& input_name : candidate_input_names) {
1438 if (forward_input_to_output_with_shape(input_name, output_name,
1439 output_shape, output)
1440 .ok()) {
1441 return Status::OK();
1442 }
1443 }
1444 return allocate_output(output_name, output_shape, output);
1445}
1446
1447template <typename T>
1448T* OpKernelContext::op_device_context() {
1449 static_assert(std::is_base_of<DeviceContext, T>::value,
1450 "T is not a subclass of DeviceContext");
1451 return static_cast<T*>(op_device_context());
1452}
1453
1454template <typename T>
1455T* OpKernelContext::input_device_context(int index) {
1456 DCHECK_GE(index, 0);
1457 DCHECK_LT(index, params_->input_device_contexts->size());
1458 static_assert(std::is_base_of<DeviceContext, T>::value,
1459 "T is not a subclass of DeviceContext");
1460 return static_cast<T*>((*params_->input_device_contexts)[index]);
1461}
1462
1463inline DeviceContext* OpKernelContext::input_device_context(int index) {
1464 DCHECK_GE(index, 0);
1465 DCHECK_LT(index, params_->input_device_contexts->size());
1466 return (*params_->input_device_contexts)[index];
1467}
1468
1469inline const Tensor& OpInputList::operator[](int i) const {
1470 DCHECK_GE(i, 0);
1471 DCHECK_LT(i, stop_ - start_);
1472 return ctx_->input(start_ + i);
1473}
1474
1475inline mutex* OpMutableInputList::ref_mutex(int i) {
1476 DCHECK_GE(i, 0);
1477 DCHECK_LT(i, stop_ - start_);
1478 return ctx_->input_ref_mutex(start_ + i);
1479}
1480
1481inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1482 DCHECK_GE(i, 0);
1483 DCHECK_LT(i, stop_ - start_);
1484 return ctx_->mutable_input(start_ + i, lock_held);
1485}
1486
1487inline Tensor* OpOutputList::operator[](int i) {
1488 DCHECK_GE(i, 0);
1489 DCHECK_LT(i, stop_ - start_);
1490 return ctx_->mutable_output(start_ + i);
1491}
1492
1493inline bool OpOutputList::required(int i) const {
1494 DCHECK_GE(i, 0);
1495 DCHECK_LT(i, stop_ - start_);
1496 return ctx_->output_required(start_ + i);
1497}
1498
1499inline DataType OpOutputList::expected_output_dtype(int i) const {
1500 DCHECK_GE(i, 0);
1501 DCHECK_LT(i, stop_ - start_);
1502 return ctx_->expected_output_dtype(start_ + i);
1503}
1504
1505inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1506 Tensor** output) {
1507 DCHECK_GE(i, 0);
1508 DCHECK_LT(i, stop_ - start_);
1509 return ctx_->allocate_output(start_ + i, shape, output);
1510}
1511
1512inline void OpOutputList::set(int i, const Tensor& tensor) {
1513 DCHECK_GE(i, 0);
1514 DCHECK_LT(i, stop_ - start_);
1515 ctx_->set_output(start_ + i, tensor);
1516}
1517
1518inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1519 DCHECK_GE(i, 0);
1520 DCHECK_LT(i, stop_ - start_);
1521 ctx_->set_output_ref(i, mu, tensor_for_ref);
1522}
1523
1524// Convenience macros for asserting and handling exceptional conditions.
1525// Analogous to the CHECK* macros provided by logging.h.
1526//
1527// Example use:
1528// void Compute(OperationContext* context) {
1529// OP_REQUIRES(context, context->num_inputs() == 2,
1530// errors::InvalidArgument("FooOp requires 2 arguments"));
1531// ...
1532// Status status = SomeUncertainMethod();
1533// OP_REQUIRES_OK(context, status);
1534// ...
1535// }
1536
1537#define OP_REQUIRES(CTX, EXP, STATUS) \
1538 do { \
1539 if (!TF_PREDICT_TRUE(EXP)) { \
1540 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1541 return; \
1542 } \
1543 } while (0)
1544
1545#define OP_REQUIRES_OK(CTX, ...) \
1546 do { \
1547 ::tensorflow::Status _s(__VA_ARGS__); \
1548 if (!TF_PREDICT_TRUE(_s.ok())) { \
1549 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1550 return; \
1551 } \
1552 } while (0)
1553
1554#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
1555 do { \
1556 if (!TF_PREDICT_TRUE(EXP)) { \
1557 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1558 (CALLBACK)(); \
1559 return; \
1560 } \
1561 } while (0)
1562
1563#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
1564 do { \
1565 ::tensorflow::Status _s(STATUS); \
1566 if (!TF_PREDICT_TRUE(_s.ok())) { \
1567 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1568 (CALLBACK)(); \
1569 return; \
1570 } \
1571 } while (0)
1572
1573} // namespace tensorflow
1574
1575#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
1576