1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
17#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
18
19#include "tensorflow/c/eager/c_api.h"
20#include "tensorflow/core/lib/core/status.h"
21#include "tensorflow/core/lib/gtl/inlined_vector.h"
22#include <Python.h>
23
24typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
25 TFE_InputTensorHandles;
26typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2>
27 TFE_OutputTensorHandles;
28
29// Execute a TensorFlow operation.
30//
31// 'device_name': Name of the device on which to execute the operation, or NULL
32// for automatic selection.
33// 'op_name': Name of the TensorFlow op to execute.
34// 'inputs': An array of TFE_TensorHandle*'s of size 'num_inputs'. These tensors
35// will be provided as input to the operation.
36// 'attrs': A Python tuple alternating names and attr values.
37// 'outputs': A pointer to a TFE_OutputTensorHandles in which outputs will
38// placed. On success, its elements will be filled in and the
39// caller takes ownership of each returned TFE_TensorHandle.
40// 'outputs' MUST be sized to be at least as large as the number
41// of tensors produced by the operation and will be resized to
42// the actual number of tensors produced.
43void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
44 const char* op_name, TFE_InputTensorHandles* inputs,
45 PyObject* attrs, TFE_OutputTensorHandles* outputs,
46 TF_Status* out_status);
47
48// Registers e as the Exception class for handling not ok Status. Returns
49// Py_None if registration succeeds, else throws a TypeError and returns NULL.
50//
51// This function is not thread-safe.
52PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
53
54// Registers e as the type of the ResourceVariable class.
55// Returns Py_None if registration succeeds, else throws a TypeError and returns
56// NULL.
57//
58// This function is not thread-safe.
59PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e);
60
61// Registers e as the Exception to be raised when the conditions of
62// TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
63// is a signal to the calling code that it should fall back to the safer (and
64// more complete) code path.
65//
66// This function is not thread-safe.
67PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
68
69// Registers e as the backward_function_getter.
70// The registered function creates a backward function (a function that can
71// return the gradient of the inputs an op given the gradient of it's outputs).
72// The registered function will be passed the following arguments:
73// op_name, attrs, num_inputs, op_inputs, op_outputs
74//
75// This function is not thread-safe.
76PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e);
77
78// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
79// `exception` if not nullptr, else using the class registered via
80// TFE_Py_RegisterExceptionClass), and returns -1.
81int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception);
82
83// Returns 0 if 'status' is ok. Otherwise, raises an exception (using
84// `exception` if not nullptr, else using the class registered via
85// TFE_Py_RegisterExceptionClass), and returns -1.
86int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
87 PyObject* exception);
88
89// Returns the string associated with the passed-in python object.
90char* TFE_GetPythonString(PyObject* o);
91
92// Returns a unique id on each call.
93int64_t get_uid();
94
95// Wraps the output of get_uid as a Python Long object. Ownership is passed to
96// the caller.
97PyObject* TFE_Py_UID();
98
99// Deleter for Context objects, called from the Capsule that owns it.
100void TFE_DeleteContextCapsule(PyObject* context);
101
102// Returns true if o is an instance of EagerTensor, but not a subclass. Else
103// returns false.
104bool EagerTensor_CheckExact(const PyObject* o);
105
106// Helper function to construct a new EagerTensor from a TFE_TensorHandle.
107PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle);
108
109// Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
110TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
111
112// Creates the `EagerTensor` class by subclassing `base_class` and returns the
113// newly created type, or nullptr on error.
114PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
115
116// Creates a new tape and adds it to the active set. `persistent` must be a
117// PyBool_Type, i.e either Py_True or Py_False
118PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
119
120// Removes the passed tape from the set of active tapes.
121void TFE_Py_TapeSetRemove(PyObject* tape);
122
123// Returns true if the tape stack is empty.
124PyObject* TFE_Py_TapeSetIsEmpty();
125
126PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
127void TFE_Py_TapeSetWatch(PyObject* tensor);
128void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
129
130// Stops any gradient recording on the current thread.
131void TFE_Py_TapeSetStopOnThread();
132
133// Restarts gradient recording on the current thread.
134void TFE_Py_TapeSetRestartOnThread();
135
136// Records an operation in the gradient tape stack.type is a string for the
137// operation type, used in the backprop code. output_tensors should be a list of
138// python ops.Tensor objects. input_tensor_ids should be a list of python
139// integers with the ids of the input tensors of the recorded
140// operation. backward_function should be the function to be called during
141// backprop to, given the gradients of the output tensors, produce the gradients
142// of the input tensors.
143void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
144 PyObject* input_tensor_ids,
145 PyObject* backward_function);
146
147// Watches the given variable object on the given tape.
148void TFE_Py_TapeSetWatchVariable(PyObject* variable);
149
150// Computes a gradient based on information recorded on the tape.`tape` must
151// have been produced by TFE_Py_NewTape. `vspace` must be a
152// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
153// lists of Tensor objects. `output_gradients` is either None or a python list
154// of either Tensor or None, and if not None should have the same length as
155// target.
156PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
157 PyObject* target, PyObject* sources,
158 PyObject* output_gradients, TF_Status* status);
159
160// Execute a tensorflow operation assuming that all provided inputs are
161// correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,
162// it will simply fail with a NotImplementedError.
163//
164// The first PyObject* is unused.
165// The "args" PyObject* is meant to be a tuple with the following structure:
166// Item 1: The TFE Context
167// Item 2: device_name: Name of the device on which to execute the operation,
168// or NULL for automatic selection.
169// Item 3: op_name: Name of the TensorFlow op to execute.
170// Item 4: name: An optional name for the operation.
171// Item 5: List representing all callbacks to execute after successful
172// op execute.
173// Item 6 onwards: inputs - This is a list of inputs followed by a list of
174// attrs. It is not necessary for type attrs to be present.
175//
176// This is named _C since there doesn't seem to be any way to make it visible
177// in the SWIG interface without renaming due to the use of the %native
178// directive.
179PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args);
180
181// Record the gradient for a given op.
182PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
183 PyObject* attrs, PyObject* results,
184 PyObject* name);
185
186// Returns the set of variables watched by the given tape.
187PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
188
189// Returns an EagerTensor of dimension [len(`tensor_list`)] containing
190// the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words,
191// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
192// `tensor_list`. For example, if `tensor_list` contains tensors of with shapes
193// [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with
194// `slice_dim` equal to 1 will return [2, 5, 7].
195// On error, returns nullptr and sets python exception.
196// REQUIRES: `tensor_list` is a python list of EagerTensors
197// REQUIRES: `slice_dim` is non-negative and smaller than the rank of all
198// tensors in `tensor_list`.
199PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim);
200
201#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
202