| 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | |
| 16 | #ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ |
| 17 | #define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ |
| 18 | |
| 19 | #include "tensorflow/c/eager/c_api.h" |
| 20 | #include "tensorflow/core/lib/core/status.h" |
| 21 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
| 22 | #include <Python.h> |
| 23 | |
| 24 | typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> |
| 25 | TFE_InputTensorHandles; |
| 26 | typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> |
| 27 | TFE_OutputTensorHandles; |
| 28 | |
| 29 | // Execute a TensorFlow operation. |
| 30 | // |
| 31 | // 'device_name': Name of the device on which to execute the operation, or NULL |
| 32 | // for automatic selection. |
| 33 | // 'op_name': Name of the TensorFlow op to execute. |
| 34 | // 'inputs': An array of TFE_TensorHandle*'s of size 'num_inputs'. These tensors |
| 35 | // will be provided as input to the operation. |
| 36 | // 'attrs': A Python tuple alternating names and attr values. |
| 37 | // 'outputs': A pointer to a TFE_OutputTensorHandles in which outputs will |
| 38 | // placed. On success, its elements will be filled in and the |
| 39 | // caller takes ownership of each returned TFE_TensorHandle. |
| 40 | // 'outputs' MUST be sized to be at least as large as the number |
| 41 | // of tensors produced by the operation and will be resized to |
| 42 | // the actual number of tensors produced. |
| 43 | void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, |
| 44 | const char* op_name, TFE_InputTensorHandles* inputs, |
| 45 | PyObject* attrs, TFE_OutputTensorHandles* outputs, |
| 46 | TF_Status* out_status); |
| 47 | |
| 48 | // Registers e as the Exception class for handling not ok Status. Returns |
| 49 | // Py_None if registration succeeds, else throws a TypeError and returns NULL. |
| 50 | // |
| 51 | // This function is not thread-safe. |
| 52 | PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); |
| 53 | |
| 54 | // Registers e as the type of the ResourceVariable class. |
| 55 | // Returns Py_None if registration succeeds, else throws a TypeError and returns |
| 56 | // NULL. |
| 57 | // |
| 58 | // This function is not thread-safe. |
| 59 | PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e); |
| 60 | |
| 61 | // Registers e as the Exception to be raised when the conditions of |
| 62 | // TFE_Py_FastPathExecute_C have not been met. When this exception is set, it |
| 63 | // is a signal to the calling code that it should fall back to the safer (and |
| 64 | // more complete) code path. |
| 65 | // |
| 66 | // This function is not thread-safe. |
| 67 | PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e); |
| 68 | |
| 69 | // Registers e as the backward_function_getter. |
| 70 | // The registered function creates a backward function (a function that can |
| 71 | // return the gradient of the inputs an op given the gradient of it's outputs). |
| 72 | // The registered function will be passed the following arguments: |
| 73 | // op_name, attrs, num_inputs, op_inputs, op_outputs |
| 74 | // |
| 75 | // This function is not thread-safe. |
| 76 | PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e); |
| 77 | |
| 78 | // Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using |
| 79 | // `exception` if not nullptr, else using the class registered via |
| 80 | // TFE_Py_RegisterExceptionClass), and returns -1. |
| 81 | int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception); |
| 82 | |
| 83 | // Returns 0 if 'status' is ok. Otherwise, raises an exception (using |
| 84 | // `exception` if not nullptr, else using the class registered via |
| 85 | // TFE_Py_RegisterExceptionClass), and returns -1. |
| 86 | int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, |
| 87 | PyObject* exception); |
| 88 | |
| 89 | // Returns the string associated with the passed-in python object. |
| 90 | char* TFE_GetPythonString(PyObject* o); |
| 91 | |
| 92 | // Returns a unique id on each call. |
| 93 | int64_t get_uid(); |
| 94 | |
| 95 | // Wraps the output of get_uid as a Python Long object. Ownership is passed to |
| 96 | // the caller. |
| 97 | PyObject* TFE_Py_UID(); |
| 98 | |
| 99 | // Deleter for Context objects, called from the Capsule that owns it. |
| 100 | void TFE_DeleteContextCapsule(PyObject* context); |
| 101 | |
| 102 | // Returns true if o is an instance of EagerTensor, but not a subclass. Else |
| 103 | // returns false. |
| 104 | bool EagerTensor_CheckExact(const PyObject* o); |
| 105 | |
| 106 | // Helper function to construct a new EagerTensor from a TFE_TensorHandle. |
| 107 | PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle); |
| 108 | |
| 109 | // Extracts the handle inside EagerTensor object `o`. Returns nullptr on error. |
| 110 | TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); |
| 111 | |
| 112 | // Creates the `EagerTensor` class by subclassing `base_class` and returns the |
| 113 | // newly created type, or nullptr on error. |
| 114 | PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); |
| 115 | |
| 116 | // Creates a new tape and adds it to the active set. `persistent` must be a |
| 117 | // PyBool_Type, i.e either Py_True or Py_False |
| 118 | PyObject* TFE_Py_TapeSetNew(PyObject* persistent); |
| 119 | |
| 120 | // Removes the passed tape from the set of active tapes. |
| 121 | void TFE_Py_TapeSetRemove(PyObject* tape); |
| 122 | |
| 123 | // Returns true if the tape stack is empty. |
| 124 | PyObject* TFE_Py_TapeSetIsEmpty(); |
| 125 | |
| 126 | PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors); |
| 127 | void TFE_Py_TapeSetWatch(PyObject* tensor); |
| 128 | void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id); |
| 129 | |
| 130 | // Stops any gradient recording on the current thread. |
| 131 | void TFE_Py_TapeSetStopOnThread(); |
| 132 | |
| 133 | // Restarts gradient recording on the current thread. |
| 134 | void TFE_Py_TapeSetRestartOnThread(); |
| 135 | |
| 136 | // Records an operation in the gradient tape stack.type is a string for the |
| 137 | // operation type, used in the backprop code. output_tensors should be a list of |
| 138 | // python ops.Tensor objects. input_tensor_ids should be a list of python |
| 139 | // integers with the ids of the input tensors of the recorded |
| 140 | // operation. backward_function should be the function to be called during |
| 141 | // backprop to, given the gradients of the output tensors, produce the gradients |
| 142 | // of the input tensors. |
| 143 | void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, |
| 144 | PyObject* input_tensor_ids, |
| 145 | PyObject* backward_function); |
| 146 | |
| 147 | // Watches the given variable object on the given tape. |
| 148 | void TFE_Py_TapeSetWatchVariable(PyObject* variable); |
| 149 | |
| 150 | // Computes a gradient based on information recorded on the tape.`tape` must |
| 151 | // have been produced by TFE_Py_NewTape. `vspace` must be a |
| 152 | // imperative_grad.py:VSpace named tuple. `target` and `sources` must be python |
| 153 | // lists of Tensor objects. `output_gradients` is either None or a python list |
| 154 | // of either Tensor or None, and if not None should have the same length as |
| 155 | // target. |
| 156 | PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, |
| 157 | PyObject* target, PyObject* sources, |
| 158 | PyObject* output_gradients, TF_Status* status); |
| 159 | |
| 160 | // Execute a tensorflow operation assuming that all provided inputs are |
| 161 | // correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors, |
| 162 | // it will simply fail with a NotImplementedError. |
| 163 | // |
| 164 | // The first PyObject* is unused. |
| 165 | // The "args" PyObject* is meant to be a tuple with the following structure: |
| 166 | // Item 1: The TFE Context |
| 167 | // Item 2: device_name: Name of the device on which to execute the operation, |
| 168 | // or NULL for automatic selection. |
| 169 | // Item 3: op_name: Name of the TensorFlow op to execute. |
| 170 | // Item 4: name: An optional name for the operation. |
| 171 | // Item 5: List representing all callbacks to execute after successful |
| 172 | // op execute. |
| 173 | // Item 6 onwards: inputs - This is a list of inputs followed by a list of |
| 174 | // attrs. It is not necessary for type attrs to be present. |
| 175 | // |
| 176 | // This is named _C since there doesn't seem to be any way to make it visible |
| 177 | // in the SWIG interface without renaming due to the use of the %native |
| 178 | // directive. |
| 179 | PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args); |
| 180 | |
| 181 | // Record the gradient for a given op. |
| 182 | PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, |
| 183 | PyObject* attrs, PyObject* results, |
| 184 | PyObject* name); |
| 185 | |
| 186 | // Returns the set of variables watched by the given tape. |
| 187 | PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); |
| 188 | |
| 189 | // Returns an EagerTensor of dimension [len(`tensor_list`)] containing |
| 190 | // the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words, |
| 191 | // TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in |
| 192 | // `tensor_list`. For example, if `tensor_list` contains tensors of with shapes |
| 193 | // [1, 2, 3], [4, 5], [6, 7, 8, 9], TFE_Py_TensorShapeSlice called with |
| 194 | // `slice_dim` equal to 1 will return [2, 5, 7]. |
| 195 | // On error, returns nullptr and sets python exception. |
| 196 | // REQUIRES: `tensor_list` is a python list of EagerTensors |
| 197 | // REQUIRES: `slice_dim` is non-negative and smaller than the rank of all |
| 198 | // tensors in `tensor_list`. |
| 199 | PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim); |
| 200 | |
| 201 | #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ |
| 202 | |