1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/direct_session.h"
17
18#include <atomic>
19#include <string>
20#include <vector>
21
22#include "tensorflow/core/common_runtime/constant_folding.h"
23#include "tensorflow/core/common_runtime/debugger_state_interface.h"
24#include "tensorflow/core/common_runtime/device_factory.h"
25#include "tensorflow/core/common_runtime/executor.h"
26#include "tensorflow/core/common_runtime/function.h"
27#include "tensorflow/core/common_runtime/graph_optimizer.h"
28#include "tensorflow/core/common_runtime/memory_types.h"
29#include "tensorflow/core/common_runtime/optimization_registry.h"
30#include "tensorflow/core/common_runtime/process_util.h"
31#include "tensorflow/core/common_runtime/step_stats_collector.h"
32#include "tensorflow/core/framework/function.h"
33#include "tensorflow/core/framework/graph.pb_text.h"
34#include "tensorflow/core/framework/graph.pb.h"
35#include "tensorflow/core/framework/graph_def_util.h"
36#include "tensorflow/core/framework/log_memory.h"
37#include "tensorflow/core/framework/node_def.pb.h"
38#include "tensorflow/core/framework/tensor.h"
39#include "tensorflow/core/framework/versions.pb.h"
40#include "tensorflow/core/graph/algorithm.h"
41#include "tensorflow/core/graph/graph.h"
42#include "tensorflow/core/graph/graph_constructor.h"
43#include "tensorflow/core/graph/graph_partition.h"
44#include "tensorflow/core/graph/subgraph.h"
45#include "tensorflow/core/graph/tensor_id.h"
46#include "tensorflow/core/lib/core/errors.h"
47#include "tensorflow/core/lib/core/notification.h"
48#include "tensorflow/core/lib/core/refcount.h"
49#include "tensorflow/core/lib/core/status.h"
50#include "tensorflow/core/lib/core/threadpool.h"
51#include "tensorflow/core/lib/gtl/array_slice.h"
52#include "tensorflow/core/lib/gtl/stl_util.h"
53#include "tensorflow/core/lib/monitoring/counter.h"
54#include "tensorflow/core/lib/strings/numbers.h"
55#include "tensorflow/core/lib/strings/str_util.h"
56#include "tensorflow/core/lib/strings/strcat.h"
57#include "tensorflow/core/platform/cpu_info.h"
58#include "tensorflow/core/platform/device_tracer.h"
59#include "tensorflow/core/platform/logging.h"
60#include "tensorflow/core/platform/mutex.h"
61#include "tensorflow/core/platform/types.h"
62#include "tensorflow/core/util/device_name_utils.h"
63#include "tensorflow/core/util/env_var.h"
64
65namespace tensorflow {
66
67namespace {
68
69auto* direct_session_runs = monitoring::Counter<0>::New(
70 "/tensorflow/core/direct_session_runs",
71 "The number of times DirectSession::Run() has been called.");
72
73Status NewThreadPoolFromThreadPoolOptions(
74 const SessionOptions& options,
75 const ThreadPoolOptionProto& thread_pool_options, int pool_number,
76 thread::ThreadPool** pool, bool* owned) {
77 int32 num_threads = thread_pool_options.num_threads();
78 if (num_threads == 0) {
79 num_threads = NumInterOpThreadsFromSessionOptions(options);
80 }
81 const string& name = thread_pool_options.global_name();
82 if (name.empty()) {
83 // Session-local threadpool.
84 VLOG(1) << "Direct session inter op parallelism threads for pool "
85 << pool_number << ": " << num_threads;
86 *pool = new thread::ThreadPool(
87 options.env, strings::StrCat("Compute", pool_number), num_threads);
88 *owned = true;
89 return Status::OK();
90 }
91
92 // Global, named threadpool.
93 typedef std::pair<int32, thread::ThreadPool*> MapValue;
94 static std::map<string, MapValue>* global_pool_map =
95 new std::map<string, MapValue>;
96 static mutex* mu = new mutex();
97 mutex_lock l(*mu);
98 MapValue* mvalue = &(*global_pool_map)[name];
99 if (mvalue->second == nullptr) {
100 mvalue->first = thread_pool_options.num_threads();
101 mvalue->second = new thread::ThreadPool(
102 options.env, strings::StrCat("Compute", pool_number), num_threads);
103 } else {
104 if (mvalue->first != thread_pool_options.num_threads()) {
105 return errors::InvalidArgument(
106 "Pool ", name,
107 " configured previously with num_threads=", mvalue->first,
108 "; cannot re-configure with num_threads=",
109 thread_pool_options.num_threads());
110 }
111 }
112 *owned = false;
113 *pool = mvalue->second;
114 return Status::OK();
115}
116
117thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
118 static thread::ThreadPool* const thread_pool =
119 NewThreadPoolFromSessionOptions(options);
120 return thread_pool;
121}
122
123// TODO(vrv): Figure out how to unify the many different functions
124// that generate RendezvousKey, since many of them have to be
125// consistent with each other.
126string GetRendezvousKey(const string& tensor_name,
127 const DeviceAttributes& device_info,
128 const FrameAndIter& frame_iter) {
129 return strings::StrCat(device_info.name(), ";",
130 strings::FpToString(device_info.incarnation()), ";",
131 device_info.name(), ";", tensor_name, ";",
132 frame_iter.frame_id, ":", frame_iter.iter_id);
133}
134
135} // namespace
136
137class DirectSessionFactory : public SessionFactory {
138 public:
139 DirectSessionFactory() {}
140
141 bool AcceptsOptions(const SessionOptions& options) override {
142 return options.target.empty();
143 }
144
145 Session* NewSession(const SessionOptions& options) override {
146 // Must do this before the CPU allocator is created.
147 if (options.config.graph_options().build_cost_model() > 0) {
148 EnableCPUAllocatorFullStats(true);
149 }
150 std::vector<Device*> devices;
151 const Status s = DeviceFactory::AddDevices(
152 options, "/job:localhost/replica:0/task:0", &devices);
153 if (!s.ok()) {
154 LOG(ERROR) << s;
155 return nullptr;
156 }
157
158 DirectSession* session =
159 new DirectSession(options, new DeviceMgr(devices), this);
160 {
161 mutex_lock l(sessions_lock_);
162 sessions_.push_back(session);
163 }
164 return session;
165 }
166
167 Status Reset(const SessionOptions& options,
168 const std::vector<string>& containers) override {
169 std::vector<DirectSession*> sessions_to_reset;
170 {
171 mutex_lock l(sessions_lock_);
172 // We create a copy to ensure that we don't have a deadlock when
173 // session->Close calls the DirectSessionFactory.Deregister, which
174 // acquires sessions_lock_.
175 std::swap(sessions_to_reset, sessions_);
176 }
177 Status s;
178 for (auto session : sessions_to_reset) {
179 s.Update(session->Reset(containers));
180 }
181 // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
182 // it doesn't close the sessions?
183 for (auto session : sessions_to_reset) {
184 s.Update(session->Close());
185 }
186 return s;
187 }
188
189 void Deregister(const DirectSession* session) {
190 mutex_lock l(sessions_lock_);
191 sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
192 sessions_.end());
193 }
194
195 private:
196 mutex sessions_lock_;
197 std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_);
198};
199
200class DirectSessionRegistrar {
201 public:
202 DirectSessionRegistrar() {
203 SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
204 }
205};
206static DirectSessionRegistrar registrar;
207
208std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
209
210// NOTE: On Android with a single device, there is never
211// a risk of an OpKernel blocking indefinitely:
212//
213// 1) No operations do I/O that depends on other simultaneous kernels,
214//
215// 2) Recv nodes always complete immediately: The inputs are sent into
216// the local rendezvous before we start the executor, so the
217// corresponding recvs will not block.
218//
219// Based on these assumptions, we can use the same thread pool for
220// both "non-blocking" and "blocking" OpKernels on Android.
221//
222// This may change down the road when we add support for multiple
223// devices that run concurrently, in which case we will need to
224// revisit this decision.
225void DirectSession::SchedClosure(thread::ThreadPool* pool,
226 std::function<void()> c) {
227// TODO(sanjay): Get rid of __ANDROID__ path
228#ifdef __ANDROID__
229 // On Android, there is no implementation of ThreadPool that takes
230 // std::function, only Closure, which we cannot easily convert.
231 //
232 // Instead, we just run the function in-line, which is currently
233 // safe given the reasoning above.
234 c();
235#else
236 pool->Schedule(std::move(c));
237#endif // __ANDROID__
238}
239
240DirectSession::DirectSession(const SessionOptions& options,
241 const DeviceMgr* device_mgr,
242 DirectSessionFactory* const factory)
243 : options_(options),
244 device_mgr_(device_mgr),
245 factory_(factory),
246 cancellation_manager_(new CancellationManager()),
247 operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
248 const int thread_pool_size =
249 options_.config.session_inter_op_thread_pool_size();
250 if (thread_pool_size > 0) {
251 for (int i = 0; i < thread_pool_size; ++i) {
252 thread::ThreadPool* pool = nullptr;
253 bool owned = false;
254 init_error_.Update(NewThreadPoolFromThreadPoolOptions(
255 options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
256 &owned));
257 thread_pools_.emplace_back(pool, owned);
258 }
259 } else if (options_.config.use_per_session_threads()) {
260 thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
261 true /* owned */);
262 } else {
263 thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
264 }
265 // The default value of sync_on_finish will be flipped soon and this
266 // environment variable will be removed as well.
267 const Status status =
268 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
269 if (!status.ok()) {
270 LOG(ERROR) << status.error_message();
271 }
272 // NOTE(mrry): We do not need to use a unique string for the session
273 // handle, because DirectSession owns its devices. This may change
274 // in future versions.
275 session_handle_ = "direct";
276 int devices_added = 0;
277 if (options.config.log_device_placement()) {
278 const string mapping_str = device_mgr_->DeviceMappingString();
279 if (mapping_str.empty()) {
280 printf("Device mapping: no known devices.\n");
281 } else {
282 printf("Device mapping:\n%s", mapping_str.c_str());
283 }
284 LOG(INFO) << "Device mapping:\n" << mapping_str;
285 }
286 for (auto d : device_mgr_->ListDevices()) {
287 devices_.push_back(d);
288 device_set_.AddDevice(d);
289 d->op_segment()->AddHold(session_handle_);
290
291 // The first device added is special: it is the 'client device' (a
292 // CPU device) from which we feed and fetch Tensors.
293 if (devices_added == 0) {
294 device_set_.set_client_device(d);
295 }
296 ++devices_added;
297 }
298}
299
300DirectSession::~DirectSession() {
301 if (!closed_) Close().IgnoreError();
302 for (auto& it : partial_runs_) {
303 it.second.reset(nullptr);
304 }
305 for (auto& it : executors_) {
306 it.second.reset();
307 }
308 callables_.clear();
309 for (auto d : device_mgr_->ListDevices()) {
310 d->op_segment()->RemoveHold(session_handle_);
311 }
312 for (auto d : device_mgr_->ListDevices()) {
313 d->ClearResourceMgr();
314 }
315 functions_.clear();
316 delete cancellation_manager_;
317 for (const auto& p_and_owned : thread_pools_) {
318 if (p_and_owned.second) delete p_and_owned.first;
319 }
320
321 execution_state_.reset(nullptr);
322 flib_def_.reset(nullptr);
323}
324
325Status DirectSession::MaybeInitializeExecutionState(
326 const GraphDef& graph, bool* out_already_initialized) {
327 // If already initialized, do nothing.
328 if (flib_def_ && execution_state_) {
329 *out_already_initialized = true;
330 return Status::OK();
331 }
332 // Set up the per-session execution state.
333 // NOTE(mrry): The function library created here will be used for
334 // all subsequent extensions of the graph.
335 flib_def_.reset(
336 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
337 GraphExecutionStateOptions options;
338 options.device_set = &device_set_;
339 options.session_options = &options_;
340 // TODO(mrry,suharshs): We explicitly copy `graph` so that
341 // `MakeForBaseGraph()` can take ownership of its
342 // contents. Previously this happened implicitly in calls to the
343 // `GraphExecutionState`. Other sessions call
344 // `MakeForBaseGraph` in such a way that we can destructively read
345 // the passed-in `GraphDef`. In principle we could do the same here,
346 // with a wider refactoring; we might revise the direct session so
347 // that it copies the graph fewer times.
348 GraphDef temp(graph);
349 TF_RETURN_IF_ERROR(
350 GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
351 graph_created_ = true;
352 *out_already_initialized = false;
353 return Status::OK();
354}
355
356Status DirectSession::Create(const GraphDef& graph) {
357 TF_RETURN_IF_ERROR(init_error_);
358 if (graph.node_size() > 0) {
359 mutex_lock l(graph_def_lock_);
360 if (graph_created_) {
361 return errors::AlreadyExists(
362 "A Graph has already been created for this session.");
363 }
364 return ExtendLocked(graph);
365 }
366 return Status::OK();
367}
368
369Status DirectSession::Extend(const GraphDef& graph) {
370 TF_RETURN_IF_ERROR(CheckNotClosed());
371 mutex_lock l(graph_def_lock_);
372 return ExtendLocked(graph);
373}
374
375Status DirectSession::ExtendLocked(const GraphDef& graph) {
376 bool already_initialized;
377 // If this is the first call, we can initialize the execution state
378 // with `graph` and do not need to call `Extend()`.
379 TF_RETURN_IF_ERROR(
380 MaybeInitializeExecutionState(graph, &already_initialized));
381 if (already_initialized) {
382 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
383 std::unique_ptr<GraphExecutionState> state;
384 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
385 execution_state_.swap(state);
386 }
387 return Status::OK();
388}
389
390Status DirectSession::Run(const NamedTensorList& inputs,
391 const std::vector<string>& output_names,
392 const std::vector<string>& target_nodes,
393 std::vector<Tensor>* outputs) {
394 RunMetadata run_metadata;
395 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
396 &run_metadata);
397}
398
399Status DirectSession::CreateDebuggerState(
400 const CallableOptions& callable_options, int64 global_step,
401 int64 session_run_index, int64 executor_step_index,
402 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
403 TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
404 callable_options.run_options().debug_options(), debugger_state));
405 std::vector<string> input_names(callable_options.feed().begin(),
406 callable_options.feed().end());
407 std::vector<string> output_names(callable_options.fetch().begin(),
408 callable_options.fetch().end());
409 std::vector<string> target_names(callable_options.target().begin(),
410 callable_options.target().end());
411
412 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
413 global_step, session_run_index, executor_step_index, input_names,
414 output_names, target_names));
415 return Status::OK();
416}
417
418Status DirectSession::DecorateAndPublishGraphForDebug(
419 const DebugOptions& debug_options, Graph* graph, Device* device) {
420 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
421 TF_RETURN_IF_ERROR(
422 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
423
424 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
425 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
426 return Status::OK();
427}
428
429Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
430 CallFrameInterface* call_frame,
431 ExecutorsAndKeys* executors_and_keys,
432 RunMetadata* run_metadata) {
433 const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
434
435 std::unique_ptr<DebuggerStateInterface> debugger_state;
436 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
437 TF_RETURN_IF_ERROR(
438 CreateDebuggerState(executors_and_keys->callable_options,
439 run_options.debug_options().global_step(), step_id,
440 executor_step_count, &debugger_state));
441 }
442
443 // Create a run state and start execution.
444 RunState run_state(step_id, &devices_);
445 run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
446
447 // Start parallel Executors.
448 const size_t num_executors = executors_and_keys->items.size();
449 ExecutorBarrier* barrier = new ExecutorBarrier(
450 num_executors, run_state.rendez, [&run_state](const Status& ret) {
451 {
452 mutex_lock l(run_state.mu_);
453 run_state.status.Update(ret);
454 }
455 run_state.executors_done.Notify();
456 });
457
458 Executor::Args args;
459 args.step_id = step_id;
460 args.call_frame = call_frame;
461 args.rendezvous = run_state.rendez;
462 CancellationManager step_cancellation_manager;
463 args.cancellation_manager = &step_cancellation_manager;
464 args.session_state = &session_state_;
465 args.tensor_store = &run_state.tensor_store;
466 args.step_container = &run_state.step_container;
467 args.sync_on_finish = sync_on_finish_;
468
469 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
470
471 bool update_cost_model = false;
472 if (options_.config.graph_options().build_cost_model() > 0) {
473 const int64 build_cost_model_every =
474 options_.config.graph_options().build_cost_model();
475 const int64 build_cost_model_after =
476 options_.config.graph_options().build_cost_model_after();
477 int64 measure_step_count = executor_step_count - build_cost_model_after;
478 if (measure_step_count >= 0) {
479 update_cost_model =
480 ((measure_step_count + 1) % build_cost_model_every == 0);
481 }
482 }
483 if (do_trace || update_cost_model ||
484 run_options.report_tensor_allocations_upon_oom()) {
485 run_state.collector.reset(
486 new StepStatsCollector(run_metadata->mutable_step_stats()));
487 args.stats_collector = run_state.collector.get();
488 }
489
490 std::unique_ptr<DeviceTracer> tracer;
491 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
492 tracer = CreateDeviceTracer();
493 // tracer may be NULL on platforms without accelerators.
494 if (tracer) {
495 Status s = tracer->Start();
496 if (!s.ok()) {
497 run_state.executors_done.Notify();
498 delete barrier;
499 return s;
500 }
501 }
502 }
503
504 if (run_options.inter_op_thread_pool() < 0 ||
505 run_options.inter_op_thread_pool() >= thread_pools_.size()) {
506 run_state.executors_done.Notify();
507 delete barrier;
508 return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
509 run_options.inter_op_thread_pool());
510 }
511
512 // Register this step with session's cancellation manager, so that
513 // `Session::Close()` will cancel the step.
514 const CancellationToken cancellation_token =
515 cancellation_manager_->get_cancellation_token();
516 const bool already_cancelled = !cancellation_manager_->RegisterCallback(
517 cancellation_token, [&step_cancellation_manager]() {
518 step_cancellation_manager.StartCancel();
519 });
520 if (already_cancelled) {
521 // NOTE(mrry): If we don't explicitly notify
522 // `run_state.executors_done`, the RunState destructor would
523 // block on this notification.
524 run_state.executors_done.Notify();
525 delete barrier;
526 return errors::Cancelled("Run call was cancelled");
527 }
528
529 thread::ThreadPool* pool =
530 thread_pools_[run_options.inter_op_thread_pool()].first;
531
532 Executor::Args::Runner default_runner = [this,
533 pool](Executor::Args::Closure c) {
534 SchedClosure(pool, std::move(c));
535 };
536 for (const auto& item : executors_and_keys->items) {
537 // TODO(zhengxq): support partial run.
538 // TODO(zhengxq): if the device picks its own threadpool, we need to assign
539 // less threads to the main compute pool by default.
540 thread::ThreadPool* device_thread_pool =
541 item.device->tensorflow_device_thread_pool();
542 if (!device_thread_pool) {
543 args.runner = default_runner;
544 } else {
545 args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
546 SchedClosure(device_thread_pool, std::move(c));
547 };
548 }
549 item.executor->RunAsync(args, barrier->Get());
550 }
551
552 WaitForNotification(&run_state, &step_cancellation_manager,
553 run_options.timeout_in_ms() > 0
554 ? run_options.timeout_in_ms()
555 : operation_timeout_in_ms_);
556
557 if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
558 // The step has been cancelled: make sure we don't attempt to receive the
559 // outputs as this would make it block forever.
560 mutex_lock l(run_state.mu_);
561 run_state.status.Update(errors::Cancelled("Run call was cancelled"));
562 }
563
564 if (tracer) {
565 TF_RETURN_IF_ERROR(tracer->Stop());
566 TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector));
567 }
568
569 {
570 mutex_lock l(run_state.mu_);
571 TF_RETURN_IF_ERROR(run_state.status);
572 }
573
574 // Save the output tensors of this run we choose to keep.
575 if (!run_state.tensor_store.empty()) {
576 TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
577 {executors_and_keys->callable_options.fetch().begin(),
578 executors_and_keys->callable_options.fetch().end()},
579 &session_state_));
580 }
581
582 if (args.stats_collector) {
583 args.stats_collector->Finalize();
584 }
585
586 // Build and return the cost model as instructed.
587 if (update_cost_model) {
588 // Build the cost model
589 std::unordered_map<string, const Graph*> device_to_graph;
590 for (const PerPartitionExecutorsAndLib& partition :
591 executors_and_keys->items) {
592 const Graph* graph = partition.graph;
593 const string device = partition.flib->device()->name();
594 device_to_graph[device] = graph;
595 }
596
597 mutex_lock l(executor_lock_);
598 args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
599
600 // annotate stats onto cost graph.
601 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
602 for (const auto& item : executors_and_keys->items) {
603 TF_RETURN_IF_ERROR(
604 cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
605 }
606 }
607
608 // If requested via RunOptions, output the partition graphs.
609 if (run_options.output_partition_graphs()) {
610 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
611 run_metadata->mutable_partition_graphs();
612 for (const PerPartitionExecutorsAndLib& exec_and_lib :
613 executors_and_keys->items) {
614 GraphDef* partition_graph_def = partition_graph_defs->Add();
615 exec_and_lib.graph->ToGraphDef(partition_graph_def);
616 }
617 }
618
619 return Status::OK();
620}
621
622Status DirectSession::Run(const RunOptions& run_options,
623 const NamedTensorList& inputs,
624 const std::vector<string>& output_names,
625 const std::vector<string>& target_nodes,
626 std::vector<Tensor>* outputs,
627 RunMetadata* run_metadata) {
628 TF_RETURN_IF_ERROR(CheckNotClosed());
629 TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
630 direct_session_runs->GetCell()->IncrementBy(1);
631
632 // Extract the inputs names for this run of the session.
633 std::vector<string> input_tensor_names;
634 input_tensor_names.reserve(inputs.size());
635 for (const auto& it : inputs) {
636 input_tensor_names.push_back(it.first);
637 }
638
639 // Check if we already have an executor for these arguments.
640 ExecutorsAndKeys* executors_and_keys;
641 RunStateArgs run_state_args(run_options.debug_options());
642
643 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
644 target_nodes, &executors_and_keys,
645 &run_state_args));
646
647 // Configure a call frame for the step, which we use to feed and
648 // fetch values to and from the executors.
649 FunctionCallFrame call_frame(executors_and_keys->input_types,
650 executors_and_keys->output_types);
651 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
652 for (const auto& it : inputs) {
653 if (it.second.dtype() == DT_RESOURCE) {
654 Tensor tensor_from_handle;
655 TF_RETURN_IF_ERROR(
656 ResourceHandleToInputTensor(it.second, &tensor_from_handle));
657 feed_args[executors_and_keys->input_name_to_index[it.first]] =
658 tensor_from_handle;
659 } else {
660 feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
661 }
662 }
663 const Status s = call_frame.SetArgs(feed_args);
664 if (errors::IsInternal(s)) {
665 return errors::InvalidArgument(s.error_message());
666 } else if (!s.ok()) {
667 return s;
668 }
669
670 const int64 step_id = step_id_counter_.fetch_add(1);
671
672 if (LogMemory::IsEnabled()) {
673 LogMemory::RecordStep(step_id, run_state_args.handle);
674 }
675
676 TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
677 executors_and_keys, run_metadata));
678
679 // Receive outputs.
680 if (outputs) {
681 std::vector<Tensor> sorted_outputs;
682 const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
683 if (errors::IsInternal(s)) {
684 return errors::InvalidArgument(s.error_message());
685 } else if (!s.ok()) {
686 return s;
687 }
688 const bool unique_outputs =
689 output_names.size() == executors_and_keys->output_name_to_index.size();
690 // first_indices[i] = j implies that j is the smallest value for which
691 // output_names[i] == output_names[j].
692 std::vector<int> first_indices;
693 if (!unique_outputs) {
694 first_indices.resize(output_names.size());
695 for (int i = 0; i < output_names.size(); ++i) {
696 for (int j = 0; j <= i; ++j) {
697 if (output_names[i] == output_names[j]) {
698 first_indices[i] = j;
699 break;
700 }
701 }
702 }
703 }
704 outputs->clear();
705 outputs->reserve(sorted_outputs.size());
706 for (int i = 0; i < output_names.size(); ++i) {
707 const string& output_name = output_names[i];
708 if (first_indices.empty() || first_indices[i] == i) {
709 outputs->emplace_back(
710 std::move(sorted_outputs[executors_and_keys
711 ->output_name_to_index[output_name]]));
712 } else {
713 outputs->push_back((*outputs)[first_indices[i]]);
714 }
715 }
716 }
717
718 return Status::OK();
719}
720
721Status DirectSession::PRunSetup(const std::vector<string>& input_names,
722 const std::vector<string>& output_names,
723 const std::vector<string>& target_nodes,
724 string* handle) {
725 TF_RETURN_IF_ERROR(CheckNotClosed());
726 TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
727
728 // RunOptions is not available in PRunSetup, so use thread pool 0.
729 thread::ThreadPool* pool = thread_pools_[0].first;
730
731 // Check if we already have an executor for these arguments.
732 ExecutorsAndKeys* executors_and_keys;
733 // TODO(cais): TFDBG support for partial runs.
734 DebugOptions debug_options;
735 RunStateArgs run_state_args(debug_options);
736 run_state_args.is_partial_run = true;
737 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
738 target_nodes, &executors_and_keys,
739 &run_state_args));
740
741 // Create the run state and save it for future PRun calls.
742 Executor::Args args;
743 args.step_id = step_id_counter_.fetch_add(1);
744 RunState* run_state =
745 new RunState(input_names, output_names, args.step_id, &devices_);
746 run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
747 {
748 mutex_lock l(executor_lock_);
749 if (!partial_runs_
750 .emplace(run_state_args.handle,
751 std::unique_ptr<RunState>(run_state))
752 .second) {
753 return errors::Internal("The handle '", run_state_args.handle,
754 "' created for this partial run is not unique.");
755 }
756 }
757
758 // Start parallel Executors.
759 const size_t num_executors = executors_and_keys->items.size();
760 ExecutorBarrier* barrier = new ExecutorBarrier(
761 num_executors, run_state->rendez, [run_state](const Status& ret) {
762 if (!ret.ok()) {
763 mutex_lock l(run_state->mu_);
764 run_state->status.Update(ret);
765 }
766 run_state->executors_done.Notify();
767 });
768
769 args.rendezvous = run_state->rendez;
770 args.cancellation_manager = cancellation_manager_;
771 args.runner = [this, pool](Executor::Args::Closure c) {
772 SchedClosure(pool, std::move(c));
773 };
774 args.session_state = &session_state_;
775 args.tensor_store = &run_state->tensor_store;
776 args.step_container = &run_state->step_container;
777 if (LogMemory::IsEnabled()) {
778 LogMemory::RecordStep(args.step_id, run_state_args.handle);
779 }
780 args.sync_on_finish = sync_on_finish_;
781
782 if (options_.config.graph_options().build_cost_model()) {
783 run_state->collector.reset(new StepStatsCollector(nullptr));
784 args.stats_collector = run_state->collector.get();
785 }
786
787 for (auto& item : executors_and_keys->items) {
788 item.executor->RunAsync(args, barrier->Get());
789 }
790
791 *handle = run_state_args.handle;
792 return Status::OK();
793}
794
795Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
796 const std::vector<string>& output_names,
797 std::vector<Tensor>* outputs) {
798 TF_RETURN_IF_ERROR(CheckNotClosed());
799 std::vector<string> parts = str_util::Split(handle, ';');
800 const string& key = parts[0];
801 // Get the executors for this partial run.
802 ExecutorsAndKeys* executors_and_keys;
803 RunState* run_state;
804 {
805 mutex_lock l(executor_lock_); // could use reader lock
806 auto exc_it = executors_.find(key);
807 if (exc_it == executors_.end()) {
808 return errors::InvalidArgument(
809 "Must run 'setup' before performing partial runs!");
810 }
811 executors_and_keys = exc_it->second.get();
812
813 auto prun_it = partial_runs_.find(handle);
814 if (prun_it == partial_runs_.end()) {
815 return errors::InvalidArgument(
816 "Must run 'setup' before performing partial runs!");
817 }
818 run_state = prun_it->second.get();
819
820 // Make sure that this is a new set of feeds that are still pending.
821 for (const auto& input : inputs) {
822 auto it = run_state->pending_inputs.find(input.first);
823 if (it == run_state->pending_inputs.end()) {
824 return errors::InvalidArgument(
825 "The feed ", input.first,
826 " was not specified in partial_run_setup.");
827 } else if (it->second) {
828 return errors::InvalidArgument("The feed ", input.first,
829 " has already been fed.");
830 }
831 }
832 // Check that this is a new set of fetches that are still pending.
833 for (const auto& output : output_names) {
834 auto it = run_state->pending_outputs.find(output);
835 if (it == run_state->pending_outputs.end()) {
836 return errors::InvalidArgument(
837 "The fetch ", output, " was not specified in partial_run_setup.");
838 } else if (it->second) {
839 return errors::InvalidArgument("The fetch ", output,
840 " has already been fetched.");
841 }
842 }
843 }
844
845 // Check that this new set of fetches can be computed from all the
846 // feeds we have supplied.
847 TF_RETURN_IF_ERROR(
848 CheckFetch(inputs, output_names, executors_and_keys, run_state));
849
850 // Send inputs.
851 Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
852
853 // Receive outputs.
854 if (s.ok()) {
855 s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
856 }
857
858 // Save the output tensors of this run we choose to keep.
859 if (s.ok()) {
860 s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
861 }
862
863 {
864 mutex_lock l(executor_lock_);
865 // Delete the run state if there is an error or all fetches are done.
866 bool done = true;
867 if (s.ok()) {
868 {
869 mutex_lock l(run_state->mu_);
870 if (!run_state->status.ok()) {
871 LOG(WARNING) << "An error unrelated to this prun has been detected. "
872 << run_state->status;
873 }
874 }
875 for (const auto& input : inputs) {
876 auto it = run_state->pending_inputs.find(input.first);
877 it->second = true;
878 }
879 for (const auto& name : output_names) {
880 auto it = run_state->pending_outputs.find(name);
881 it->second = true;
882 }
883 done = run_state->PendingDone();
884 }
885 if (done) {
886 WaitForNotification(run_state, cancellation_manager_,
887 operation_timeout_in_ms_);
888 partial_runs_.erase(handle);
889 }
890 }
891
892 return s;
893}
894
895Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
896 Tensor* retrieved_tensor) {
897 if (resource_tensor.dtype() != DT_RESOURCE) {
898 return errors::InvalidArgument(strings::StrCat(
899 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
900 resource_tensor.dtype()));
901 }
902
903 const ResourceHandle& resource_handle =
904 resource_tensor.scalar<ResourceHandle>()();
905
906 if (resource_handle.container() ==
907 SessionState::kTensorHandleResourceTypeName) {
908 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
909 } else {
910 return errors::InvalidArgument(strings::StrCat(
911 "Invalid resource type hash code: ", resource_handle.hash_code(),
912 "(name: ", resource_handle.name(),
913 " type: ", resource_handle.maybe_type_name(),
914 "). Perhaps a resource tensor was being provided as a feed? That is "
915 "not currently allowed. Please file an issue at "
916 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
917 "short code snippet that leads to this error message."));
918 }
919}
920
921Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
922 const ExecutorsAndKeys* executors_and_keys,
923 IntraProcessRendezvous* rendez) {
924 Status s;
925 Rendezvous::ParsedKey parsed;
926 // Insert the input tensors into the local rendezvous by their
927 // rendezvous key.
928 for (const auto& input : inputs) {
929 auto it =
930 executors_and_keys->input_name_to_rendezvous_key.find(input.first);
931 if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
932 return errors::Internal("'", input.first, "' is not a pre-defined feed.");
933 }
934 const string& input_key = it->second;
935
936 s = Rendezvous::ParseKey(input_key, &parsed);
937 if (!s.ok()) {
938 rendez->StartAbort(s);
939 return s;
940 }
941
942 if (input.second.dtype() == DT_RESOURCE) {
943 Tensor tensor_from_handle;
944 s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
945 if (s.ok()) {
946 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
947 }
948 } else {
949 s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
950 }
951
952 if (!s.ok()) {
953 rendez->StartAbort(s);
954 return s;
955 }
956 }
957 return Status::OK();
958}
959
960Status DirectSession::RecvPRunOutputs(
961 const std::vector<string>& output_names,
962 const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
963 std::vector<Tensor>* outputs) {
964 Status s;
965 if (!output_names.empty()) {
966 outputs->resize(output_names.size());
967 }
968
969 Rendezvous::ParsedKey parsed;
970 // Get the outputs from the rendezvous
971 for (size_t output_offset = 0; output_offset < output_names.size();
972 ++output_offset) {
973 const string& output_name = output_names[output_offset];
974 auto it =
975 executors_and_keys->output_name_to_rendezvous_key.find(output_name);
976 if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
977 return errors::Internal("'", output_name,
978 "' is not a pre-defined fetch.");
979 }
980 const string& output_key = it->second;
981 Tensor output_tensor;
982 bool is_dead;
983 IntraProcessRendezvous* rendez = run_state->rendez;
984
985 s = Rendezvous::ParseKey(output_key, &parsed);
986 if (s.ok()) {
987 // Fetch data from the Rendezvous.
988 s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
989 operation_timeout_in_ms_);
990 if (is_dead && s.ok()) {
991 s = errors::InvalidArgument("The tensor returned for ", output_name,
992 " was not valid.");
993 }
994 }
995 if (!s.ok()) {
996 rendez->StartAbort(s);
997 outputs->clear();
998 return s;
999 }
1000
1001 (*outputs)[output_offset] = output_tensor;
1002 }
1003 return Status::OK();
1004}
1005
1006Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1007 const std::vector<string>& fetches,
1008 const ExecutorsAndKeys* executors_and_keys,
1009 const RunState* run_state) {
1010 const Graph* graph = executors_and_keys->graph.get();
1011 const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1012
1013 // Build the set of pending feeds that we haven't seen.
1014 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1015 {
1016 mutex_lock l(executor_lock_);
1017 for (const auto& input : run_state->pending_inputs) {
1018 // Skip if the feed has already been fed.
1019 if (input.second) continue;
1020 TensorId id(ParseTensorName(input.first));
1021 auto it = name_to_node->find(id.first);
1022 if (it == name_to_node->end()) {
1023 return errors::NotFound("Feed ", input.first, ": not found");
1024 }
1025 pending_feeds.insert(id);
1026 }
1027 }
1028 for (const auto& it : feeds) {
1029 TensorId id(ParseTensorName(it.first));
1030 pending_feeds.erase(id);
1031 }
1032
1033 // Initialize the stack with the fetch nodes.
1034 std::vector<const Node*> stack;
1035 for (const string& fetch : fetches) {
1036 TensorId id(ParseTensorName(fetch));
1037 auto it = name_to_node->find(id.first);
1038 if (it == name_to_node->end()) {
1039 return errors::NotFound("Fetch ", fetch, ": not found");
1040 }
1041 stack.push_back(it->second);
1042 }
1043
1044 // Any tensor needed for fetches can't be in pending_feeds.
1045 std::vector<bool> visited(graph->num_node_ids(), false);
1046 while (!stack.empty()) {
1047 const Node* n = stack.back();
1048 stack.pop_back();
1049
1050 for (const Edge* in_edge : n->in_edges()) {
1051 const Node* in_node = in_edge->src();
1052 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1053 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1054 in_edge->src_output(),
1055 " can't be computed from the feeds"
1056 " that have been fed so far.");
1057 }
1058 if (!visited[in_node->id()]) {
1059 visited[in_node->id()] = true;
1060 stack.push_back(in_node);
1061 }
1062 }
1063 }
1064 return Status::OK();
1065}
1066
1067Status DirectSession::CreateExecutors(
1068 const CallableOptions& callable_options,
1069 std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1070 std::unique_ptr<FunctionInfo>* out_func_info,
1071 RunStateArgs* run_state_args) {
1072 BuildGraphOptions options;
1073 options.callable_options = callable_options;
1074 options.use_function_convention = !run_state_args->is_partial_run;
1075
1076 std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1077 std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1078
1079 ek->callable_options = callable_options;
1080
1081 std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1082 TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
1083 run_state_args, &ek->input_types,
1084 &ek->output_types));
1085
1086 if (run_state_args->is_partial_run) {
1087 ek->graph = std::move(run_state_args->graph);
1088 std::unordered_set<StringPiece, StringPieceHasher> names;
1089 for (const string& input : callable_options.feed()) {
1090 TensorId id(ParseTensorName(input));
1091 names.emplace(id.first);
1092 }
1093 for (const string& output : callable_options.fetch()) {
1094 TensorId id(ParseTensorName(output));
1095 names.emplace(id.first);
1096 }
1097 for (Node* n : ek->graph->nodes()) {
1098 if (names.count(n->name()) > 0) {
1099 ek->name_to_node.insert({n->name(), n});
1100 }
1101 }
1102 }
1103 ek->items.reserve(graphs.size());
1104 const auto& optimizer_opts =
1105 options_.config.graph_options().optimizer_options();
1106
1107 int graph_def_version;
1108 {
1109 mutex_lock l(graph_def_lock_);
1110 graph_def_version =
1111 execution_state_->original_graph_def().versions().producer();
1112 }
1113 func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1114 device_mgr_.get(), options_.env, graph_def_version,
1115 func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first));
1116
1117 GraphOptimizer optimizer(optimizer_opts);
1118 for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1119 const string& partition_name = iter->first;
1120 std::unique_ptr<Graph>& partition_graph = iter->second;
1121
1122 Device* device;
1123 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1124
1125 ek->items.resize(ek->items.size() + 1);
1126 auto* item = &(ek->items.back());
1127 auto lib = func_info->proc_flr->GetFLR(partition_name);
1128 if (lib == nullptr) {
1129 return errors::Internal("Could not find device: ", partition_name);
1130 }
1131 item->flib = lib;
1132
1133 LocalExecutorParams params;
1134 params.device = device;
1135 params.function_library = lib;
1136 auto opseg = device->op_segment();
1137 params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
1138 OpKernel** kernel) {
1139 // We do not share the kernel via the OpSegment if the node is
1140 // stateless, or a function.
1141 // NOTE(mrry): We must not share function kernels (implemented
1142 // using `CallOp`) between subgraphs, because `CallOp::handle_`
1143 // is tied to a particular subgraph. Even if the function itself
1144 // is stateful, the `CallOp` that invokes it is not.
1145 if (!lib->IsStateful(ndef.op()) ||
1146 lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
1147 return lib->CreateKernel(ndef, kernel);
1148 }
1149 auto create_fn = [lib, &ndef](OpKernel** kernel) {
1150 return lib->CreateKernel(ndef, kernel);
1151 };
1152 // Kernels created for subgraph nodes need to be cached. On
1153 // cache miss, create_fn() is invoked to create a kernel based
1154 // on the function library here + global op registry.
1155 return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
1156 create_fn);
1157 };
1158 params.delete_kernel = [lib](OpKernel* kernel) {
1159 // If the node is stateful, opseg owns it. Otherwise, delete it.
1160 if (kernel && !lib->IsStateful(kernel->type_string())) {
1161 delete kernel;
1162 }
1163 };
1164 params.node_outputs_cb = node_outputs_callback_;
1165
1166 optimizer.Optimize(lib, options_.env, device, &iter->second,
1167 /*shape_map=*/nullptr);
1168
1169 // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
1170 const DebugOptions& debug_options =
1171 options.callable_options.run_options().debug_options();
1172 if (!debug_options.debug_tensor_watch_opts().empty()) {
1173 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1174 debug_options, partition_graph.get(), params.device));
1175 }
1176
1177 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1178 device->name(),
1179 partition_graph.get()));
1180 // NewLocalExecutor takes ownership of partition_graph.
1181 item->graph = partition_graph.get();
1182 item->executor = nullptr;
1183 item->device = device;
1184 Executor* executor;
1185 TF_RETURN_IF_ERROR(
1186 NewLocalExecutor(params, std::move(partition_graph), &executor));
1187 item->executor.reset(executor);
1188 }
1189
1190 // Cache the mapping from input/output names to graph elements to
1191 // avoid recomputing it every time.
1192 if (!run_state_args->is_partial_run) {
1193 // For regular `Run()`, we use the function calling convention, and so
1194 // maintain a mapping from input/output names to
1195 // argument/return-value ordinal index.
1196 for (int i = 0; i < callable_options.feed().size(); ++i) {
1197 const string& input = callable_options.feed(i);
1198 ek->input_name_to_index[input] = i;
1199 }
1200 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1201 const string& output = callable_options.fetch(i);
1202 ek->output_name_to_index[output] = i;
1203 }
1204 } else {
1205 // For `PRun()`, we use the rendezvous calling convention, and so
1206 // maintain a mapping from input/output names to rendezvous keys.
1207 //
1208 // We always use the first device as the device name portion of the
1209 // key, even if we're feeding another graph.
1210 for (int i = 0; i < callable_options.feed().size(); ++i) {
1211 const string& input = callable_options.feed(i);
1212 ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1213 input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1214 }
1215 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1216 const string& output = callable_options.fetch(i);
1217 ek->output_name_to_rendezvous_key[output] =
1218 GetRendezvousKey(output, device_set_.client_device()->attributes(),
1219 FrameAndIter(0, 0));
1220 }
1221 }
1222
1223 *out_executors_and_keys = std::move(ek);
1224 *out_func_info = std::move(func_info);
1225 return Status::OK();
1226}
1227
1228Status DirectSession::GetOrCreateExecutors(
1229 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1230 gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1231 RunStateArgs* run_state_args) {
1232 int64 handle_name_counter_value = -1;
1233 if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1234 handle_name_counter_value = handle_name_counter_.fetch_add(1);
1235 }
1236
1237 string debug_tensor_watches_summary;
1238 if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1239 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1240 run_state_args->debug_options.debug_tensor_watch_opts());
1241 }
1242
1243 // Fast lookup path, no sorting.
1244 const string key = strings::StrCat(
1245 str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
1246 str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
1247 "/", debug_tensor_watches_summary);
1248 // Set the handle, if it's needed to log memory or for partial run.
1249 if (handle_name_counter_value >= 0) {
1250 run_state_args->handle =
1251 strings::StrCat(key, ";", handle_name_counter_value);
1252 }
1253
1254 // See if we already have the executors for this run.
1255 {
1256 mutex_lock l(executor_lock_); // could use reader lock
1257 auto it = executors_.find(key);
1258 if (it != executors_.end()) {
1259 *executors_and_keys = it->second.get();
1260 return Status::OK();
1261 }
1262 }
1263
1264 // Slow lookup path, the unsorted key missed the cache.
1265 // Sort the inputs and outputs, and look up with the sorted key in case an
1266 // earlier call used a different order of inputs and outputs.
1267 //
1268 // We could consider some other signature instead of sorting that
1269 // preserves the same property to avoid the sort in the future.
1270 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1271 std::sort(inputs_sorted.begin(), inputs_sorted.end());
1272 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1273 std::sort(outputs_sorted.begin(), outputs_sorted.end());
1274 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1275 std::sort(tn_sorted.begin(), tn_sorted.end());
1276
1277 const string sorted_key = strings::StrCat(
1278 str_util::Join(inputs_sorted, ","), "->",
1279 str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
1280 "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1281 // Set the handle, if its needed to log memory or for partial run.
1282 if (handle_name_counter_value >= 0) {
1283 run_state_args->handle =
1284 strings::StrCat(sorted_key, ";", handle_name_counter_value);
1285 }
1286
1287 // See if we already have the executors for this run.
1288 {
1289 mutex_lock l(executor_lock_);
1290 auto it = executors_.find(sorted_key);
1291 if (it != executors_.end()) {
1292 *executors_and_keys = it->second.get();
1293 // Insert this under the original key.
1294 executors_.emplace(key, it->second);
1295 return Status::OK();
1296 }
1297 }
1298
1299 // Nothing found, so create the executors and store in the cache.
1300 // The executor_lock_ is intentionally released while executors are
1301 // being created.
1302 CallableOptions callable_options;
1303 for (const string& input : inputs_sorted) {
1304 callable_options.add_feed(input);
1305 }
1306 for (const string& output : outputs_sorted) {
1307 callable_options.add_fetch(output);
1308 }
1309 for (const string& target : tn_sorted) {
1310 callable_options.add_target(target);
1311 }
1312 *callable_options.mutable_run_options()->mutable_debug_options() =
1313 run_state_args->debug_options;
1314 std::unique_ptr<ExecutorsAndKeys> ek;
1315 std::unique_ptr<FunctionInfo> func_info;
1316 TF_RETURN_IF_ERROR(
1317 CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1318
1319 // Reacquire the lock, try to insert into the map.
1320 mutex_lock l(executor_lock_);
1321 functions_.push_back(std::move(func_info));
1322
1323 // Another thread may have created the entry before us, in which case we will
1324 // reuse the already created one.
1325 auto insert_result = executors_.emplace(
1326 sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1327 // Insert the value under the original key, so the fast path lookup will work
1328 // if the user uses the same order of inputs, outputs, and targets again.
1329 executors_.emplace(key, insert_result.first->second);
1330 *executors_and_keys = insert_result.first->second.get();
1331
1332 return Status::OK();
1333}
1334
1335Status DirectSession::CreateGraphs(
1336 const BuildGraphOptions& subgraph_options,
1337 std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1338 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1339 RunStateArgs* run_state_args, DataTypeVector* input_types,
1340 DataTypeVector* output_types) {
1341 mutex_lock l(graph_def_lock_);
1342 std::unique_ptr<ClientGraph> client_graph;
1343
1344 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1345 GraphExecutionState* execution_state = nullptr;
1346 if (options_.config.graph_options().place_pruned_graph()) {
1347 // Because we are placing pruned graphs, we need to create a
1348 // new GraphExecutionState for every new unseen graph,
1349 // and then place it.
1350 GraphExecutionStateOptions prune_options;
1351 prune_options.device_set = &device_set_;
1352 prune_options.session_options = &options_;
1353 prune_options.stateful_placements = stateful_placements_;
1354 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1355 execution_state_->original_graph_def().library(), prune_options,
1356 execution_state_->original_graph_def(), subgraph_options,
1357 &temp_exec_state_holder, &client_graph));
1358 execution_state = temp_exec_state_holder.get();
1359 } else {
1360 execution_state = execution_state_.get();
1361 TF_RETURN_IF_ERROR(
1362 execution_state->BuildGraph(subgraph_options, &client_graph));
1363 }
1364
1365 if (subgraph_options.callable_options.feed_size() !=
1366 client_graph->feed_types.size()) {
1367 return errors::Internal(
1368 "Graph pruning failed: requested number of feed endpoints = ",
1369 subgraph_options.callable_options.feed_size(),
1370 " versus number of pruned feed endpoints = ",
1371 client_graph->feed_types.size());
1372 }
1373 if (subgraph_options.callable_options.fetch_size() !=
1374 client_graph->fetch_types.size()) {
1375 return errors::Internal(
1376 "Graph pruning failed: requested number of fetch endpoints = ",
1377 subgraph_options.callable_options.fetch_size(),
1378 " versus number of pruned fetch endpoints = ",
1379 client_graph->fetch_types.size());
1380 }
1381
1382 auto current_stateful_placements = execution_state->GetStatefulPlacements();
1383 // Update our current state based on the execution_state's
1384 // placements. If there are any mismatches for a node,
1385 // we should fail, as this should never happen.
1386 for (auto placement_pair : current_stateful_placements) {
1387 const string& node_name = placement_pair.first;
1388 const string& placement = placement_pair.second;
1389 auto iter = stateful_placements_.find(node_name);
1390 if (iter == stateful_placements_.end()) {
1391 stateful_placements_.insert(std::make_pair(node_name, placement));
1392 } else if (iter->second != placement) {
1393 return errors::Internal(
1394 "Stateful placement mismatch. "
1395 "Current assignment of ",
1396 node_name, " to ", iter->second, " does not match ", placement);
1397 }
1398 }
1399
1400 stateful_placements_ = execution_state->GetStatefulPlacements();
1401
1402 // Remember the graph in run state if this is a partial run.
1403 if (run_state_args->is_partial_run) {
1404 run_state_args->graph.reset(new Graph(flib_def_.get()));
1405 CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1406 }
1407
1408 // Partition the graph across devices.
1409 PartitionOptions popts;
1410 popts.node_to_loc = [](const Node* node) {
1411 return node->assigned_device_name();
1412 };
1413 popts.new_name = [this](const string& prefix) {
1414 return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1415 };
1416 popts.get_incarnation = [](const string& name) {
1417 // The direct session does not have changing incarnation numbers.
1418 // Just return '1'.
1419 return 1;
1420 };
1421 popts.flib_def = &client_graph->graph.flib_def();
1422 popts.control_flow_added = false;
1423
1424 std::unordered_map<string, GraphDef> partitions;
1425 TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1426
1427 std::vector<string> device_names;
1428 for (auto device : devices_) {
1429 // Extract the LocalName from the device.
1430 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1431 }
1432
1433 // Check for valid partitions.
1434 for (const auto& partition : partitions) {
1435 const string local_partition_name =
1436 DeviceNameUtils::LocalName(partition.first);
1437 if (std::count(device_names.begin(), device_names.end(),
1438 local_partition_name) == 0) {
1439 return errors::InvalidArgument(
1440 "Creating a partition for ", local_partition_name,
1441 " which doesn't exist in the list of available devices. Available "
1442 "devices: ",
1443 str_util::Join(device_names, ","));
1444 }
1445 }
1446
1447 for (const auto& partition : partitions) {
1448 std::unique_ptr<Graph> device_graph(
1449 new Graph(client_graph->flib_def.get()));
1450 GraphConstructorOptions device_opts;
1451 // There are internal operations (e.g., send/recv) that we now allow.
1452 device_opts.allow_internal_ops = true;
1453 device_opts.expect_device_spec = true;
1454 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1455 device_graph.get()));
1456 outputs->emplace(partition.first, std::move(device_graph));
1457 }
1458
1459 GraphOptimizationPassOptions optimization_options;
1460 optimization_options.session_options = &options_;
1461 optimization_options.flib_def = client_graph->flib_def.get();
1462 optimization_options.partition_graphs = outputs;
1463 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1464 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1465
1466 Status s;
1467 for (auto& partition : *outputs) {
1468 const string& partition_name = partition.first;
1469 std::unique_ptr<Graph>* graph = &partition.second;
1470
1471 VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1472 << partition_name;
1473
1474 // Give the device an opportunity to rewrite its subgraph.
1475 Device* d;
1476 s = device_mgr_->LookupDevice(partition_name, &d);
1477 if (!s.ok()) break;
1478 s = d->MaybeRewriteGraph(graph);
1479 if (!s.ok()) {
1480 break;
1481 }
1482 }
1483 *flib_def = std::move(client_graph->flib_def);
1484 std::swap(*input_types, client_graph->feed_types);
1485 std::swap(*output_types, client_graph->fetch_types);
1486 return s;
1487}
1488
1489::tensorflow::Status DirectSession::ListDevices(
1490 std::vector<DeviceAttributes>* response) {
1491 response->clear();
1492 response->reserve(devices_.size());
1493 for (Device* d : devices_) {
1494 const DeviceAttributes& attrs = d->attributes();
1495 response->emplace_back(attrs);
1496 }
1497 return ::tensorflow::Status::OK();
1498}
1499
1500::tensorflow::Status DirectSession::Reset(
1501 const std::vector<string>& containers) {
1502 device_mgr_->ClearContainers(containers);
1503 return ::tensorflow::Status::OK();
1504}
1505
1506::tensorflow::Status DirectSession::Close() {
1507 cancellation_manager_->StartCancel();
1508 {
1509 mutex_lock l(closed_lock_);
1510 if (closed_) return ::tensorflow::Status::OK();
1511 closed_ = true;
1512 }
1513 if (factory_ != nullptr) factory_->Deregister(this);
1514 return ::tensorflow::Status::OK();
1515}
1516
1517DirectSession::RunState::RunState(
1518 const std::vector<string>& pending_input_names,
1519 const std::vector<string>& pending_output_names, int64 step_id,
1520 const std::vector<Device*>* devices)
1521 : step_container(step_id, [devices](const string& name) {
1522 for (auto d : *devices) {
1523 if (!d->resource_manager()->Cleanup(name).ok()) {
1524 // Do nothing...
1525 }
1526 }
1527 }) {
1528 // Initially all the feeds and fetches are pending.
1529 for (auto& name : pending_input_names) {
1530 pending_inputs[name] = false;
1531 }
1532 for (auto& name : pending_output_names) {
1533 pending_outputs[name] = false;
1534 }
1535}
1536
1537DirectSession::RunState::RunState(int64 step_id,
1538 const std::vector<Device*>* devices)
1539 : RunState({}, {}, step_id, devices) {}
1540
1541DirectSession::RunState::~RunState() {
1542 if (rendez != nullptr) {
1543 if (!executors_done.HasBeenNotified()) {
1544 rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1545 executors_done.WaitForNotification();
1546 }
1547 rendez->Unref();
1548 }
1549}
1550
1551bool DirectSession::RunState::PendingDone() const {
1552 for (const auto& it : pending_inputs) {
1553 if (!it.second) return false;
1554 }
1555 for (const auto& it : pending_outputs) {
1556 if (!it.second) return false;
1557 }
1558 return true;
1559}
1560
1561void DirectSession::WaitForNotification(RunState* run_state,
1562 CancellationManager* cm,
1563 int64 timeout_in_ms) {
1564 const Status status =
1565 WaitForNotification(&run_state->executors_done, timeout_in_ms);
1566 if (!status.ok()) {
1567 {
1568 mutex_lock l(run_state->mu_);
1569 run_state->status.Update(status);
1570 }
1571 cm->StartCancel();
1572 // We must wait for the executors to complete, because they have borrowed
1573 // references to `cm` and other per-step state. After this notification, it
1574 // is safe to clean up the step.
1575 run_state->executors_done.WaitForNotification();
1576 }
1577}
1578
1579::tensorflow::Status DirectSession::WaitForNotification(
1580 Notification* notification, int64 timeout_in_ms) {
1581 if (timeout_in_ms > 0) {
1582 const int64 timeout_in_us = timeout_in_ms * 1000;
1583 const bool notified =
1584 WaitForNotificationWithTimeout(notification, timeout_in_us);
1585 if (!notified) {
1586 return Status(error::DEADLINE_EXCEEDED,
1587 "Timed out waiting for notification");
1588 }
1589 } else {
1590 notification->WaitForNotification();
1591 }
1592 return Status::OK();
1593}
1594
1595Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1596 CallableHandle* out_handle) {
1597 TF_RETURN_IF_ERROR(CheckNotClosed());
1598 TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1599
1600 if (!callable_options.run_options()
1601 .debug_options()
1602 .debug_tensor_watch_opts()
1603 .empty()) {
1604 return errors::Unimplemented(
1605 "Debug options are not currently supported via the C++ MakeCallable "
1606 "interface.");
1607 }
1608
1609 std::unique_ptr<ExecutorsAndKeys> ek;
1610 std::unique_ptr<FunctionInfo> func_info;
1611 RunStateArgs run_state_args(callable_options.run_options().debug_options());
1612 TF_RETURN_IF_ERROR(
1613 CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1614 {
1615 mutex_lock l(callables_lock_);
1616 *out_handle = next_callable_handle_++;
1617 callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1618 }
1619 return Status::OK();
1620}
1621
1622class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1623 public:
1624 RunCallableCallFrame(DirectSession* session,
1625 ExecutorsAndKeys* executors_and_keys,
1626 const std::vector<Tensor>* feed_tensors,
1627 std::vector<Tensor>* fetch_tensors)
1628 : session_(session),
1629 executors_and_keys_(executors_and_keys),
1630 feed_tensors_(feed_tensors),
1631 fetch_tensors_(fetch_tensors) {}
1632
1633 size_t num_args() const override {
1634 return executors_and_keys_->input_types.size();
1635 }
1636 size_t num_retvals() const override {
1637 return executors_and_keys_->output_types.size();
1638 }
1639
1640 Status GetArg(int index, Tensor* val) const override {
1641 if (index > feed_tensors_->size()) {
1642 return errors::Internal("Args index out of bounds: ", index);
1643 } else if (executors_and_keys_->input_types[index] == DT_RESOURCE) {
1644 TF_RETURN_IF_ERROR(
1645 session_->ResourceHandleToInputTensor((*feed_tensors_)[index], val));
1646 } else {
1647 *val = (*feed_tensors_)[index];
1648 }
1649 return Status::OK();
1650 }
1651
1652 Status SetRetval(int index, const Tensor& val) override {
1653 if (index > fetch_tensors_->size()) {
1654 return errors::Internal("RetVal index out of bounds: ", index);
1655 }
1656 (*fetch_tensors_)[index] = val;
1657 return Status::OK();
1658 }
1659
1660 private:
1661 DirectSession* const session_; // Not owned.
1662 ExecutorsAndKeys* const executors_and_keys_; // Not owned.
1663 const std::vector<Tensor>* const feed_tensors_; // Not owned.
1664 std::vector<Tensor>* const fetch_tensors_; // Not owned.
1665};
1666
1667::tensorflow::Status DirectSession::RunCallable(
1668 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1669 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1670 TF_RETURN_IF_ERROR(CheckNotClosed());
1671 TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1672 direct_session_runs->GetCell()->IncrementBy(1);
1673
1674 // Check if we already have an executor for these arguments.
1675 std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1676 const int64 step_id = step_id_counter_.fetch_add(1);
1677
1678 {
1679 tf_shared_lock l(callables_lock_);
1680 if (handle >= next_callable_handle_) {
1681 return errors::InvalidArgument("No such callable handle: ", handle);
1682 }
1683 executors_and_keys = callables_[handle].executors_and_keys;
1684 }
1685
1686 if (!executors_and_keys) {
1687 return errors::InvalidArgument(
1688 "Attempted to run callable after handle was released: ", handle);
1689 }
1690
1691 // NOTE(mrry): Debug options are not currently supported in the
1692 // callable interface.
1693 DebugOptions debug_options;
1694 RunStateArgs run_state_args(debug_options);
1695
1696 // Configure a call frame for the step, which we use to feed and
1697 // fetch values to and from the executors.
1698 if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1699 return errors::InvalidArgument(
1700 "Expected ", executors_and_keys->input_types.size(),
1701 " feed tensors, but got ", feed_tensors.size());
1702 }
1703 if (fetch_tensors != nullptr) {
1704 fetch_tensors->resize(executors_and_keys->output_types.size());
1705 } else if (!executors_and_keys->output_types.empty()) {
1706 return errors::InvalidArgument(
1707 "`fetch_tensors` must be provided when the callable has one or more "
1708 "outputs.");
1709 }
1710
1711 // A specialized CallFrame implementation that takes advantage of the
1712 // optimized RunCallable interface.
1713
1714 RunCallableCallFrame call_frame(this, executors_and_keys.get(), &feed_tensors,
1715 fetch_tensors);
1716
1717 if (LogMemory::IsEnabled()) {
1718 LogMemory::RecordStep(step_id, run_state_args.handle);
1719 }
1720
1721 TF_RETURN_IF_ERROR(
1722 RunInternal(step_id, executors_and_keys->callable_options.run_options(),
1723 &call_frame, executors_and_keys.get(), run_metadata));
1724
1725 return Status::OK();
1726}
1727
1728::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
1729 mutex_lock l(callables_lock_);
1730 if (handle >= next_callable_handle_) {
1731 return errors::InvalidArgument("No such callable handle: ", handle);
1732 }
1733 callables_.erase(handle);
1734 return Status::OK();
1735}
1736
1737DirectSession::Callable::~Callable() {
1738 // We must delete the fields in this order, because the destructor
1739 // of `executors_and_keys` will call into an object owned by
1740 // `function_info` (in particular, when deleting a kernel, it relies
1741 // on the `FunctionLibraryRuntime` to know if the kernel is stateful
1742 // or not).
1743 executors_and_keys.reset();
1744 function_info.reset();
1745}
1746
1747} // namespace tensorflow
1748