| 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | |
| 16 | #ifndef TENSORFLOW_PUBLIC_SESSION_H_ |
| 17 | #define TENSORFLOW_PUBLIC_SESSION_H_ |
| 18 | |
| 19 | #include <string> |
| 20 | #include <vector> |
| 21 | |
| 22 | #include "tensorflow/core/framework/device_attributes.pb.h" |
| 23 | #include "tensorflow/core/framework/graph.pb.h" |
| 24 | #include "tensorflow/core/framework/tensor.h" |
| 25 | #include "tensorflow/core/lib/core/status.h" |
| 26 | #include "tensorflow/core/platform/env.h" |
| 27 | #include "tensorflow/core/protobuf/config.pb.h" |
| 28 | #include "tensorflow/core/public/session_options.h" |
| 29 | |
| 30 | namespace tensorflow { |
| 31 | class DeviceMgr; |
| 32 | |
| 33 | /// \brief A Session instance lets a caller drive a TensorFlow graph |
| 34 | /// computation. |
| 35 | /// |
| 36 | /// When a Session is created with a given target, a new Session object |
| 37 | /// is bound to the universe of resources specified by that target. |
| 38 | /// Those resources are available to this session to perform |
| 39 | /// computation described in the GraphDef. After extending the session |
| 40 | /// with a graph, the caller uses the Run() API to perform the |
| 41 | /// computation and potentially fetch outputs as Tensors. |
| 42 | /// |
| 43 | /// Example: |
| 44 | /// |
| 45 | /// ```c++ |
| 46 | /// |
| 47 | /// tensorflow::GraphDef graph; |
| 48 | /// // ... Create or load graph into "graph". |
| 49 | /// |
| 50 | /// // This example uses the default options which connects |
| 51 | /// // to a local runtime. |
| 52 | /// tensorflow::SessionOptions options; |
| 53 | /// std::unique_ptr<tensorflow::Session> |
| 54 | /// session(tensorflow::NewSession(options)); |
| 55 | /// |
| 56 | /// // Create the session with this graph. |
| 57 | /// tensorflow::Status s = session->Create(graph); |
| 58 | /// if (!s.ok()) { ... } |
| 59 | /// |
| 60 | /// // Run the graph and fetch the first output of the "output" |
| 61 | /// // operation, and also run to but do not return anything |
| 62 | /// // for the "update_state" operation. |
| 63 | /// std::vector<tensorflow::Tensor> outputs; |
| 64 | /// s = session->Run({}, {"output:0"}, {"update_state"}, &outputs); |
| 65 | /// if (!s.ok()) { ... } |
| 66 | /// |
| 67 | /// // Map the output as a flattened float tensor, and do something |
| 68 | /// // with it. |
| 69 | /// auto output_tensor = outputs[0].flat<float>(); |
| 70 | /// if (output_tensor(0) > 0.5) { ... } |
| 71 | /// |
| 72 | /// // Close the session to release the resources associated with |
| 73 | /// // this session. |
| 74 | /// session->Close(); |
| 75 | /// |
| 76 | /// ``` |
| 77 | /// |
| 78 | /// A Session allows concurrent calls to Run(), though a Session must |
| 79 | /// be created / extended by a single thread. |
| 80 | /// |
| 81 | /// Only one thread must call Close(), and Close() must only be called |
| 82 | /// after all other calls to Run() have returned. |
| 83 | class Session { |
| 84 | public: |
| 85 | Session(); |
| 86 | virtual ~Session(); |
| 87 | |
| 88 | /// \brief Create the graph to be used for the session. |
| 89 | /// |
| 90 | /// Returns an error if this session has already been created with a |
| 91 | /// graph. To re-use the session with a different graph, the caller |
| 92 | /// must Close() the session first. |
| 93 | virtual Status Create(const GraphDef& graph) = 0; |
| 94 | |
| 95 | /// \brief Adds operations to the graph that is already registered with the |
| 96 | /// Session. |
| 97 | /// |
| 98 | /// The names of new operations in "graph" must not exist in the |
| 99 | /// graph that is already registered. |
| 100 | virtual Status Extend(const GraphDef& graph) = 0; |
| 101 | |
| 102 | /// \brief Runs the graph with the provided input tensors and fills |
| 103 | /// `outputs` for the endpoints specified in `output_tensor_names`. |
| 104 | /// Runs to but does not return Tensors for the nodes in |
| 105 | /// `target_node_names`. |
| 106 | /// |
| 107 | /// The order of tensors in `outputs` will match the order provided |
| 108 | /// by `output_tensor_names`. |
| 109 | /// |
| 110 | /// If `Run` returns `OK()`, then `outputs->size()` will be equal to |
| 111 | /// `output_tensor_names.size()`. If `Run` does not return `OK()`, the |
| 112 | /// state of `outputs` is undefined. |
| 113 | /// |
| 114 | /// REQUIRES: The name of each Tensor of the input or output must |
| 115 | /// match a "Tensor endpoint" in the `GraphDef` passed to `Create()`. |
| 116 | /// |
| 117 | /// REQUIRES: At least one of `output_tensor_names` and |
| 118 | /// `target_node_names` must be non-empty. |
| 119 | /// |
| 120 | /// REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty. |
| 121 | virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs, |
| 122 | const std::vector<string>& output_tensor_names, |
| 123 | const std::vector<string>& target_node_names, |
| 124 | std::vector<Tensor>* outputs) = 0; |
| 125 | |
| 126 | /// \brief Implementations which support `RunOptions`. |
| 127 | // |
| 128 | /// NOTE: This API is still experimental and may change. |
| 129 | virtual Status Create(const RunOptions& run_options, const GraphDef& graph) { |
| 130 | return errors::Unimplemented( |
| 131 | "Create(const RunOptions& run_options, const GraphDef& graph) is not " |
| 132 | "supported for this session." ); |
| 133 | } |
| 134 | virtual Status Extend(const RunOptions& run_options, const GraphDef& graph) { |
| 135 | return errors::Unimplemented( |
| 136 | "Extend(const RunOptions& run_options, const GraphDef& graph) is not " |
| 137 | "supported for this session." ); |
| 138 | } |
| 139 | virtual Status Close(const RunOptions& run_options) { |
| 140 | return errors::Unimplemented( |
| 141 | "Close(const RunOptions& run_options) is not supported for this " |
| 142 | "session." ); |
| 143 | } |
| 144 | |
| 145 | /// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and |
| 146 | /// to retrieve non-Tensor metadata output via a `RunMetadata` proto for this |
| 147 | /// step. `run_metadata` may be nullptr, in which case any metadata output is |
| 148 | /// discarded. |
| 149 | /// NOTE: This API is still experimental and may change. |
| 150 | virtual Status Run(const RunOptions& run_options, |
| 151 | const std::vector<std::pair<string, Tensor> >& inputs, |
| 152 | const std::vector<string>& output_tensor_names, |
| 153 | const std::vector<string>& target_node_names, |
| 154 | std::vector<Tensor>* outputs, RunMetadata* run_metadata); |
| 155 | |
| 156 | /// \brief Sets up a graph for partial execution. All future feeds and |
| 157 | /// fetches are specified by `input_names` and `output_names`. Returns |
| 158 | /// `handle` that can be used to perform a sequence of partial feeds and |
| 159 | /// fetches. |
| 160 | /// NOTE: This API is still experimental and may change. |
| 161 | virtual Status PRunSetup(const std::vector<string>& input_names, |
| 162 | const std::vector<string>& output_names, |
| 163 | const std::vector<string>& target_nodes, |
| 164 | string* handle); |
| 165 | |
| 166 | /// \brief Continues the pending execution specified by `handle` with the |
| 167 | /// provided input tensors and fills `outputs` for the endpoints specified |
| 168 | /// in `output_names`. |
| 169 | /// NOTE: This API is still experimental and may change. |
| 170 | virtual Status PRun(const string& handle, |
| 171 | const std::vector<std::pair<string, Tensor> >& inputs, |
| 172 | const std::vector<string>& output_names, |
| 173 | std::vector<Tensor>* outputs); |
| 174 | |
| 175 | /// \brief List devices in the session. |
| 176 | /// |
| 177 | /// Retrieves the list of available devices within the session, and populates |
| 178 | /// *response. This API is optional. If it is unimplemented, Status will |
| 179 | /// return a corresponding error message, and *response will be unmodified. |
| 180 | virtual Status ListDevices(std::vector<DeviceAttributes>* response) = 0; |
| 181 | |
| 182 | /// \brief Closes this session. |
| 183 | /// |
| 184 | /// Closing a session releases the resources used by this session |
| 185 | /// on the TensorFlow runtime (specified during session creation by |
| 186 | /// the `SessionOptions::target` field). |
| 187 | virtual Status Close() = 0; |
| 188 | |
| 189 | // NOTE(ashankar): As of July 2017, this method was added to facilitate some |
| 190 | // experimentation. Reconsider/re-evaluate after September 2017. |
| 191 | // |
| 192 | // Sets `*output` to the `DeviceMgr` that owns accessible devices in the |
| 193 | // address-space of the caller. |
| 194 | virtual Status LocalDeviceManager(const DeviceMgr** output) { |
| 195 | return errors::Unimplemented( |
| 196 | "LocalDeviceManager is not supported for this session." ); |
| 197 | } |
| 198 | |
| 199 | /// \brief A handle to a subgraph, created with `Session::MakeCallable()`. |
| 200 | typedef int64 CallableHandle; |
| 201 | |
| 202 | /// \brief Creates a `handle` for invoking the subgraph defined by |
| 203 | /// `callable_options`. |
| 204 | /// NOTE: This API is still experimental and may change. |
| 205 | virtual Status MakeCallable(const CallableOptions& callable_options, |
| 206 | CallableHandle* out_handle) { |
| 207 | return errors::Unimplemented( |
| 208 | "MakeCallable is not supported for this session." ); |
| 209 | } |
| 210 | |
| 211 | /// \brief Invokes the subgraph named by `handle` with the given options and |
| 212 | /// input tensors. |
| 213 | /// |
| 214 | /// The order of tensors in `feed_tensors` must and `fetch_tensors` will |
| 215 | /// match the order of names in `CallableOptions::feed()` and |
| 216 | /// `CallableOptions::fetch()` when this subgraph was created. |
| 217 | /// NOTE: This API is still experimental and may change. |
| 218 | virtual Status RunCallable(CallableHandle handle, |
| 219 | const std::vector<Tensor>& feed_tensors, |
| 220 | std::vector<Tensor>* fetch_tensors, |
| 221 | RunMetadata* run_metadata) { |
| 222 | return errors::Unimplemented( |
| 223 | "RunCallable is not supported for this session." ); |
| 224 | } |
| 225 | |
| 226 | /// \brief Releases resources associated with the given `handle` in this |
| 227 | /// session. |
| 228 | /// NOTE: This API is still experimental and may change. |
| 229 | virtual Status ReleaseCallable(CallableHandle handle) { |
| 230 | return errors::Unimplemented( |
| 231 | "ReleaseCallable is not supported for this session." ); |
| 232 | } |
| 233 | }; |
| 234 | |
| 235 | /// \brief Create a new session with the given options. |
| 236 | /// |
| 237 | /// If session creation succeeds, the new `Session` will be stored in |
| 238 | /// `*out_session`, the caller will take ownership of the returned |
| 239 | /// `*out_session`, and this function will return `OK()`. Otherwise, this |
| 240 | /// function will return an error status. |
| 241 | Status NewSession(const SessionOptions& options, Session** out_session); |
| 242 | |
| 243 | /// \brief Resets resource containers associated with a target. |
| 244 | /// |
| 245 | /// Reset() allows misbehaving or slow sessions to be aborted and closed, and |
| 246 | /// causes their resources eventually to be released. Reset() does not wait |
| 247 | /// for the computations in old sessions to cease; it merely starts the |
| 248 | /// process of tearing them down. However, if a new session is started after |
| 249 | /// a Reset(), the new session is isolated from changes that old sessions |
| 250 | /// (started prior to the Reset()) may continue to make to resources, provided |
| 251 | /// all those resources are in containers listed in "containers". |
| 252 | /// |
| 253 | /// Old sessions may continue to have side-effects on resources not in |
| 254 | /// containers listed in "containers", and thus may affect future |
| 255 | /// sessions' results in ways that are hard to predict. Thus, if well-defined |
| 256 | /// behavior is desired, it is recommended that all containers be listed in |
| 257 | /// "containers". |
| 258 | /// |
| 259 | /// `containers` is a vector of string representation of resource container |
| 260 | /// names. When a resource container is reset, the resources held by the |
| 261 | /// container will be released. In particular, all Variables in the container |
| 262 | /// will become undefined. If the "containers" vector is empty, the default |
| 263 | /// container is assumed. If the "containers" vector is non-empty, the |
| 264 | /// default container should be listed explicitly. |
| 265 | /// |
| 266 | /// If Reset succeeds, this function will return `OK()`. Otherwise, this |
| 267 | /// function will return an error status. |
| 268 | Status Reset(const SessionOptions& options, |
| 269 | const std::vector<string>& containers); |
| 270 | |
| 271 | /// \brief Create a new session with the given options. |
| 272 | /// |
| 273 | /// If a new `Session` object could not be created, this function will |
| 274 | /// return nullptr. |
| 275 | /// |
| 276 | /// *Strongly prefer* the version of NewSession that returns Status, |
| 277 | /// which contains more helpful error information. |
| 278 | Session* NewSession(const SessionOptions& options); |
| 279 | |
| 280 | } // end namespace tensorflow |
| 281 | |
| 282 | #endif // TENSORFLOW_PUBLIC_SESSION_H_ |
| 283 | |