| 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | |
| 16 | #include "tensorflow/core/common_runtime/direct_session.h" |
| 17 | |
| 18 | #include <atomic> |
| 19 | #include <string> |
| 20 | #include <vector> |
| 21 | |
| 22 | #include "tensorflow/core/common_runtime/constant_folding.h" |
| 23 | #include "tensorflow/core/common_runtime/debugger_state_interface.h" |
| 24 | #include "tensorflow/core/common_runtime/device_factory.h" |
| 25 | #include "tensorflow/core/common_runtime/executor.h" |
| 26 | #include "tensorflow/core/common_runtime/function.h" |
| 27 | #include "tensorflow/core/common_runtime/graph_optimizer.h" |
| 28 | #include "tensorflow/core/common_runtime/memory_types.h" |
| 29 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
| 30 | #include "tensorflow/core/common_runtime/process_util.h" |
| 31 | #include "tensorflow/core/common_runtime/step_stats_collector.h" |
| 32 | #include "tensorflow/core/framework/function.h" |
| 33 | #include "tensorflow/core/framework/graph.pb_text.h" |
| 34 | #include "tensorflow/core/framework/graph.pb.h" |
| 35 | #include "tensorflow/core/framework/graph_def_util.h" |
| 36 | #include "tensorflow/core/framework/log_memory.h" |
| 37 | #include "tensorflow/core/framework/node_def.pb.h" |
| 38 | #include "tensorflow/core/framework/tensor.h" |
| 39 | #include "tensorflow/core/framework/versions.pb.h" |
| 40 | #include "tensorflow/core/graph/algorithm.h" |
| 41 | #include "tensorflow/core/graph/graph.h" |
| 42 | #include "tensorflow/core/graph/graph_constructor.h" |
| 43 | #include "tensorflow/core/graph/graph_partition.h" |
| 44 | #include "tensorflow/core/graph/subgraph.h" |
| 45 | #include "tensorflow/core/graph/tensor_id.h" |
| 46 | #include "tensorflow/core/lib/core/errors.h" |
| 47 | #include "tensorflow/core/lib/core/notification.h" |
| 48 | #include "tensorflow/core/lib/core/refcount.h" |
| 49 | #include "tensorflow/core/lib/core/status.h" |
| 50 | #include "tensorflow/core/lib/core/threadpool.h" |
| 51 | #include "tensorflow/core/lib/gtl/array_slice.h" |
| 52 | #include "tensorflow/core/lib/gtl/stl_util.h" |
| 53 | #include "tensorflow/core/lib/monitoring/counter.h" |
| 54 | #include "tensorflow/core/lib/strings/numbers.h" |
| 55 | #include "tensorflow/core/lib/strings/str_util.h" |
| 56 | #include "tensorflow/core/lib/strings/strcat.h" |
| 57 | #include "tensorflow/core/platform/cpu_info.h" |
| 58 | #include "tensorflow/core/platform/device_tracer.h" |
| 59 | #include "tensorflow/core/platform/logging.h" |
| 60 | #include "tensorflow/core/platform/mutex.h" |
| 61 | #include "tensorflow/core/platform/types.h" |
| 62 | #include "tensorflow/core/util/device_name_utils.h" |
| 63 | #include "tensorflow/core/util/env_var.h" |
| 64 | |
| 65 | namespace tensorflow { |
| 66 | |
| 67 | namespace { |
| 68 | |
| 69 | auto* direct_session_runs = monitoring::Counter<0>::New( |
| 70 | "/tensorflow/core/direct_session_runs" , |
| 71 | "The number of times DirectSession::Run() has been called." ); |
| 72 | |
| 73 | Status NewThreadPoolFromThreadPoolOptions( |
| 74 | const SessionOptions& options, |
| 75 | const ThreadPoolOptionProto& thread_pool_options, int pool_number, |
| 76 | thread::ThreadPool** pool, bool* owned) { |
| 77 | int32 num_threads = thread_pool_options.num_threads(); |
| 78 | if (num_threads == 0) { |
| 79 | num_threads = NumInterOpThreadsFromSessionOptions(options); |
| 80 | } |
| 81 | const string& name = thread_pool_options.global_name(); |
| 82 | if (name.empty()) { |
| 83 | // Session-local threadpool. |
| 84 | VLOG(1) << "Direct session inter op parallelism threads for pool " |
| 85 | << pool_number << ": " << num_threads; |
| 86 | *pool = new thread::ThreadPool( |
| 87 | options.env, strings::StrCat("Compute" , pool_number), num_threads); |
| 88 | *owned = true; |
| 89 | return Status::OK(); |
| 90 | } |
| 91 | |
| 92 | // Global, named threadpool. |
| 93 | typedef std::pair<int32, thread::ThreadPool*> MapValue; |
| 94 | static std::map<string, MapValue>* global_pool_map = |
| 95 | new std::map<string, MapValue>; |
| 96 | static mutex* mu = new mutex(); |
| 97 | mutex_lock l(*mu); |
| 98 | MapValue* mvalue = &(*global_pool_map)[name]; |
| 99 | if (mvalue->second == nullptr) { |
| 100 | mvalue->first = thread_pool_options.num_threads(); |
| 101 | mvalue->second = new thread::ThreadPool( |
| 102 | options.env, strings::StrCat("Compute" , pool_number), num_threads); |
| 103 | } else { |
| 104 | if (mvalue->first != thread_pool_options.num_threads()) { |
| 105 | return errors::InvalidArgument( |
| 106 | "Pool " , name, |
| 107 | " configured previously with num_threads=" , mvalue->first, |
| 108 | "; cannot re-configure with num_threads=" , |
| 109 | thread_pool_options.num_threads()); |
| 110 | } |
| 111 | } |
| 112 | *owned = false; |
| 113 | *pool = mvalue->second; |
| 114 | return Status::OK(); |
| 115 | } |
| 116 | |
| 117 | thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) { |
| 118 | static thread::ThreadPool* const thread_pool = |
| 119 | NewThreadPoolFromSessionOptions(options); |
| 120 | return thread_pool; |
| 121 | } |
| 122 | |
| 123 | // TODO(vrv): Figure out how to unify the many different functions |
| 124 | // that generate RendezvousKey, since many of them have to be |
| 125 | // consistent with each other. |
| 126 | string GetRendezvousKey(const string& tensor_name, |
| 127 | const DeviceAttributes& device_info, |
| 128 | const FrameAndIter& frame_iter) { |
| 129 | return strings::StrCat(device_info.name(), ";" , |
| 130 | strings::FpToString(device_info.incarnation()), ";" , |
| 131 | device_info.name(), ";" , tensor_name, ";" , |
| 132 | frame_iter.frame_id, ":" , frame_iter.iter_id); |
| 133 | } |
| 134 | |
| 135 | } // namespace |
| 136 | |
| 137 | class DirectSessionFactory : public SessionFactory { |
| 138 | public: |
| 139 | DirectSessionFactory() {} |
| 140 | |
| 141 | bool AcceptsOptions(const SessionOptions& options) override { |
| 142 | return options.target.empty(); |
| 143 | } |
| 144 | |
| 145 | Session* NewSession(const SessionOptions& options) override { |
| 146 | // Must do this before the CPU allocator is created. |
| 147 | if (options.config.graph_options().build_cost_model() > 0) { |
| 148 | EnableCPUAllocatorFullStats(true); |
| 149 | } |
| 150 | std::vector<Device*> devices; |
| 151 | const Status s = DeviceFactory::AddDevices( |
| 152 | options, "/job:localhost/replica:0/task:0" , &devices); |
| 153 | if (!s.ok()) { |
| 154 | LOG(ERROR) << s; |
| 155 | return nullptr; |
| 156 | } |
| 157 | |
| 158 | DirectSession* session = |
| 159 | new DirectSession(options, new DeviceMgr(devices), this); |
| 160 | { |
| 161 | mutex_lock l(sessions_lock_); |
| 162 | sessions_.push_back(session); |
| 163 | } |
| 164 | return session; |
| 165 | } |
| 166 | |
| 167 | Status Reset(const SessionOptions& options, |
| 168 | const std::vector<string>& containers) override { |
| 169 | std::vector<DirectSession*> sessions_to_reset; |
| 170 | { |
| 171 | mutex_lock l(sessions_lock_); |
| 172 | // We create a copy to ensure that we don't have a deadlock when |
| 173 | // session->Close calls the DirectSessionFactory.Deregister, which |
| 174 | // acquires sessions_lock_. |
| 175 | std::swap(sessions_to_reset, sessions_); |
| 176 | } |
| 177 | Status s; |
| 178 | for (auto session : sessions_to_reset) { |
| 179 | s.Update(session->Reset(containers)); |
| 180 | } |
| 181 | // TODO(suharshs): Change the Reset behavior of all SessionFactories so that |
| 182 | // it doesn't close the sessions? |
| 183 | for (auto session : sessions_to_reset) { |
| 184 | s.Update(session->Close()); |
| 185 | } |
| 186 | return s; |
| 187 | } |
| 188 | |
| 189 | void Deregister(const DirectSession* session) { |
| 190 | mutex_lock l(sessions_lock_); |
| 191 | sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session), |
| 192 | sessions_.end()); |
| 193 | } |
| 194 | |
| 195 | private: |
| 196 | mutex sessions_lock_; |
| 197 | std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_); |
| 198 | }; |
| 199 | |
| 200 | class DirectSessionRegistrar { |
| 201 | public: |
| 202 | DirectSessionRegistrar() { |
| 203 | SessionFactory::Register("DIRECT_SESSION" , new DirectSessionFactory()); |
| 204 | } |
| 205 | }; |
| 206 | static DirectSessionRegistrar registrar; |
| 207 | |
| 208 | std::atomic_int_fast64_t DirectSession::step_id_counter_(1); |
| 209 | |
| 210 | // NOTE: On Android with a single device, there is never |
| 211 | // a risk of an OpKernel blocking indefinitely: |
| 212 | // |
| 213 | // 1) No operations do I/O that depends on other simultaneous kernels, |
| 214 | // |
| 215 | // 2) Recv nodes always complete immediately: The inputs are sent into |
| 216 | // the local rendezvous before we start the executor, so the |
| 217 | // corresponding recvs will not block. |
| 218 | // |
| 219 | // Based on these assumptions, we can use the same thread pool for |
| 220 | // both "non-blocking" and "blocking" OpKernels on Android. |
| 221 | // |
| 222 | // This may change down the road when we add support for multiple |
| 223 | // devices that run concurrently, in which case we will need to |
| 224 | // revisit this decision. |
| 225 | void DirectSession::SchedClosure(thread::ThreadPool* pool, |
| 226 | std::function<void()> c) { |
| 227 | // TODO(sanjay): Get rid of __ANDROID__ path |
| 228 | #ifdef __ANDROID__ |
| 229 | // On Android, there is no implementation of ThreadPool that takes |
| 230 | // std::function, only Closure, which we cannot easily convert. |
| 231 | // |
| 232 | // Instead, we just run the function in-line, which is currently |
| 233 | // safe given the reasoning above. |
| 234 | c(); |
| 235 | #else |
| 236 | pool->Schedule(std::move(c)); |
| 237 | #endif // __ANDROID__ |
| 238 | } |
| 239 | |
| 240 | DirectSession::DirectSession(const SessionOptions& options, |
| 241 | const DeviceMgr* device_mgr, |
| 242 | DirectSessionFactory* const factory) |
| 243 | : options_(options), |
| 244 | device_mgr_(device_mgr), |
| 245 | factory_(factory), |
| 246 | cancellation_manager_(new CancellationManager()), |
| 247 | operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) { |
| 248 | const int thread_pool_size = |
| 249 | options_.config.session_inter_op_thread_pool_size(); |
| 250 | if (thread_pool_size > 0) { |
| 251 | for (int i = 0; i < thread_pool_size; ++i) { |
| 252 | thread::ThreadPool* pool = nullptr; |
| 253 | bool owned = false; |
| 254 | init_error_.Update(NewThreadPoolFromThreadPoolOptions( |
| 255 | options_, options_.config.session_inter_op_thread_pool(i), i, &pool, |
| 256 | &owned)); |
| 257 | thread_pools_.emplace_back(pool, owned); |
| 258 | } |
| 259 | } else if (options_.config.use_per_session_threads()) { |
| 260 | thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_), |
| 261 | true /* owned */); |
| 262 | } else { |
| 263 | thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */); |
| 264 | } |
| 265 | // The default value of sync_on_finish will be flipped soon and this |
| 266 | // environment variable will be removed as well. |
| 267 | const Status status = |
| 268 | ReadBoolFromEnvVar("TF_SYNC_ON_FINISH" , true, &sync_on_finish_); |
| 269 | if (!status.ok()) { |
| 270 | LOG(ERROR) << status.error_message(); |
| 271 | } |
| 272 | // NOTE(mrry): We do not need to use a unique string for the session |
| 273 | // handle, because DirectSession owns its devices. This may change |
| 274 | // in future versions. |
| 275 | session_handle_ = "direct" ; |
| 276 | int devices_added = 0; |
| 277 | if (options.config.log_device_placement()) { |
| 278 | const string mapping_str = device_mgr_->DeviceMappingString(); |
| 279 | if (mapping_str.empty()) { |
| 280 | printf("Device mapping: no known devices.\n" ); |
| 281 | } else { |
| 282 | printf("Device mapping:\n%s" , mapping_str.c_str()); |
| 283 | } |
| 284 | LOG(INFO) << "Device mapping:\n" << mapping_str; |
| 285 | } |
| 286 | for (auto d : device_mgr_->ListDevices()) { |
| 287 | devices_.push_back(d); |
| 288 | device_set_.AddDevice(d); |
| 289 | d->op_segment()->AddHold(session_handle_); |
| 290 | |
| 291 | // The first device added is special: it is the 'client device' (a |
| 292 | // CPU device) from which we feed and fetch Tensors. |
| 293 | if (devices_added == 0) { |
| 294 | device_set_.set_client_device(d); |
| 295 | } |
| 296 | ++devices_added; |
| 297 | } |
| 298 | } |
| 299 | |
| 300 | DirectSession::~DirectSession() { |
| 301 | if (!closed_) Close().IgnoreError(); |
| 302 | for (auto& it : partial_runs_) { |
| 303 | it.second.reset(nullptr); |
| 304 | } |
| 305 | for (auto& it : executors_) { |
| 306 | it.second.reset(); |
| 307 | } |
| 308 | callables_.clear(); |
| 309 | for (auto d : device_mgr_->ListDevices()) { |
| 310 | d->op_segment()->RemoveHold(session_handle_); |
| 311 | } |
| 312 | for (auto d : device_mgr_->ListDevices()) { |
| 313 | d->ClearResourceMgr(); |
| 314 | } |
| 315 | functions_.clear(); |
| 316 | delete cancellation_manager_; |
| 317 | for (const auto& p_and_owned : thread_pools_) { |
| 318 | if (p_and_owned.second) delete p_and_owned.first; |
| 319 | } |
| 320 | |
| 321 | execution_state_.reset(nullptr); |
| 322 | flib_def_.reset(nullptr); |
| 323 | } |
| 324 | |
| 325 | Status DirectSession::MaybeInitializeExecutionState( |
| 326 | const GraphDef& graph, bool* out_already_initialized) { |
| 327 | // If already initialized, do nothing. |
| 328 | if (flib_def_ && execution_state_) { |
| 329 | *out_already_initialized = true; |
| 330 | return Status::OK(); |
| 331 | } |
| 332 | // Set up the per-session execution state. |
| 333 | // NOTE(mrry): The function library created here will be used for |
| 334 | // all subsequent extensions of the graph. |
| 335 | flib_def_.reset( |
| 336 | new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); |
| 337 | GraphExecutionStateOptions options; |
| 338 | options.device_set = &device_set_; |
| 339 | options.session_options = &options_; |
| 340 | // TODO(mrry,suharshs): We explicitly copy `graph` so that |
| 341 | // `MakeForBaseGraph()` can take ownership of its |
| 342 | // contents. Previously this happened implicitly in calls to the |
| 343 | // `GraphExecutionState`. Other sessions call |
| 344 | // `MakeForBaseGraph` in such a way that we can destructively read |
| 345 | // the passed-in `GraphDef`. In principle we could do the same here, |
| 346 | // with a wider refactoring; we might revise the direct session so |
| 347 | // that it copies the graph fewer times. |
| 348 | GraphDef temp(graph); |
| 349 | TF_RETURN_IF_ERROR( |
| 350 | GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_)); |
| 351 | graph_created_ = true; |
| 352 | *out_already_initialized = false; |
| 353 | return Status::OK(); |
| 354 | } |
| 355 | |
| 356 | Status DirectSession::Create(const GraphDef& graph) { |
| 357 | TF_RETURN_IF_ERROR(init_error_); |
| 358 | if (graph.node_size() > 0) { |
| 359 | mutex_lock l(graph_def_lock_); |
| 360 | if (graph_created_) { |
| 361 | return errors::AlreadyExists( |
| 362 | "A Graph has already been created for this session." ); |
| 363 | } |
| 364 | return ExtendLocked(graph); |
| 365 | } |
| 366 | return Status::OK(); |
| 367 | } |
| 368 | |
| 369 | Status DirectSession::Extend(const GraphDef& graph) { |
| 370 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
| 371 | mutex_lock l(graph_def_lock_); |
| 372 | return ExtendLocked(graph); |
| 373 | } |
| 374 | |
| 375 | Status DirectSession::ExtendLocked(const GraphDef& graph) { |
| 376 | bool already_initialized; |
| 377 | // If this is the first call, we can initialize the execution state |
| 378 | // with `graph` and do not need to call `Extend()`. |
| 379 | TF_RETURN_IF_ERROR( |
| 380 | MaybeInitializeExecutionState(graph, &already_initialized)); |
| 381 | if (already_initialized) { |
| 382 | TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); |
| 383 | std::unique_ptr<GraphExecutionState> state; |
| 384 | TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); |
| 385 | execution_state_.swap(state); |
| 386 | } |
| 387 | return Status::OK(); |
| 388 | } |
| 389 | |
| 390 | Status DirectSession::Run(const NamedTensorList& inputs, |
| 391 | const std::vector<string>& output_names, |
| 392 | const std::vector<string>& target_nodes, |
| 393 | std::vector<Tensor>* outputs) { |
| 394 | RunMetadata run_metadata; |
| 395 | return Run(RunOptions(), inputs, output_names, target_nodes, outputs, |
| 396 | &run_metadata); |
| 397 | } |
| 398 | |
| 399 | Status DirectSession::CreateDebuggerState( |
| 400 | const CallableOptions& callable_options, int64 global_step, |
| 401 | int64 session_run_index, int64 executor_step_index, |
| 402 | std::unique_ptr<DebuggerStateInterface>* debugger_state) { |
| 403 | TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState( |
| 404 | callable_options.run_options().debug_options(), debugger_state)); |
| 405 | std::vector<string> input_names(callable_options.feed().begin(), |
| 406 | callable_options.feed().end()); |
| 407 | std::vector<string> output_names(callable_options.fetch().begin(), |
| 408 | callable_options.fetch().end()); |
| 409 | std::vector<string> target_names(callable_options.target().begin(), |
| 410 | callable_options.target().end()); |
| 411 | |
| 412 | TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( |
| 413 | global_step, session_run_index, executor_step_index, input_names, |
| 414 | output_names, target_names)); |
| 415 | return Status::OK(); |
| 416 | } |
| 417 | |
| 418 | Status DirectSession::DecorateAndPublishGraphForDebug( |
| 419 | const DebugOptions& debug_options, Graph* graph, Device* device) { |
| 420 | std::unique_ptr<DebugGraphDecoratorInterface> decorator; |
| 421 | TF_RETURN_IF_ERROR( |
| 422 | DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); |
| 423 | |
| 424 | TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); |
| 425 | TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name())); |
| 426 | return Status::OK(); |
| 427 | } |
| 428 | |
| 429 | Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, |
| 430 | CallFrameInterface* call_frame, |
| 431 | ExecutorsAndKeys* executors_and_keys, |
| 432 | RunMetadata* run_metadata) { |
| 433 | const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); |
| 434 | |
| 435 | std::unique_ptr<DebuggerStateInterface> debugger_state; |
| 436 | if (!run_options.debug_options().debug_tensor_watch_opts().empty()) { |
| 437 | TF_RETURN_IF_ERROR( |
| 438 | CreateDebuggerState(executors_and_keys->callable_options, |
| 439 | run_options.debug_options().global_step(), step_id, |
| 440 | executor_step_count, &debugger_state)); |
| 441 | } |
| 442 | |
| 443 | // Create a run state and start execution. |
| 444 | RunState run_state(step_id, &devices_); |
| 445 | run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); |
| 446 | |
| 447 | // Start parallel Executors. |
| 448 | const size_t num_executors = executors_and_keys->items.size(); |
| 449 | ExecutorBarrier* barrier = new ExecutorBarrier( |
| 450 | num_executors, run_state.rendez, [&run_state](const Status& ret) { |
| 451 | { |
| 452 | mutex_lock l(run_state.mu_); |
| 453 | run_state.status.Update(ret); |
| 454 | } |
| 455 | run_state.executors_done.Notify(); |
| 456 | }); |
| 457 | |
| 458 | Executor::Args args; |
| 459 | args.step_id = step_id; |
| 460 | args.call_frame = call_frame; |
| 461 | args.rendezvous = run_state.rendez; |
| 462 | CancellationManager step_cancellation_manager; |
| 463 | args.cancellation_manager = &step_cancellation_manager; |
| 464 | args.session_state = &session_state_; |
| 465 | args.tensor_store = &run_state.tensor_store; |
| 466 | args.step_container = &run_state.step_container; |
| 467 | args.sync_on_finish = sync_on_finish_; |
| 468 | |
| 469 | const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE); |
| 470 | |
| 471 | bool update_cost_model = false; |
| 472 | if (options_.config.graph_options().build_cost_model() > 0) { |
| 473 | const int64 build_cost_model_every = |
| 474 | options_.config.graph_options().build_cost_model(); |
| 475 | const int64 build_cost_model_after = |
| 476 | options_.config.graph_options().build_cost_model_after(); |
| 477 | int64 measure_step_count = executor_step_count - build_cost_model_after; |
| 478 | if (measure_step_count >= 0) { |
| 479 | update_cost_model = |
| 480 | ((measure_step_count + 1) % build_cost_model_every == 0); |
| 481 | } |
| 482 | } |
| 483 | if (do_trace || update_cost_model || |
| 484 | run_options.report_tensor_allocations_upon_oom()) { |
| 485 | run_state.collector.reset( |
| 486 | new StepStatsCollector(run_metadata->mutable_step_stats())); |
| 487 | args.stats_collector = run_state.collector.get(); |
| 488 | } |
| 489 | |
| 490 | std::unique_ptr<DeviceTracer> tracer; |
| 491 | if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) { |
| 492 | tracer = CreateDeviceTracer(); |
| 493 | // tracer may be NULL on platforms without accelerators. |
| 494 | if (tracer) { |
| 495 | Status s = tracer->Start(); |
| 496 | if (!s.ok()) { |
| 497 | run_state.executors_done.Notify(); |
| 498 | delete barrier; |
| 499 | return s; |
| 500 | } |
| 501 | } |
| 502 | } |
| 503 | |
| 504 | if (run_options.inter_op_thread_pool() < 0 || |
| 505 | run_options.inter_op_thread_pool() >= thread_pools_.size()) { |
| 506 | run_state.executors_done.Notify(); |
| 507 | delete barrier; |
| 508 | return errors::InvalidArgument("Invalid inter_op_thread_pool: " , |
| 509 | run_options.inter_op_thread_pool()); |
| 510 | } |
| 511 | |
| 512 | // Register this step with session's cancellation manager, so that |
| 513 | // `Session::Close()` will cancel the step. |
| 514 | const CancellationToken cancellation_token = |
| 515 | cancellation_manager_->get_cancellation_token(); |
| 516 | const bool already_cancelled = !cancellation_manager_->RegisterCallback( |
| 517 | cancellation_token, [&step_cancellation_manager]() { |
| 518 | step_cancellation_manager.StartCancel(); |
| 519 | }); |
| 520 | if (already_cancelled) { |
| 521 | // NOTE(mrry): If we don't explicitly notify |
| 522 | // `run_state.executors_done`, the RunState destructor would |
| 523 | // block on this notification. |
| 524 | run_state.executors_done.Notify(); |
| 525 | delete barrier; |
| 526 | return errors::Cancelled("Run call was cancelled" ); |
| 527 | } |
| 528 | |
| 529 | thread::ThreadPool* pool = |
| 530 | thread_pools_[run_options.inter_op_thread_pool()].first; |
| 531 | |
| 532 | Executor::Args::Runner default_runner = [this, |
| 533 | pool](Executor::Args::Closure c) { |
| 534 | SchedClosure(pool, std::move(c)); |
| 535 | }; |
| 536 | for (const auto& item : executors_and_keys->items) { |
| 537 | // TODO(zhengxq): support partial run. |
| 538 | // TODO(zhengxq): if the device picks its own threadpool, we need to assign |
| 539 | // less threads to the main compute pool by default. |
| 540 | thread::ThreadPool* device_thread_pool = |
| 541 | item.device->tensorflow_device_thread_pool(); |
| 542 | if (!device_thread_pool) { |
| 543 | args.runner = default_runner; |
| 544 | } else { |
| 545 | args.runner = [this, device_thread_pool](Executor::Args::Closure c) { |
| 546 | SchedClosure(device_thread_pool, std::move(c)); |
| 547 | }; |
| 548 | } |
| 549 | item.executor->RunAsync(args, barrier->Get()); |
| 550 | } |
| 551 | |
| 552 | WaitForNotification(&run_state, &step_cancellation_manager, |
| 553 | run_options.timeout_in_ms() > 0 |
| 554 | ? run_options.timeout_in_ms() |
| 555 | : operation_timeout_in_ms_); |
| 556 | |
| 557 | if (!cancellation_manager_->DeregisterCallback(cancellation_token)) { |
| 558 | // The step has been cancelled: make sure we don't attempt to receive the |
| 559 | // outputs as this would make it block forever. |
| 560 | mutex_lock l(run_state.mu_); |
| 561 | run_state.status.Update(errors::Cancelled("Run call was cancelled" )); |
| 562 | } |
| 563 | |
| 564 | if (tracer) { |
| 565 | TF_RETURN_IF_ERROR(tracer->Stop()); |
| 566 | TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector)); |
| 567 | } |
| 568 | |
| 569 | { |
| 570 | mutex_lock l(run_state.mu_); |
| 571 | TF_RETURN_IF_ERROR(run_state.status); |
| 572 | } |
| 573 | |
| 574 | // Save the output tensors of this run we choose to keep. |
| 575 | if (!run_state.tensor_store.empty()) { |
| 576 | TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors( |
| 577 | {executors_and_keys->callable_options.fetch().begin(), |
| 578 | executors_and_keys->callable_options.fetch().end()}, |
| 579 | &session_state_)); |
| 580 | } |
| 581 | |
| 582 | if (args.stats_collector) { |
| 583 | args.stats_collector->Finalize(); |
| 584 | } |
| 585 | |
| 586 | // Build and return the cost model as instructed. |
| 587 | if (update_cost_model) { |
| 588 | // Build the cost model |
| 589 | std::unordered_map<string, const Graph*> device_to_graph; |
| 590 | for (const PerPartitionExecutorsAndLib& partition : |
| 591 | executors_and_keys->items) { |
| 592 | const Graph* graph = partition.graph; |
| 593 | const string device = partition.flib->device()->name(); |
| 594 | device_to_graph[device] = graph; |
| 595 | } |
| 596 | |
| 597 | mutex_lock l(executor_lock_); |
| 598 | args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph); |
| 599 | |
| 600 | // annotate stats onto cost graph. |
| 601 | CostGraphDef* cost_graph = run_metadata->mutable_cost_graph(); |
| 602 | for (const auto& item : executors_and_keys->items) { |
| 603 | TF_RETURN_IF_ERROR( |
| 604 | cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph)); |
| 605 | } |
| 606 | } |
| 607 | |
| 608 | // If requested via RunOptions, output the partition graphs. |
| 609 | if (run_options.output_partition_graphs()) { |
| 610 | protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = |
| 611 | run_metadata->mutable_partition_graphs(); |
| 612 | for (const PerPartitionExecutorsAndLib& exec_and_lib : |
| 613 | executors_and_keys->items) { |
| 614 | GraphDef* partition_graph_def = partition_graph_defs->Add(); |
| 615 | exec_and_lib.graph->ToGraphDef(partition_graph_def); |
| 616 | } |
| 617 | } |
| 618 | |
| 619 | return Status::OK(); |
| 620 | } |
| 621 | |
| 622 | Status DirectSession::Run(const RunOptions& run_options, |
| 623 | const NamedTensorList& inputs, |
| 624 | const std::vector<string>& output_names, |
| 625 | const std::vector<string>& target_nodes, |
| 626 | std::vector<Tensor>* outputs, |
| 627 | RunMetadata* run_metadata) { |
| 628 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
| 629 | TF_RETURN_IF_ERROR(CheckGraphCreated("Run()" )); |
| 630 | direct_session_runs->GetCell()->IncrementBy(1); |
| 631 | |
| 632 | // Extract the inputs names for this run of the session. |
| 633 | std::vector<string> input_tensor_names; |
| 634 | input_tensor_names.reserve(inputs.size()); |
| 635 | for (const auto& it : inputs) { |
| 636 | input_tensor_names.push_back(it.first); |
| 637 | } |
| 638 | |
| 639 | // Check if we already have an executor for these arguments. |
| 640 | ExecutorsAndKeys* executors_and_keys; |
| 641 | RunStateArgs run_state_args(run_options.debug_options()); |
| 642 | |
| 643 | TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, |
| 644 | target_nodes, &executors_and_keys, |
| 645 | &run_state_args)); |
| 646 | |
| 647 | // Configure a call frame for the step, which we use to feed and |
| 648 | // fetch values to and from the executors. |
| 649 | FunctionCallFrame call_frame(executors_and_keys->input_types, |
| 650 | executors_and_keys->output_types); |
| 651 | gtl::InlinedVector<Tensor, 4> feed_args(inputs.size()); |
| 652 | for (const auto& it : inputs) { |
| 653 | if (it.second.dtype() == DT_RESOURCE) { |
| 654 | Tensor tensor_from_handle; |
| 655 | TF_RETURN_IF_ERROR( |
| 656 | ResourceHandleToInputTensor(it.second, &tensor_from_handle)); |
| 657 | feed_args[executors_and_keys->input_name_to_index[it.first]] = |
| 658 | tensor_from_handle; |
| 659 | } else { |
| 660 | feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second; |
| 661 | } |
| 662 | } |
| 663 | const Status s = call_frame.SetArgs(feed_args); |
| 664 | if (errors::IsInternal(s)) { |
| 665 | return errors::InvalidArgument(s.error_message()); |
| 666 | } else if (!s.ok()) { |
| 667 | return s; |
| 668 | } |
| 669 | |
| 670 | const int64 step_id = step_id_counter_.fetch_add(1); |
| 671 | |
| 672 | if (LogMemory::IsEnabled()) { |
| 673 | LogMemory::RecordStep(step_id, run_state_args.handle); |
| 674 | } |
| 675 | |
| 676 | TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame, |
| 677 | executors_and_keys, run_metadata)); |
| 678 | |
| 679 | // Receive outputs. |
| 680 | if (outputs) { |
| 681 | std::vector<Tensor> sorted_outputs; |
| 682 | const Status s = call_frame.ConsumeRetvals(&sorted_outputs); |
| 683 | if (errors::IsInternal(s)) { |
| 684 | return errors::InvalidArgument(s.error_message()); |
| 685 | } else if (!s.ok()) { |
| 686 | return s; |
| 687 | } |
| 688 | const bool unique_outputs = |
| 689 | output_names.size() == executors_and_keys->output_name_to_index.size(); |
| 690 | // first_indices[i] = j implies that j is the smallest value for which |
| 691 | // output_names[i] == output_names[j]. |
| 692 | std::vector<int> first_indices; |
| 693 | if (!unique_outputs) { |
| 694 | first_indices.resize(output_names.size()); |
| 695 | for (int i = 0; i < output_names.size(); ++i) { |
| 696 | for (int j = 0; j <= i; ++j) { |
| 697 | if (output_names[i] == output_names[j]) { |
| 698 | first_indices[i] = j; |
| 699 | break; |
| 700 | } |
| 701 | } |
| 702 | } |
| 703 | } |
| 704 | outputs->clear(); |
| 705 | outputs->reserve(sorted_outputs.size()); |
| 706 | for (int i = 0; i < output_names.size(); ++i) { |
| 707 | const string& output_name = output_names[i]; |
| 708 | if (first_indices.empty() || first_indices[i] == i) { |
| 709 | outputs->emplace_back( |
| 710 | std::move(sorted_outputs[executors_and_keys |
| 711 | ->output_name_to_index[output_name]])); |
| 712 | } else { |
| 713 | outputs->push_back((*outputs)[first_indices[i]]); |
| 714 | } |
| 715 | } |
| 716 | } |
| 717 | |
| 718 | return Status::OK(); |
| 719 | } |
| 720 | |
| 721 | Status DirectSession::PRunSetup(const std::vector<string>& input_names, |
| 722 | const std::vector<string>& output_names, |
| 723 | const std::vector<string>& target_nodes, |
| 724 | string* handle) { |
| 725 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
| 726 | TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()" )); |
| 727 | |
| 728 | // RunOptions is not available in PRunSetup, so use thread pool 0. |
| 729 | thread::ThreadPool* pool = thread_pools_[0].first; |
| 730 | |
| 731 | // Check if we already have an executor for these arguments. |
| 732 | ExecutorsAndKeys* executors_and_keys; |
| 733 | // TODO(cais): TFDBG support for partial runs. |
| 734 | DebugOptions debug_options; |
| 735 | RunStateArgs run_state_args(debug_options); |
| 736 | run_state_args.is_partial_run = true; |
| 737 | TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names, |
| 738 | target_nodes, &executors_and_keys, |
| 739 | &run_state_args)); |
| 740 | |
| 741 | // Create the run state and save it for future PRun calls. |
| 742 | Executor::Args args; |
| 743 | args.step_id = step_id_counter_.fetch_add(1); |
| 744 | RunState* run_state = |
| 745 | new RunState(input_names, output_names, args.step_id, &devices_); |
| 746 | run_state->rendez = new IntraProcessRendezvous(device_mgr_.get()); |
| 747 | { |
| 748 | mutex_lock l(executor_lock_); |
| 749 | if (!partial_runs_ |
| 750 | .emplace(run_state_args.handle, |
| 751 | std::unique_ptr<RunState>(run_state)) |
| 752 | .second) { |
| 753 | return errors::Internal("The handle '" , run_state_args.handle, |
| 754 | "' created for this partial run is not unique." ); |
| 755 | } |
| 756 | } |
| 757 | |
| 758 | // Start parallel Executors. |
| 759 | const size_t num_executors = executors_and_keys->items.size(); |
| 760 | ExecutorBarrier* barrier = new ExecutorBarrier( |
| 761 | num_executors, run_state->rendez, [run_state](const Status& ret) { |
| 762 | if (!ret.ok()) { |
| 763 | mutex_lock l(run_state->mu_); |
| 764 | run_state->status.Update(ret); |
| 765 | } |
| 766 | run_state->executors_done.Notify(); |
| 767 | }); |
| 768 | |
| 769 | args.rendezvous = run_state->rendez; |
| 770 | args.cancellation_manager = cancellation_manager_; |
| 771 | args.runner = [this, pool](Executor::Args::Closure c) { |
| 772 | SchedClosure(pool, std::move(c)); |
| 773 | }; |
| 774 | args.session_state = &session_state_; |
| 775 | args.tensor_store = &run_state->tensor_store; |
| 776 | args.step_container = &run_state->step_container; |
| 777 | if (LogMemory::IsEnabled()) { |
| 778 | LogMemory::RecordStep(args.step_id, run_state_args.handle); |
| 779 | } |
| 780 | args.sync_on_finish = sync_on_finish_; |
| 781 | |
| 782 | if (options_.config.graph_options().build_cost_model()) { |
| 783 | run_state->collector.reset(new StepStatsCollector(nullptr)); |
| 784 | args.stats_collector = run_state->collector.get(); |
| 785 | } |
| 786 | |
| 787 | for (auto& item : executors_and_keys->items) { |
| 788 | item.executor->RunAsync(args, barrier->Get()); |
| 789 | } |
| 790 | |
| 791 | *handle = run_state_args.handle; |
| 792 | return Status::OK(); |
| 793 | } |
| 794 | |
| 795 | Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, |
| 796 | const std::vector<string>& output_names, |
| 797 | std::vector<Tensor>* outputs) { |
| 798 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
| 799 | std::vector<string> parts = str_util::Split(handle, ';'); |
| 800 | const string& key = parts[0]; |
| 801 | // Get the executors for this partial run. |
| 802 | ExecutorsAndKeys* executors_and_keys; |
| 803 | RunState* run_state; |
| 804 | { |
| 805 | mutex_lock l(executor_lock_); // could use reader lock |
| 806 | auto exc_it = executors_.find(key); |
| 807 | if (exc_it == executors_.end()) { |
| 808 | return errors::InvalidArgument( |
| 809 | "Must run 'setup' before performing partial runs!" ); |
| 810 | } |
| 811 | executors_and_keys = exc_it->second.get(); |
| 812 | |
| 813 | auto prun_it = partial_runs_.find(handle); |
| 814 | if (prun_it == partial_runs_.end()) { |
| 815 | return errors::InvalidArgument( |
| 816 | "Must run 'setup' before performing partial runs!" ); |
| 817 | } |
| 818 | run_state = prun_it->second.get(); |
| 819 | |
| 820 | // Make sure that this is a new set of feeds that are still pending. |
| 821 | for (const auto& input : inputs) { |
| 822 | auto it = run_state->pending_inputs.find(input.first); |
| 823 | if (it == run_state->pending_inputs.end()) { |
| 824 | return errors::InvalidArgument( |
| 825 | "The feed " , input.first, |
| 826 | " was not specified in partial_run_setup." ); |
| 827 | } else if (it->second) { |
| 828 | return errors::InvalidArgument("The feed " , input.first, |
| 829 | " has already been fed." ); |
| 830 | } |
| 831 | } |
| 832 | // Check that this is a new set of fetches that are still pending. |
| 833 | for (const auto& output : output_names) { |
| 834 | auto it = run_state->pending_outputs.find(output); |
| 835 | if (it == run_state->pending_outputs.end()) { |
| 836 | return errors::InvalidArgument( |
| 837 | "The fetch " , output, " was not specified in partial_run_setup." ); |
| 838 | } else if (it->second) { |
| 839 | return errors::InvalidArgument("The fetch " , output, |
| 840 | " has already been fetched." ); |
| 841 | } |
| 842 | } |
| 843 | } |
| 844 | |
| 845 | // Check that this new set of fetches can be computed from all the |
| 846 | // feeds we have supplied. |
| 847 | TF_RETURN_IF_ERROR( |
| 848 | CheckFetch(inputs, output_names, executors_and_keys, run_state)); |
| 849 | |
| 850 | // Send inputs. |
| 851 | Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez); |
| 852 | |
| 853 | // Receive outputs. |
| 854 | if (s.ok()) { |
| 855 | s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs); |
| 856 | } |
| 857 | |
| 858 | // Save the output tensors of this run we choose to keep. |
| 859 | if (s.ok()) { |
| 860 | s = run_state->tensor_store.SaveTensors(output_names, &session_state_); |
| 861 | } |
| 862 | |
| 863 | { |
| 864 | mutex_lock l(executor_lock_); |
| 865 | // Delete the run state if there is an error or all fetches are done. |
| 866 | bool done = true; |
| 867 | if (s.ok()) { |
| 868 | { |
| 869 | mutex_lock l(run_state->mu_); |
| 870 | if (!run_state->status.ok()) { |
| 871 | LOG(WARNING) << "An error unrelated to this prun has been detected. " |
| 872 | << run_state->status; |
| 873 | } |
| 874 | } |
| 875 | for (const auto& input : inputs) { |
| 876 | auto it = run_state->pending_inputs.find(input.first); |
| 877 | it->second = true; |
| 878 | } |
| 879 | for (const auto& name : output_names) { |
| 880 | auto it = run_state->pending_outputs.find(name); |
| 881 | it->second = true; |
| 882 | } |
| 883 | done = run_state->PendingDone(); |
| 884 | } |
| 885 | if (done) { |
| 886 | WaitForNotification(run_state, cancellation_manager_, |
| 887 | operation_timeout_in_ms_); |
| 888 | partial_runs_.erase(handle); |
| 889 | } |
| 890 | } |
| 891 | |
| 892 | return s; |
| 893 | } |
| 894 | |
| 895 | Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor, |
| 896 | Tensor* retrieved_tensor) { |
| 897 | if (resource_tensor.dtype() != DT_RESOURCE) { |
| 898 | return errors::InvalidArgument(strings::StrCat( |
| 899 | "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: " , |
| 900 | resource_tensor.dtype())); |
| 901 | } |
| 902 | |
| 903 | const ResourceHandle& resource_handle = |
| 904 | resource_tensor.scalar<ResourceHandle>()(); |
| 905 | |
| 906 | if (resource_handle.container() == |
| 907 | SessionState::kTensorHandleResourceTypeName) { |
| 908 | return session_state_.GetTensor(resource_handle.name(), retrieved_tensor); |
| 909 | } else { |
| 910 | return errors::InvalidArgument(strings::StrCat( |
| 911 | "Invalid resource type hash code: " , resource_handle.hash_code(), |
| 912 | "(name: " , resource_handle.name(), |
| 913 | " type: " , resource_handle.maybe_type_name(), |
| 914 | "). Perhaps a resource tensor was being provided as a feed? That is " |
| 915 | "not currently allowed. Please file an issue at " |
| 916 | "https://github.com/tensorflow/tensorflow/issues/new, ideally with a " |
| 917 | "short code snippet that leads to this error message." )); |
| 918 | } |
| 919 | } |
| 920 | |
| 921 | Status DirectSession::SendPRunInputs(const NamedTensorList& inputs, |
| 922 | const ExecutorsAndKeys* executors_and_keys, |
| 923 | IntraProcessRendezvous* rendez) { |
| 924 | Status s; |
| 925 | Rendezvous::ParsedKey parsed; |
| 926 | // Insert the input tensors into the local rendezvous by their |
| 927 | // rendezvous key. |
| 928 | for (const auto& input : inputs) { |
| 929 | auto it = |
| 930 | executors_and_keys->input_name_to_rendezvous_key.find(input.first); |
| 931 | if (it == executors_and_keys->input_name_to_rendezvous_key.end()) { |
| 932 | return errors::Internal("'" , input.first, "' is not a pre-defined feed." ); |
| 933 | } |
| 934 | const string& input_key = it->second; |
| 935 | |
| 936 | s = Rendezvous::ParseKey(input_key, &parsed); |
| 937 | if (!s.ok()) { |
| 938 | rendez->StartAbort(s); |
| 939 | return s; |
| 940 | } |
| 941 | |
| 942 | if (input.second.dtype() == DT_RESOURCE) { |
| 943 | Tensor tensor_from_handle; |
| 944 | s = ResourceHandleToInputTensor(input.second, &tensor_from_handle); |
| 945 | if (s.ok()) { |
| 946 | s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false); |
| 947 | } |
| 948 | } else { |
| 949 | s = rendez->Send(parsed, Rendezvous::Args(), input.second, false); |
| 950 | } |
| 951 | |
| 952 | if (!s.ok()) { |
| 953 | rendez->StartAbort(s); |
| 954 | return s; |
| 955 | } |
| 956 | } |
| 957 | return Status::OK(); |
| 958 | } |
| 959 | |
| 960 | Status DirectSession::RecvPRunOutputs( |
| 961 | const std::vector<string>& output_names, |
| 962 | const ExecutorsAndKeys* executors_and_keys, RunState* run_state, |
| 963 | std::vector<Tensor>* outputs) { |
| 964 | Status s; |
| 965 | if (!output_names.empty()) { |
| 966 | outputs->resize(output_names.size()); |
| 967 | } |
| 968 | |
| 969 | Rendezvous::ParsedKey parsed; |
| 970 | // Get the outputs from the rendezvous |
| 971 | for (size_t output_offset = 0; output_offset < output_names.size(); |
| 972 | ++output_offset) { |
| 973 | const string& output_name = output_names[output_offset]; |
| 974 | auto it = |
| 975 | executors_and_keys->output_name_to_rendezvous_key.find(output_name); |
| 976 | if (it == executors_and_keys->output_name_to_rendezvous_key.end()) { |
| 977 | return errors::Internal("'" , output_name, |
| 978 | "' is not a pre-defined fetch." ); |
| 979 | } |
| 980 | const string& output_key = it->second; |
| 981 | Tensor output_tensor; |
| 982 | bool is_dead; |
| 983 | IntraProcessRendezvous* rendez = run_state->rendez; |
| 984 | |
| 985 | s = Rendezvous::ParseKey(output_key, &parsed); |
| 986 | if (s.ok()) { |
| 987 | // Fetch data from the Rendezvous. |
| 988 | s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead, |
| 989 | operation_timeout_in_ms_); |
| 990 | if (is_dead && s.ok()) { |
| 991 | s = errors::InvalidArgument("The tensor returned for " , output_name, |
| 992 | " was not valid." ); |
| 993 | } |
| 994 | } |
| 995 | if (!s.ok()) { |
| 996 | rendez->StartAbort(s); |
| 997 | outputs->clear(); |
| 998 | return s; |
| 999 | } |
| 1000 | |
| 1001 | (*outputs)[output_offset] = output_tensor; |
| 1002 | } |
| 1003 | return Status::OK(); |
| 1004 | } |
| 1005 | |
| 1006 | Status DirectSession::CheckFetch(const NamedTensorList& feeds, |
| 1007 | const std::vector<string>& fetches, |
| 1008 | const ExecutorsAndKeys* executors_and_keys, |
| 1009 | const RunState* run_state) { |
| 1010 | const Graph* graph = executors_and_keys->graph.get(); |
| 1011 | const NameNodeMap* name_to_node = &executors_and_keys->name_to_node; |
| 1012 | |
| 1013 | // Build the set of pending feeds that we haven't seen. |
| 1014 | std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; |
| 1015 | { |
| 1016 | mutex_lock l(executor_lock_); |
| 1017 | for (const auto& input : run_state->pending_inputs) { |
| 1018 | // Skip if the feed has already been fed. |
| 1019 | if (input.second) continue; |
| 1020 | TensorId id(ParseTensorName(input.first)); |
| 1021 | auto it = name_to_node->find(id.first); |
| 1022 | if (it == name_to_node->end()) { |
| 1023 | return errors::NotFound("Feed " , input.first, ": not found" ); |
| 1024 | } |
| 1025 | pending_feeds.insert(id); |
| 1026 | } |
| 1027 | } |
| 1028 | for (const auto& it : feeds) { |
| 1029 | TensorId id(ParseTensorName(it.first)); |
| 1030 | pending_feeds.erase(id); |
| 1031 | } |
| 1032 | |
| 1033 | // Initialize the stack with the fetch nodes. |
| 1034 | std::vector<const Node*> stack; |
| 1035 | for (const string& fetch : fetches) { |
| 1036 | TensorId id(ParseTensorName(fetch)); |
| 1037 | auto it = name_to_node->find(id.first); |
| 1038 | if (it == name_to_node->end()) { |
| 1039 | return errors::NotFound("Fetch " , fetch, ": not found" ); |
| 1040 | } |
| 1041 | stack.push_back(it->second); |
| 1042 | } |
| 1043 | |
| 1044 | // Any tensor needed for fetches can't be in pending_feeds. |
| 1045 | std::vector<bool> visited(graph->num_node_ids(), false); |
| 1046 | while (!stack.empty()) { |
| 1047 | const Node* n = stack.back(); |
| 1048 | stack.pop_back(); |
| 1049 | |
| 1050 | for (const Edge* in_edge : n->in_edges()) { |
| 1051 | const Node* in_node = in_edge->src(); |
| 1052 | if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) { |
| 1053 | return errors::InvalidArgument("Fetch " , in_node->name(), ":" , |
| 1054 | in_edge->src_output(), |
| 1055 | " can't be computed from the feeds" |
| 1056 | " that have been fed so far." ); |
| 1057 | } |
| 1058 | if (!visited[in_node->id()]) { |
| 1059 | visited[in_node->id()] = true; |
| 1060 | stack.push_back(in_node); |
| 1061 | } |
| 1062 | } |
| 1063 | } |
| 1064 | return Status::OK(); |
| 1065 | } |
| 1066 | |
| 1067 | Status DirectSession::CreateExecutors( |
| 1068 | const CallableOptions& callable_options, |
| 1069 | std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys, |
| 1070 | std::unique_ptr<FunctionInfo>* out_func_info, |
| 1071 | RunStateArgs* run_state_args) { |
| 1072 | BuildGraphOptions options; |
| 1073 | options.callable_options = callable_options; |
| 1074 | options.use_function_convention = !run_state_args->is_partial_run; |
| 1075 | |
| 1076 | std::unique_ptr<FunctionInfo> func_info(new FunctionInfo); |
| 1077 | std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); |
| 1078 | |
| 1079 | ek->callable_options = callable_options; |
| 1080 | |
| 1081 | std::unordered_map<string, std::unique_ptr<Graph>> graphs; |
| 1082 | TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def, |
| 1083 | run_state_args, &ek->input_types, |
| 1084 | &ek->output_types)); |
| 1085 | |
| 1086 | if (run_state_args->is_partial_run) { |
| 1087 | ek->graph = std::move(run_state_args->graph); |
| 1088 | std::unordered_set<StringPiece, StringPieceHasher> names; |
| 1089 | for (const string& input : callable_options.feed()) { |
| 1090 | TensorId id(ParseTensorName(input)); |
| 1091 | names.emplace(id.first); |
| 1092 | } |
| 1093 | for (const string& output : callable_options.fetch()) { |
| 1094 | TensorId id(ParseTensorName(output)); |
| 1095 | names.emplace(id.first); |
| 1096 | } |
| 1097 | for (Node* n : ek->graph->nodes()) { |
| 1098 | if (names.count(n->name()) > 0) { |
| 1099 | ek->name_to_node.insert({n->name(), n}); |
| 1100 | } |
| 1101 | } |
| 1102 | } |
| 1103 | ek->items.reserve(graphs.size()); |
| 1104 | const auto& optimizer_opts = |
| 1105 | options_.config.graph_options().optimizer_options(); |
| 1106 | |
| 1107 | int graph_def_version; |
| 1108 | { |
| 1109 | mutex_lock l(graph_def_lock_); |
| 1110 | graph_def_version = |
| 1111 | execution_state_->original_graph_def().versions().producer(); |
| 1112 | } |
| 1113 | func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( |
| 1114 | device_mgr_.get(), options_.env, graph_def_version, |
| 1115 | func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first)); |
| 1116 | |
| 1117 | GraphOptimizer optimizer(optimizer_opts); |
| 1118 | for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { |
| 1119 | const string& partition_name = iter->first; |
| 1120 | std::unique_ptr<Graph>& partition_graph = iter->second; |
| 1121 | |
| 1122 | Device* device; |
| 1123 | TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device)); |
| 1124 | |
| 1125 | ek->items.resize(ek->items.size() + 1); |
| 1126 | auto* item = &(ek->items.back()); |
| 1127 | auto lib = func_info->proc_flr->GetFLR(partition_name); |
| 1128 | if (lib == nullptr) { |
| 1129 | return errors::Internal("Could not find device: " , partition_name); |
| 1130 | } |
| 1131 | item->flib = lib; |
| 1132 | |
| 1133 | LocalExecutorParams params; |
| 1134 | params.device = device; |
| 1135 | params.function_library = lib; |
| 1136 | auto opseg = device->op_segment(); |
| 1137 | params.create_kernel = [this, lib, opseg](const NodeDef& ndef, |
| 1138 | OpKernel** kernel) { |
| 1139 | // We do not share the kernel via the OpSegment if the node is |
| 1140 | // stateless, or a function. |
| 1141 | // NOTE(mrry): We must not share function kernels (implemented |
| 1142 | // using `CallOp`) between subgraphs, because `CallOp::handle_` |
| 1143 | // is tied to a particular subgraph. Even if the function itself |
| 1144 | // is stateful, the `CallOp` that invokes it is not. |
| 1145 | if (!lib->IsStateful(ndef.op()) || |
| 1146 | lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { |
| 1147 | return lib->CreateKernel(ndef, kernel); |
| 1148 | } |
| 1149 | auto create_fn = [lib, &ndef](OpKernel** kernel) { |
| 1150 | return lib->CreateKernel(ndef, kernel); |
| 1151 | }; |
| 1152 | // Kernels created for subgraph nodes need to be cached. On |
| 1153 | // cache miss, create_fn() is invoked to create a kernel based |
| 1154 | // on the function library here + global op registry. |
| 1155 | return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, |
| 1156 | create_fn); |
| 1157 | }; |
| 1158 | params.delete_kernel = [lib](OpKernel* kernel) { |
| 1159 | // If the node is stateful, opseg owns it. Otherwise, delete it. |
| 1160 | if (kernel && !lib->IsStateful(kernel->type_string())) { |
| 1161 | delete kernel; |
| 1162 | } |
| 1163 | }; |
| 1164 | params.node_outputs_cb = node_outputs_callback_; |
| 1165 | |
| 1166 | optimizer.Optimize(lib, options_.env, device, &iter->second, |
| 1167 | /*shape_map=*/nullptr); |
| 1168 | |
| 1169 | // EXPERIMENTAL: tfdbg inserts debug nodes in the graph. |
| 1170 | const DebugOptions& debug_options = |
| 1171 | options.callable_options.run_options().debug_options(); |
| 1172 | if (!debug_options.debug_tensor_watch_opts().empty()) { |
| 1173 | TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( |
| 1174 | debug_options, partition_graph.get(), params.device)); |
| 1175 | } |
| 1176 | |
| 1177 | TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), |
| 1178 | device->name(), |
| 1179 | partition_graph.get())); |
| 1180 | // NewLocalExecutor takes ownership of partition_graph. |
| 1181 | item->graph = partition_graph.get(); |
| 1182 | item->executor = nullptr; |
| 1183 | item->device = device; |
| 1184 | Executor* executor; |
| 1185 | TF_RETURN_IF_ERROR( |
| 1186 | NewLocalExecutor(params, std::move(partition_graph), &executor)); |
| 1187 | item->executor.reset(executor); |
| 1188 | } |
| 1189 | |
| 1190 | // Cache the mapping from input/output names to graph elements to |
| 1191 | // avoid recomputing it every time. |
| 1192 | if (!run_state_args->is_partial_run) { |
| 1193 | // For regular `Run()`, we use the function calling convention, and so |
| 1194 | // maintain a mapping from input/output names to |
| 1195 | // argument/return-value ordinal index. |
| 1196 | for (int i = 0; i < callable_options.feed().size(); ++i) { |
| 1197 | const string& input = callable_options.feed(i); |
| 1198 | ek->input_name_to_index[input] = i; |
| 1199 | } |
| 1200 | for (int i = 0; i < callable_options.fetch().size(); ++i) { |
| 1201 | const string& output = callable_options.fetch(i); |
| 1202 | ek->output_name_to_index[output] = i; |
| 1203 | } |
| 1204 | } else { |
| 1205 | // For `PRun()`, we use the rendezvous calling convention, and so |
| 1206 | // maintain a mapping from input/output names to rendezvous keys. |
| 1207 | // |
| 1208 | // We always use the first device as the device name portion of the |
| 1209 | // key, even if we're feeding another graph. |
| 1210 | for (int i = 0; i < callable_options.feed().size(); ++i) { |
| 1211 | const string& input = callable_options.feed(i); |
| 1212 | ek->input_name_to_rendezvous_key[input] = GetRendezvousKey( |
| 1213 | input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); |
| 1214 | } |
| 1215 | for (int i = 0; i < callable_options.fetch().size(); ++i) { |
| 1216 | const string& output = callable_options.fetch(i); |
| 1217 | ek->output_name_to_rendezvous_key[output] = |
| 1218 | GetRendezvousKey(output, device_set_.client_device()->attributes(), |
| 1219 | FrameAndIter(0, 0)); |
| 1220 | } |
| 1221 | } |
| 1222 | |
| 1223 | *out_executors_and_keys = std::move(ek); |
| 1224 | *out_func_info = std::move(func_info); |
| 1225 | return Status::OK(); |
| 1226 | } |
| 1227 | |
| 1228 | Status DirectSession::GetOrCreateExecutors( |
| 1229 | gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, |
| 1230 | gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys, |
| 1231 | RunStateArgs* run_state_args) { |
| 1232 | int64 handle_name_counter_value = -1; |
| 1233 | if (LogMemory::IsEnabled() || run_state_args->is_partial_run) { |
| 1234 | handle_name_counter_value = handle_name_counter_.fetch_add(1); |
| 1235 | } |
| 1236 | |
| 1237 | string debug_tensor_watches_summary; |
| 1238 | if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) { |
| 1239 | debug_tensor_watches_summary = SummarizeDebugTensorWatches( |
| 1240 | run_state_args->debug_options.debug_tensor_watch_opts()); |
| 1241 | } |
| 1242 | |
| 1243 | // Fast lookup path, no sorting. |
| 1244 | const string key = strings::StrCat( |
| 1245 | str_util::Join(inputs, "," ), "->" , str_util::Join(outputs, "," ), "/" , |
| 1246 | str_util::Join(target_nodes, "," ), "/" , run_state_args->is_partial_run, |
| 1247 | "/" , debug_tensor_watches_summary); |
| 1248 | // Set the handle, if it's needed to log memory or for partial run. |
| 1249 | if (handle_name_counter_value >= 0) { |
| 1250 | run_state_args->handle = |
| 1251 | strings::StrCat(key, ";" , handle_name_counter_value); |
| 1252 | } |
| 1253 | |
| 1254 | // See if we already have the executors for this run. |
| 1255 | { |
| 1256 | mutex_lock l(executor_lock_); // could use reader lock |
| 1257 | auto it = executors_.find(key); |
| 1258 | if (it != executors_.end()) { |
| 1259 | *executors_and_keys = it->second.get(); |
| 1260 | return Status::OK(); |
| 1261 | } |
| 1262 | } |
| 1263 | |
| 1264 | // Slow lookup path, the unsorted key missed the cache. |
| 1265 | // Sort the inputs and outputs, and look up with the sorted key in case an |
| 1266 | // earlier call used a different order of inputs and outputs. |
| 1267 | // |
| 1268 | // We could consider some other signature instead of sorting that |
| 1269 | // preserves the same property to avoid the sort in the future. |
| 1270 | std::vector<string> inputs_sorted(inputs.begin(), inputs.end()); |
| 1271 | std::sort(inputs_sorted.begin(), inputs_sorted.end()); |
| 1272 | std::vector<string> outputs_sorted(outputs.begin(), outputs.end()); |
| 1273 | std::sort(outputs_sorted.begin(), outputs_sorted.end()); |
| 1274 | std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end()); |
| 1275 | std::sort(tn_sorted.begin(), tn_sorted.end()); |
| 1276 | |
| 1277 | const string sorted_key = strings::StrCat( |
| 1278 | str_util::Join(inputs_sorted, "," ), "->" , |
| 1279 | str_util::Join(outputs_sorted, "," ), "/" , str_util::Join(tn_sorted, "," ), |
| 1280 | "/" , run_state_args->is_partial_run, "/" , debug_tensor_watches_summary); |
| 1281 | // Set the handle, if its needed to log memory or for partial run. |
| 1282 | if (handle_name_counter_value >= 0) { |
| 1283 | run_state_args->handle = |
| 1284 | strings::StrCat(sorted_key, ";" , handle_name_counter_value); |
| 1285 | } |
| 1286 | |
| 1287 | // See if we already have the executors for this run. |
| 1288 | { |
| 1289 | mutex_lock l(executor_lock_); |
| 1290 | auto it = executors_.find(sorted_key); |
| 1291 | if (it != executors_.end()) { |
| 1292 | *executors_and_keys = it->second.get(); |
| 1293 | // Insert this under the original key. |
| 1294 | executors_.emplace(key, it->second); |
| 1295 | return Status::OK(); |
| 1296 | } |
| 1297 | } |
| 1298 | |
| 1299 | // Nothing found, so create the executors and store in the cache. |
| 1300 | // The executor_lock_ is intentionally released while executors are |
| 1301 | // being created. |
| 1302 | CallableOptions callable_options; |
| 1303 | for (const string& input : inputs_sorted) { |
| 1304 | callable_options.add_feed(input); |
| 1305 | } |
| 1306 | for (const string& output : outputs_sorted) { |
| 1307 | callable_options.add_fetch(output); |
| 1308 | } |
| 1309 | for (const string& target : tn_sorted) { |
| 1310 | callable_options.add_target(target); |
| 1311 | } |
| 1312 | *callable_options.mutable_run_options()->mutable_debug_options() = |
| 1313 | run_state_args->debug_options; |
| 1314 | std::unique_ptr<ExecutorsAndKeys> ek; |
| 1315 | std::unique_ptr<FunctionInfo> func_info; |
| 1316 | TF_RETURN_IF_ERROR( |
| 1317 | CreateExecutors(callable_options, &ek, &func_info, run_state_args)); |
| 1318 | |
| 1319 | // Reacquire the lock, try to insert into the map. |
| 1320 | mutex_lock l(executor_lock_); |
| 1321 | functions_.push_back(std::move(func_info)); |
| 1322 | |
| 1323 | // Another thread may have created the entry before us, in which case we will |
| 1324 | // reuse the already created one. |
| 1325 | auto insert_result = executors_.emplace( |
| 1326 | sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek))); |
| 1327 | // Insert the value under the original key, so the fast path lookup will work |
| 1328 | // if the user uses the same order of inputs, outputs, and targets again. |
| 1329 | executors_.emplace(key, insert_result.first->second); |
| 1330 | *executors_and_keys = insert_result.first->second.get(); |
| 1331 | |
| 1332 | return Status::OK(); |
| 1333 | } |
| 1334 | |
| 1335 | Status DirectSession::CreateGraphs( |
| 1336 | const BuildGraphOptions& subgraph_options, |
| 1337 | std::unordered_map<string, std::unique_ptr<Graph>>* outputs, |
| 1338 | std::unique_ptr<FunctionLibraryDefinition>* flib_def, |
| 1339 | RunStateArgs* run_state_args, DataTypeVector* input_types, |
| 1340 | DataTypeVector* output_types) { |
| 1341 | mutex_lock l(graph_def_lock_); |
| 1342 | std::unique_ptr<ClientGraph> client_graph; |
| 1343 | |
| 1344 | std::unique_ptr<GraphExecutionState> temp_exec_state_holder; |
| 1345 | GraphExecutionState* execution_state = nullptr; |
| 1346 | if (options_.config.graph_options().place_pruned_graph()) { |
| 1347 | // Because we are placing pruned graphs, we need to create a |
| 1348 | // new GraphExecutionState for every new unseen graph, |
| 1349 | // and then place it. |
| 1350 | GraphExecutionStateOptions prune_options; |
| 1351 | prune_options.device_set = &device_set_; |
| 1352 | prune_options.session_options = &options_; |
| 1353 | prune_options.stateful_placements = stateful_placements_; |
| 1354 | TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph( |
| 1355 | execution_state_->original_graph_def().library(), prune_options, |
| 1356 | execution_state_->original_graph_def(), subgraph_options, |
| 1357 | &temp_exec_state_holder, &client_graph)); |
| 1358 | execution_state = temp_exec_state_holder.get(); |
| 1359 | } else { |
| 1360 | execution_state = execution_state_.get(); |
| 1361 | TF_RETURN_IF_ERROR( |
| 1362 | execution_state->BuildGraph(subgraph_options, &client_graph)); |
| 1363 | } |
| 1364 | |
| 1365 | if (subgraph_options.callable_options.feed_size() != |
| 1366 | client_graph->feed_types.size()) { |
| 1367 | return errors::Internal( |
| 1368 | "Graph pruning failed: requested number of feed endpoints = " , |
| 1369 | subgraph_options.callable_options.feed_size(), |
| 1370 | " versus number of pruned feed endpoints = " , |
| 1371 | client_graph->feed_types.size()); |
| 1372 | } |
| 1373 | if (subgraph_options.callable_options.fetch_size() != |
| 1374 | client_graph->fetch_types.size()) { |
| 1375 | return errors::Internal( |
| 1376 | "Graph pruning failed: requested number of fetch endpoints = " , |
| 1377 | subgraph_options.callable_options.fetch_size(), |
| 1378 | " versus number of pruned fetch endpoints = " , |
| 1379 | client_graph->fetch_types.size()); |
| 1380 | } |
| 1381 | |
| 1382 | auto current_stateful_placements = execution_state->GetStatefulPlacements(); |
| 1383 | // Update our current state based on the execution_state's |
| 1384 | // placements. If there are any mismatches for a node, |
| 1385 | // we should fail, as this should never happen. |
| 1386 | for (auto placement_pair : current_stateful_placements) { |
| 1387 | const string& node_name = placement_pair.first; |
| 1388 | const string& placement = placement_pair.second; |
| 1389 | auto iter = stateful_placements_.find(node_name); |
| 1390 | if (iter == stateful_placements_.end()) { |
| 1391 | stateful_placements_.insert(std::make_pair(node_name, placement)); |
| 1392 | } else if (iter->second != placement) { |
| 1393 | return errors::Internal( |
| 1394 | "Stateful placement mismatch. " |
| 1395 | "Current assignment of " , |
| 1396 | node_name, " to " , iter->second, " does not match " , placement); |
| 1397 | } |
| 1398 | } |
| 1399 | |
| 1400 | stateful_placements_ = execution_state->GetStatefulPlacements(); |
| 1401 | |
| 1402 | // Remember the graph in run state if this is a partial run. |
| 1403 | if (run_state_args->is_partial_run) { |
| 1404 | run_state_args->graph.reset(new Graph(flib_def_.get())); |
| 1405 | CopyGraph(*execution_state->full_graph(), run_state_args->graph.get()); |
| 1406 | } |
| 1407 | |
| 1408 | // Partition the graph across devices. |
| 1409 | PartitionOptions popts; |
| 1410 | popts.node_to_loc = [](const Node* node) { |
| 1411 | return node->assigned_device_name(); |
| 1412 | }; |
| 1413 | popts.new_name = [this](const string& prefix) { |
| 1414 | return strings::StrCat(prefix, "/_" , edge_name_counter_.fetch_add(1)); |
| 1415 | }; |
| 1416 | popts.get_incarnation = [](const string& name) { |
| 1417 | // The direct session does not have changing incarnation numbers. |
| 1418 | // Just return '1'. |
| 1419 | return 1; |
| 1420 | }; |
| 1421 | popts.flib_def = &client_graph->graph.flib_def(); |
| 1422 | popts.control_flow_added = false; |
| 1423 | |
| 1424 | std::unordered_map<string, GraphDef> partitions; |
| 1425 | TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions)); |
| 1426 | |
| 1427 | std::vector<string> device_names; |
| 1428 | for (auto device : devices_) { |
| 1429 | // Extract the LocalName from the device. |
| 1430 | device_names.push_back(DeviceNameUtils::LocalName(device->name())); |
| 1431 | } |
| 1432 | |
| 1433 | // Check for valid partitions. |
| 1434 | for (const auto& partition : partitions) { |
| 1435 | const string local_partition_name = |
| 1436 | DeviceNameUtils::LocalName(partition.first); |
| 1437 | if (std::count(device_names.begin(), device_names.end(), |
| 1438 | local_partition_name) == 0) { |
| 1439 | return errors::InvalidArgument( |
| 1440 | "Creating a partition for " , local_partition_name, |
| 1441 | " which doesn't exist in the list of available devices. Available " |
| 1442 | "devices: " , |
| 1443 | str_util::Join(device_names, "," )); |
| 1444 | } |
| 1445 | } |
| 1446 | |
| 1447 | for (const auto& partition : partitions) { |
| 1448 | std::unique_ptr<Graph> device_graph( |
| 1449 | new Graph(client_graph->flib_def.get())); |
| 1450 | GraphConstructorOptions device_opts; |
| 1451 | // There are internal operations (e.g., send/recv) that we now allow. |
| 1452 | device_opts.allow_internal_ops = true; |
| 1453 | device_opts.expect_device_spec = true; |
| 1454 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, |
| 1455 | device_graph.get())); |
| 1456 | outputs->emplace(partition.first, std::move(device_graph)); |
| 1457 | } |
| 1458 | |
| 1459 | GraphOptimizationPassOptions optimization_options; |
| 1460 | optimization_options.session_options = &options_; |
| 1461 | optimization_options.flib_def = client_graph->flib_def.get(); |
| 1462 | optimization_options.partition_graphs = outputs; |
| 1463 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
| 1464 | OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); |
| 1465 | |
| 1466 | Status s; |
| 1467 | for (auto& partition : *outputs) { |
| 1468 | const string& partition_name = partition.first; |
| 1469 | std::unique_ptr<Graph>* graph = &partition.second; |
| 1470 | |
| 1471 | VLOG(2) << "Created " << DebugString(graph->get()) << " for " |
| 1472 | << partition_name; |
| 1473 | |
| 1474 | // Give the device an opportunity to rewrite its subgraph. |
| 1475 | Device* d; |
| 1476 | s = device_mgr_->LookupDevice(partition_name, &d); |
| 1477 | if (!s.ok()) break; |
| 1478 | s = d->MaybeRewriteGraph(graph); |
| 1479 | if (!s.ok()) { |
| 1480 | break; |
| 1481 | } |
| 1482 | } |
| 1483 | *flib_def = std::move(client_graph->flib_def); |
| 1484 | std::swap(*input_types, client_graph->feed_types); |
| 1485 | std::swap(*output_types, client_graph->fetch_types); |
| 1486 | return s; |
| 1487 | } |
| 1488 | |
| 1489 | ::tensorflow::Status DirectSession::ListDevices( |
| 1490 | std::vector<DeviceAttributes>* response) { |
| 1491 | response->clear(); |
| 1492 | response->reserve(devices_.size()); |
| 1493 | for (Device* d : devices_) { |
| 1494 | const DeviceAttributes& attrs = d->attributes(); |
| 1495 | response->emplace_back(attrs); |
| 1496 | } |
| 1497 | return ::tensorflow::Status::OK(); |
| 1498 | } |
| 1499 | |
| 1500 | ::tensorflow::Status DirectSession::Reset( |
| 1501 | const std::vector<string>& containers) { |
| 1502 | device_mgr_->ClearContainers(containers); |
| 1503 | return ::tensorflow::Status::OK(); |
| 1504 | } |
| 1505 | |
| 1506 | ::tensorflow::Status DirectSession::Close() { |
| 1507 | cancellation_manager_->StartCancel(); |
| 1508 | { |
| 1509 | mutex_lock l(closed_lock_); |
| 1510 | if (closed_) return ::tensorflow::Status::OK(); |
| 1511 | closed_ = true; |
| 1512 | } |
| 1513 | if (factory_ != nullptr) factory_->Deregister(this); |
| 1514 | return ::tensorflow::Status::OK(); |
| 1515 | } |
| 1516 | |
| 1517 | DirectSession::RunState::RunState( |
| 1518 | const std::vector<string>& pending_input_names, |
| 1519 | const std::vector<string>& pending_output_names, int64 step_id, |
| 1520 | const std::vector<Device*>* devices) |
| 1521 | : step_container(step_id, [devices](const string& name) { |
| 1522 | for (auto d : *devices) { |
| 1523 | if (!d->resource_manager()->Cleanup(name).ok()) { |
| 1524 | // Do nothing... |
| 1525 | } |
| 1526 | } |
| 1527 | }) { |
| 1528 | // Initially all the feeds and fetches are pending. |
| 1529 | for (auto& name : pending_input_names) { |
| 1530 | pending_inputs[name] = false; |
| 1531 | } |
| 1532 | for (auto& name : pending_output_names) { |
| 1533 | pending_outputs[name] = false; |
| 1534 | } |
| 1535 | } |
| 1536 | |
| 1537 | DirectSession::RunState::RunState(int64 step_id, |
| 1538 | const std::vector<Device*>* devices) |
| 1539 | : RunState({}, {}, step_id, devices) {} |
| 1540 | |
| 1541 | DirectSession::RunState::~RunState() { |
| 1542 | if (rendez != nullptr) { |
| 1543 | if (!executors_done.HasBeenNotified()) { |
| 1544 | rendez->StartAbort(errors::Cancelled("PRun cancellation" )); |
| 1545 | executors_done.WaitForNotification(); |
| 1546 | } |
| 1547 | rendez->Unref(); |
| 1548 | } |
| 1549 | } |
| 1550 | |
| 1551 | bool DirectSession::RunState::PendingDone() const { |
| 1552 | for (const auto& it : pending_inputs) { |
| 1553 | if (!it.second) return false; |
| 1554 | } |
| 1555 | for (const auto& it : pending_outputs) { |
| 1556 | if (!it.second) return false; |
| 1557 | } |
| 1558 | return true; |
| 1559 | } |
| 1560 | |
| 1561 | void DirectSession::WaitForNotification(RunState* run_state, |
| 1562 | CancellationManager* cm, |
| 1563 | int64 timeout_in_ms) { |
| 1564 | const Status status = |
| 1565 | WaitForNotification(&run_state->executors_done, timeout_in_ms); |
| 1566 | if (!status.ok()) { |
| 1567 | { |
| 1568 | mutex_lock l(run_state->mu_); |
| 1569 | run_state->status.Update(status); |
| 1570 | } |
| 1571 | cm->StartCancel(); |
| 1572 | // We must wait for the executors to complete, because they have borrowed |
| 1573 | // references to `cm` and other per-step state. After this notification, it |
| 1574 | // is safe to clean up the step. |
| 1575 | run_state->executors_done.WaitForNotification(); |
| 1576 | } |
| 1577 | } |
| 1578 | |
| 1579 | ::tensorflow::Status DirectSession::WaitForNotification( |
| 1580 | Notification* notification, int64 timeout_in_ms) { |
| 1581 | if (timeout_in_ms > 0) { |
| 1582 | const int64 timeout_in_us = timeout_in_ms * 1000; |
| 1583 | const bool notified = |
| 1584 | WaitForNotificationWithTimeout(notification, timeout_in_us); |
| 1585 | if (!notified) { |
| 1586 | return Status(error::DEADLINE_EXCEEDED, |
| 1587 | "Timed out waiting for notification" ); |
| 1588 | } |
| 1589 | } else { |
| 1590 | notification->WaitForNotification(); |
| 1591 | } |
| 1592 | return Status::OK(); |
| 1593 | } |
| 1594 | |
| 1595 | Status DirectSession::MakeCallable(const CallableOptions& callable_options, |
| 1596 | CallableHandle* out_handle) { |
| 1597 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
| 1598 | TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()" )); |
| 1599 | |
| 1600 | if (!callable_options.run_options() |
| 1601 | .debug_options() |
| 1602 | .debug_tensor_watch_opts() |
| 1603 | .empty()) { |
| 1604 | return errors::Unimplemented( |
| 1605 | "Debug options are not currently supported via the C++ MakeCallable " |
| 1606 | "interface." ); |
| 1607 | } |
| 1608 | |
| 1609 | std::unique_ptr<ExecutorsAndKeys> ek; |
| 1610 | std::unique_ptr<FunctionInfo> func_info; |
| 1611 | RunStateArgs run_state_args(callable_options.run_options().debug_options()); |
| 1612 | TF_RETURN_IF_ERROR( |
| 1613 | CreateExecutors(callable_options, &ek, &func_info, &run_state_args)); |
| 1614 | { |
| 1615 | mutex_lock l(callables_lock_); |
| 1616 | *out_handle = next_callable_handle_++; |
| 1617 | callables_[*out_handle] = {std::move(ek), std::move(func_info)}; |
| 1618 | } |
| 1619 | return Status::OK(); |
| 1620 | } |
| 1621 | |
| 1622 | class DirectSession::RunCallableCallFrame : public CallFrameInterface { |
| 1623 | public: |
| 1624 | RunCallableCallFrame(DirectSession* session, |
| 1625 | ExecutorsAndKeys* executors_and_keys, |
| 1626 | const std::vector<Tensor>* feed_tensors, |
| 1627 | std::vector<Tensor>* fetch_tensors) |
| 1628 | : session_(session), |
| 1629 | executors_and_keys_(executors_and_keys), |
| 1630 | feed_tensors_(feed_tensors), |
| 1631 | fetch_tensors_(fetch_tensors) {} |
| 1632 | |
| 1633 | size_t num_args() const override { |
| 1634 | return executors_and_keys_->input_types.size(); |
| 1635 | } |
| 1636 | size_t num_retvals() const override { |
| 1637 | return executors_and_keys_->output_types.size(); |
| 1638 | } |
| 1639 | |
| 1640 | Status GetArg(int index, Tensor* val) const override { |
| 1641 | if (index > feed_tensors_->size()) { |
| 1642 | return errors::Internal("Args index out of bounds: " , index); |
| 1643 | } else if (executors_and_keys_->input_types[index] == DT_RESOURCE) { |
| 1644 | TF_RETURN_IF_ERROR( |
| 1645 | session_->ResourceHandleToInputTensor((*feed_tensors_)[index], val)); |
| 1646 | } else { |
| 1647 | *val = (*feed_tensors_)[index]; |
| 1648 | } |
| 1649 | return Status::OK(); |
| 1650 | } |
| 1651 | |
| 1652 | Status SetRetval(int index, const Tensor& val) override { |
| 1653 | if (index > fetch_tensors_->size()) { |
| 1654 | return errors::Internal("RetVal index out of bounds: " , index); |
| 1655 | } |
| 1656 | (*fetch_tensors_)[index] = val; |
| 1657 | return Status::OK(); |
| 1658 | } |
| 1659 | |
| 1660 | private: |
| 1661 | DirectSession* const session_; // Not owned. |
| 1662 | ExecutorsAndKeys* const executors_and_keys_; // Not owned. |
| 1663 | const std::vector<Tensor>* const feed_tensors_; // Not owned. |
| 1664 | std::vector<Tensor>* const fetch_tensors_; // Not owned. |
| 1665 | }; |
| 1666 | |
| 1667 | ::tensorflow::Status DirectSession::RunCallable( |
| 1668 | CallableHandle handle, const std::vector<Tensor>& feed_tensors, |
| 1669 | std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) { |
| 1670 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
| 1671 | TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()" )); |
| 1672 | direct_session_runs->GetCell()->IncrementBy(1); |
| 1673 | |
| 1674 | // Check if we already have an executor for these arguments. |
| 1675 | std::shared_ptr<ExecutorsAndKeys> executors_and_keys; |
| 1676 | const int64 step_id = step_id_counter_.fetch_add(1); |
| 1677 | |
| 1678 | { |
| 1679 | tf_shared_lock l(callables_lock_); |
| 1680 | if (handle >= next_callable_handle_) { |
| 1681 | return errors::InvalidArgument("No such callable handle: " , handle); |
| 1682 | } |
| 1683 | executors_and_keys = callables_[handle].executors_and_keys; |
| 1684 | } |
| 1685 | |
| 1686 | if (!executors_and_keys) { |
| 1687 | return errors::InvalidArgument( |
| 1688 | "Attempted to run callable after handle was released: " , handle); |
| 1689 | } |
| 1690 | |
| 1691 | // NOTE(mrry): Debug options are not currently supported in the |
| 1692 | // callable interface. |
| 1693 | DebugOptions debug_options; |
| 1694 | RunStateArgs run_state_args(debug_options); |
| 1695 | |
| 1696 | // Configure a call frame for the step, which we use to feed and |
| 1697 | // fetch values to and from the executors. |
| 1698 | if (feed_tensors.size() != executors_and_keys->input_types.size()) { |
| 1699 | return errors::InvalidArgument( |
| 1700 | "Expected " , executors_and_keys->input_types.size(), |
| 1701 | " feed tensors, but got " , feed_tensors.size()); |
| 1702 | } |
| 1703 | if (fetch_tensors != nullptr) { |
| 1704 | fetch_tensors->resize(executors_and_keys->output_types.size()); |
| 1705 | } else if (!executors_and_keys->output_types.empty()) { |
| 1706 | return errors::InvalidArgument( |
| 1707 | "`fetch_tensors` must be provided when the callable has one or more " |
| 1708 | "outputs." ); |
| 1709 | } |
| 1710 | |
| 1711 | // A specialized CallFrame implementation that takes advantage of the |
| 1712 | // optimized RunCallable interface. |
| 1713 | |
| 1714 | RunCallableCallFrame call_frame(this, executors_and_keys.get(), &feed_tensors, |
| 1715 | fetch_tensors); |
| 1716 | |
| 1717 | if (LogMemory::IsEnabled()) { |
| 1718 | LogMemory::RecordStep(step_id, run_state_args.handle); |
| 1719 | } |
| 1720 | |
| 1721 | TF_RETURN_IF_ERROR( |
| 1722 | RunInternal(step_id, executors_and_keys->callable_options.run_options(), |
| 1723 | &call_frame, executors_and_keys.get(), run_metadata)); |
| 1724 | |
| 1725 | return Status::OK(); |
| 1726 | } |
| 1727 | |
| 1728 | ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) { |
| 1729 | mutex_lock l(callables_lock_); |
| 1730 | if (handle >= next_callable_handle_) { |
| 1731 | return errors::InvalidArgument("No such callable handle: " , handle); |
| 1732 | } |
| 1733 | callables_.erase(handle); |
| 1734 | return Status::OK(); |
| 1735 | } |
| 1736 | |
| 1737 | DirectSession::Callable::~Callable() { |
| 1738 | // We must delete the fields in this order, because the destructor |
| 1739 | // of `executors_and_keys` will call into an object owned by |
| 1740 | // `function_info` (in particular, when deleting a kernel, it relies |
| 1741 | // on the `FunctionLibraryRuntime` to know if the kernel is stateful |
| 1742 | // or not). |
| 1743 | executors_and_keys.reset(); |
| 1744 | function_info.reset(); |
| 1745 | } |
| 1746 | |
| 1747 | } // namespace tensorflow |
| 1748 | |