1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_PUBLIC_SESSION_H_ |
17 | #define TENSORFLOW_PUBLIC_SESSION_H_ |
18 | |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/device_attributes.pb.h" |
23 | #include "tensorflow/core/framework/graph.pb.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | #include "tensorflow/core/platform/env.h" |
27 | #include "tensorflow/core/protobuf/config.pb.h" |
28 | #include "tensorflow/core/public/session_options.h" |
29 | |
30 | namespace tensorflow { |
31 | class DeviceMgr; |
32 | |
33 | /// \brief A Session instance lets a caller drive a TensorFlow graph |
34 | /// computation. |
35 | /// |
36 | /// When a Session is created with a given target, a new Session object |
37 | /// is bound to the universe of resources specified by that target. |
38 | /// Those resources are available to this session to perform |
39 | /// computation described in the GraphDef. After extending the session |
40 | /// with a graph, the caller uses the Run() API to perform the |
41 | /// computation and potentially fetch outputs as Tensors. |
42 | /// |
43 | /// Example: |
44 | /// |
45 | /// ```c++ |
46 | /// |
47 | /// tensorflow::GraphDef graph; |
48 | /// // ... Create or load graph into "graph". |
49 | /// |
50 | /// // This example uses the default options which connects |
51 | /// // to a local runtime. |
52 | /// tensorflow::SessionOptions options; |
53 | /// std::unique_ptr<tensorflow::Session> |
54 | /// session(tensorflow::NewSession(options)); |
55 | /// |
56 | /// // Create the session with this graph. |
57 | /// tensorflow::Status s = session->Create(graph); |
58 | /// if (!s.ok()) { ... } |
59 | /// |
60 | /// // Run the graph and fetch the first output of the "output" |
61 | /// // operation, and also run to but do not return anything |
62 | /// // for the "update_state" operation. |
63 | /// std::vector<tensorflow::Tensor> outputs; |
64 | /// s = session->Run({}, {"output:0"}, {"update_state"}, &outputs); |
65 | /// if (!s.ok()) { ... } |
66 | /// |
67 | /// // Map the output as a flattened float tensor, and do something |
68 | /// // with it. |
69 | /// auto output_tensor = outputs[0].flat<float>(); |
70 | /// if (output_tensor(0) > 0.5) { ... } |
71 | /// |
72 | /// // Close the session to release the resources associated with |
73 | /// // this session. |
74 | /// session->Close(); |
75 | /// |
76 | /// ``` |
77 | /// |
78 | /// A Session allows concurrent calls to Run(), though a Session must |
79 | /// be created / extended by a single thread. |
80 | /// |
81 | /// Only one thread must call Close(), and Close() must only be called |
82 | /// after all other calls to Run() have returned. |
83 | class Session { |
84 | public: |
85 | Session(); |
86 | virtual ~Session(); |
87 | |
88 | /// \brief Create the graph to be used for the session. |
89 | /// |
90 | /// Returns an error if this session has already been created with a |
91 | /// graph. To re-use the session with a different graph, the caller |
92 | /// must Close() the session first. |
93 | virtual Status Create(const GraphDef& graph) = 0; |
94 | |
95 | /// \brief Adds operations to the graph that is already registered with the |
96 | /// Session. |
97 | /// |
98 | /// The names of new operations in "graph" must not exist in the |
99 | /// graph that is already registered. |
100 | virtual Status Extend(const GraphDef& graph) = 0; |
101 | |
102 | /// \brief Runs the graph with the provided input tensors and fills |
103 | /// `outputs` for the endpoints specified in `output_tensor_names`. |
104 | /// Runs to but does not return Tensors for the nodes in |
105 | /// `target_node_names`. |
106 | /// |
107 | /// The order of tensors in `outputs` will match the order provided |
108 | /// by `output_tensor_names`. |
109 | /// |
110 | /// If `Run` returns `OK()`, then `outputs->size()` will be equal to |
111 | /// `output_tensor_names.size()`. If `Run` does not return `OK()`, the |
112 | /// state of `outputs` is undefined. |
113 | /// |
114 | /// REQUIRES: The name of each Tensor of the input or output must |
115 | /// match a "Tensor endpoint" in the `GraphDef` passed to `Create()`. |
116 | /// |
117 | /// REQUIRES: At least one of `output_tensor_names` and |
118 | /// `target_node_names` must be non-empty. |
119 | /// |
120 | /// REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty. |
121 | virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs, |
122 | const std::vector<string>& output_tensor_names, |
123 | const std::vector<string>& target_node_names, |
124 | std::vector<Tensor>* outputs) = 0; |
125 | |
126 | /// \brief Implementations which support `RunOptions`. |
127 | // |
128 | /// NOTE: This API is still experimental and may change. |
129 | virtual Status Create(const RunOptions& run_options, const GraphDef& graph) { |
130 | return errors::Unimplemented( |
131 | "Create(const RunOptions& run_options, const GraphDef& graph) is not " |
132 | "supported for this session." ); |
133 | } |
134 | virtual Status Extend(const RunOptions& run_options, const GraphDef& graph) { |
135 | return errors::Unimplemented( |
136 | "Extend(const RunOptions& run_options, const GraphDef& graph) is not " |
137 | "supported for this session." ); |
138 | } |
139 | virtual Status Close(const RunOptions& run_options) { |
140 | return errors::Unimplemented( |
141 | "Close(const RunOptions& run_options) is not supported for this " |
142 | "session." ); |
143 | } |
144 | |
145 | /// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and |
146 | /// to retrieve non-Tensor metadata output via a `RunMetadata` proto for this |
147 | /// step. `run_metadata` may be nullptr, in which case any metadata output is |
148 | /// discarded. |
149 | /// NOTE: This API is still experimental and may change. |
150 | virtual Status Run(const RunOptions& run_options, |
151 | const std::vector<std::pair<string, Tensor> >& inputs, |
152 | const std::vector<string>& output_tensor_names, |
153 | const std::vector<string>& target_node_names, |
154 | std::vector<Tensor>* outputs, RunMetadata* run_metadata); |
155 | |
156 | /// \brief Sets up a graph for partial execution. All future feeds and |
157 | /// fetches are specified by `input_names` and `output_names`. Returns |
158 | /// `handle` that can be used to perform a sequence of partial feeds and |
159 | /// fetches. |
160 | /// NOTE: This API is still experimental and may change. |
161 | virtual Status PRunSetup(const std::vector<string>& input_names, |
162 | const std::vector<string>& output_names, |
163 | const std::vector<string>& target_nodes, |
164 | string* handle); |
165 | |
166 | /// \brief Continues the pending execution specified by `handle` with the |
167 | /// provided input tensors and fills `outputs` for the endpoints specified |
168 | /// in `output_names`. |
169 | /// NOTE: This API is still experimental and may change. |
170 | virtual Status PRun(const string& handle, |
171 | const std::vector<std::pair<string, Tensor> >& inputs, |
172 | const std::vector<string>& output_names, |
173 | std::vector<Tensor>* outputs); |
174 | |
175 | /// \brief List devices in the session. |
176 | /// |
177 | /// Retrieves the list of available devices within the session, and populates |
178 | /// *response. This API is optional. If it is unimplemented, Status will |
179 | /// return a corresponding error message, and *response will be unmodified. |
180 | virtual Status ListDevices(std::vector<DeviceAttributes>* response) = 0; |
181 | |
182 | /// \brief Closes this session. |
183 | /// |
184 | /// Closing a session releases the resources used by this session |
185 | /// on the TensorFlow runtime (specified during session creation by |
186 | /// the `SessionOptions::target` field). |
187 | virtual Status Close() = 0; |
188 | |
189 | // NOTE(ashankar): As of July 2017, this method was added to facilitate some |
190 | // experimentation. Reconsider/re-evaluate after September 2017. |
191 | // |
192 | // Sets `*output` to the `DeviceMgr` that owns accessible devices in the |
193 | // address-space of the caller. |
194 | virtual Status LocalDeviceManager(const DeviceMgr** output) { |
195 | return errors::Unimplemented( |
196 | "LocalDeviceManager is not supported for this session." ); |
197 | } |
198 | |
199 | /// \brief A handle to a subgraph, created with `Session::MakeCallable()`. |
200 | typedef int64 CallableHandle; |
201 | |
202 | /// \brief Creates a `handle` for invoking the subgraph defined by |
203 | /// `callable_options`. |
204 | /// NOTE: This API is still experimental and may change. |
205 | virtual Status MakeCallable(const CallableOptions& callable_options, |
206 | CallableHandle* out_handle) { |
207 | return errors::Unimplemented( |
208 | "MakeCallable is not supported for this session." ); |
209 | } |
210 | |
211 | /// \brief Invokes the subgraph named by `handle` with the given options and |
212 | /// input tensors. |
213 | /// |
214 | /// The order of tensors in `feed_tensors` must and `fetch_tensors` will |
215 | /// match the order of names in `CallableOptions::feed()` and |
216 | /// `CallableOptions::fetch()` when this subgraph was created. |
217 | /// NOTE: This API is still experimental and may change. |
218 | virtual Status RunCallable(CallableHandle handle, |
219 | const std::vector<Tensor>& feed_tensors, |
220 | std::vector<Tensor>* fetch_tensors, |
221 | RunMetadata* run_metadata) { |
222 | return errors::Unimplemented( |
223 | "RunCallable is not supported for this session." ); |
224 | } |
225 | |
226 | /// \brief Releases resources associated with the given `handle` in this |
227 | /// session. |
228 | /// NOTE: This API is still experimental and may change. |
229 | virtual Status ReleaseCallable(CallableHandle handle) { |
230 | return errors::Unimplemented( |
231 | "ReleaseCallable is not supported for this session." ); |
232 | } |
233 | }; |
234 | |
235 | /// \brief Create a new session with the given options. |
236 | /// |
237 | /// If session creation succeeds, the new `Session` will be stored in |
238 | /// `*out_session`, the caller will take ownership of the returned |
239 | /// `*out_session`, and this function will return `OK()`. Otherwise, this |
240 | /// function will return an error status. |
241 | Status NewSession(const SessionOptions& options, Session** out_session); |
242 | |
243 | /// \brief Resets resource containers associated with a target. |
244 | /// |
245 | /// Reset() allows misbehaving or slow sessions to be aborted and closed, and |
246 | /// causes their resources eventually to be released. Reset() does not wait |
247 | /// for the computations in old sessions to cease; it merely starts the |
248 | /// process of tearing them down. However, if a new session is started after |
249 | /// a Reset(), the new session is isolated from changes that old sessions |
250 | /// (started prior to the Reset()) may continue to make to resources, provided |
251 | /// all those resources are in containers listed in "containers". |
252 | /// |
253 | /// Old sessions may continue to have side-effects on resources not in |
254 | /// containers listed in "containers", and thus may affect future |
255 | /// sessions' results in ways that are hard to predict. Thus, if well-defined |
256 | /// behavior is desired, it is recommended that all containers be listed in |
257 | /// "containers". |
258 | /// |
259 | /// `containers` is a vector of string representation of resource container |
260 | /// names. When a resource container is reset, the resources held by the |
261 | /// container will be released. In particular, all Variables in the container |
262 | /// will become undefined. If the "containers" vector is empty, the default |
263 | /// container is assumed. If the "containers" vector is non-empty, the |
264 | /// default container should be listed explicitly. |
265 | /// |
266 | /// If Reset succeeds, this function will return `OK()`. Otherwise, this |
267 | /// function will return an error status. |
268 | Status Reset(const SessionOptions& options, |
269 | const std::vector<string>& containers); |
270 | |
271 | /// \brief Create a new session with the given options. |
272 | /// |
273 | /// If a new `Session` object could not be created, this function will |
274 | /// return nullptr. |
275 | /// |
276 | /// *Strongly prefer* the version of NewSession that returns Status, |
277 | /// which contains more helpful error information. |
278 | Session* NewSession(const SessionOptions& options); |
279 | |
280 | } // end namespace tensorflow |
281 | |
282 | #endif // TENSORFLOW_PUBLIC_SESSION_H_ |
283 | |