1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/eager/c_api.h"
17
18#include <algorithm>
19#include <cstddef>
20#include <memory>
21#include <string>
22#include <vector>
23
24#include "tensorflow/c/c_api.h"
25#include "tensorflow/c/c_api_internal.h"
26#include "tensorflow/c/eager/c_api_internal.h"
27#include "tensorflow/c/eager/runtime.h"
28#ifdef TENSORFLOW_EAGER_USE_XLA
29#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30#endif // TENSORFLOW_EAGER_USE_XLA
31#include "tensorflow/core/common_runtime/copy_tensor.h"
32#include "tensorflow/core/common_runtime/device_factory.h"
33#include "tensorflow/core/common_runtime/device_mgr.h"
34#include "tensorflow/core/common_runtime/device_set.h"
35#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
36#include "tensorflow/core/common_runtime/eager/execute.h"
37#include "tensorflow/core/common_runtime/eager/execute_node.h"
38#include "tensorflow/core/common_runtime/function.h"
39#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40#include "tensorflow/core/framework/node_def_util.h"
41#include "tensorflow/core/framework/rendezvous.h"
42#include "tensorflow/core/framework/tensor_shape.pb.h"
43#include "tensorflow/core/framework/types.h"
44#include "tensorflow/core/lib/core/refcount.h"
45#include "tensorflow/core/lib/gtl/flatmap.h"
46#include "tensorflow/core/lib/gtl/map_util.h"
47#include "tensorflow/core/lib/gtl/stl_util.h"
48#include "tensorflow/core/platform/env.h"
49#include "tensorflow/core/platform/mutex.h"
50#include "tensorflow/core/platform/thread_annotations.h"
51#include "tensorflow/core/public/version.h"
52
53using tensorflow::int64;
54using tensorflow::string;
55
56namespace {
57bool IsCPU(const tensorflow::Device* d) {
58 return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
59}
60
61bool IsXLA(const tensorflow::Device* d) {
62 if (d == nullptr) return false;
63 const auto& device_type = d->attributes().device_type();
64 return device_type.find("XLA") != std::string::npos;
65}
66
67string DeviceName(const tensorflow::Device* d) {
68 return (d == nullptr) ? "cpu:0" : d->name();
69}
70
71#ifdef TENSORFLOW_EAGER_USE_XLA
72std::atomic_int_fast64_t func_id_generator(0);
73#endif // TENSORFLOW_EAGER_USE_XLA
74
75} // namespace
76
77extern "C" {
78
79TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
80
81void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
82 size_t proto_len, TF_Status* status) {
83 TF_SetConfig(&options->session_options, proto, proto_len, status);
84}
85
86void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
87 unsigned char async) {
88 options->async = async;
89}
90void TFE_ContextOptionsSetDevicePlacementPolicy(
91 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
92 options->policy = policy;
93}
94
95TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
96 unsigned char async,
97 TF_Status* status) {
98 status->status = ctx->context.SetAsyncForThread(async);
99}
100
101void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
102
103TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
104 std::vector<tensorflow::Device*> devices;
105 status->status = tensorflow::DeviceFactory::AddDevices(
106 opts->session_options.options, "/job:localhost/replica:0/task:0",
107 &devices);
108 if (!status->status.ok()) {
109 return nullptr;
110 }
111 std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
112 new tensorflow::DeviceMgr(devices));
113 tensorflow::Rendezvous* r =
114 new tensorflow::IntraProcessRendezvous(device_mgr.get());
115 return new TFE_Context(opts->session_options.options, opts->policy,
116 opts->async, std::move(device_mgr), r);
117}
118
119void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
120 delete ctx;
121}
122
123TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
124 TF_DeviceList* list = new TF_DeviceList;
125 ctx->context.device_mgr()->ListDeviceAttributes(&list->response);
126 return list;
127}
128
129void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
130
131void TFE_ContextSetThreadLocalDevicePlacementPolicy(
132 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
133 ctx->context.SetThreadLocalDevicePlacementPolicy(
134 static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
135}
136
137// Note: this function looks up a thread local policy. So it should be called in
138// the appropriate client thread. In particular, in async mode, it may not be
139// safe to call this function from the async EagerExecutor threads.
140extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
141 TFE_Context* ctx) {
142 return static_cast<TFE_ContextDevicePlacementPolicy>(
143 ctx->context.GetDevicePlacementPolicy());
144}
145
146void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
147 status->status = ctx->context.AsyncWait();
148}
149
150void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
151 status->status = ctx->context.GetStatus();
152}
153
154void TFE_ContextAsyncClearError(TFE_Context* ctx) {
155 ctx->context.ClearAsyncError();
156}
157
158TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
159 tensorflow::Tensor tensor;
160 status->status = tensorflow::TF_TensorToTensor(t, &tensor);
161 if (!status->status.ok()) return nullptr;
162 return new TFE_TensorHandle(tensor, nullptr, nullptr);
163}
164
165void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
166 DCHECK(h);
167 if (h->handle) {
168 h->handle->Unref();
169 }
170 delete h;
171}
172
173TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
174 return static_cast<TF_DataType>(h->handle->dtype);
175}
176
177int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
178 const tensorflow::Tensor* t = nullptr;
179 status->status = h->handle->Tensor(&t);
180 return t == nullptr ? 0 : t->dims();
181}
182
183int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
184 TF_Status* status) {
185 const tensorflow::Tensor* t = nullptr;
186 status->status = h->handle->Tensor(&t);
187 return t == nullptr ? 0 : t->dim_size(dim_index);
188}
189
190const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
191 tensorflow::Device* d = nullptr;
192 status->status = h->handle->OpDevice(&d);
193 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
194 : d->name().c_str();
195}
196
197TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
198 // TODO(agarwal): move this implementation inside TFE_TensorHandle.
199 tensorflow::Device* d = nullptr;
200 tensorflow::Device* op_device = nullptr;
201 const tensorflow::Tensor* t = nullptr;
202 status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
203 if (!status->status.ok()) return nullptr;
204 tensorflow::TensorHandle* h_cpu = nullptr;
205 if (!IsCPU(d)) {
206 status->status = h->handle->CopyToDevice(
207 h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu);
208 if (!status->status.ok()) {
209 return nullptr;
210 }
211 status->status = h_cpu->TensorAndDevice(&t, &d, &op_device);
212 if (!status->status.ok()) {
213 h_cpu->Unref();
214 return nullptr;
215 }
216 }
217 TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
218 if (h_cpu != nullptr) {
219 h_cpu->Unref();
220 }
221 return retval;
222}
223} // extern "C"
224
225extern "C" {
226
227TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
228 TF_Status* status) {
229 const char* name = op_or_function_name; // Shorthand
230 const tensorflow::AttrTypeMap* types;
231 status->status = tensorflow::AttrTypeMapForOp(name, &types);
232 if (status->status.ok()) return new TFE_Op(ctx, name, types);
233 if (TF_GetCode(status) == TF_NOT_FOUND) {
234 if (ctx->context.FindFunctionByName(name)) {
235 status->status = tensorflow::Status::OK();
236 return new TFE_Op(ctx, name, nullptr);
237 }
238 }
239 return nullptr;
240}
241
242void TFE_DeleteOp(TFE_Op* op) { delete op; }
243
244void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
245 tensorflow::Device* d = nullptr;
246 if (device_name != nullptr && strlen(device_name) > 0) {
247 status->status = op->ctx->context.FindDeviceByName(device_name, &d);
248 }
249 op->device = d;
250}
251
252const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
253 tensorflow::Device* device =
254 (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device;
255 return device->name().c_str();
256}
257
258void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
259 op->use_xla = enable;
260#ifndef TENSORFLOW_EAGER_USE_XLA
261 LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
262 "built with XLA support.";
263#endif // TENSORFLOW_EAGER_USE_XLA
264}
265
266void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
267 h->handle->Ref();
268 op->inputs.push_back(h->handle);
269 op->attrs.NumInputs(op->inputs.size());
270}
271
272TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
273 unsigned char* is_list, TF_Status* status) {
274 TF_AttrType ret;
275 if (op->is_function()) {
276 status->status = tensorflow::errors::Unimplemented(
277 "TODO(apassos): Support for attributes for TensorFlow functions is not "
278 "ready yet.");
279 return TF_ATTR_INT; // The compiler requires that we return something.
280 }
281 status->status =
282 tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
283 return ret;
284}
285
286TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
287 const char* op_or_function_name,
288 const char* attr_name, unsigned char* is_list,
289 TF_Status* status) {
290 TF_AttrType ret;
291 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
292 if (!status->status.ok()) {
293 return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
294 }
295 ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
296 TFE_DeleteOp(op);
297 return ret;
298}
299
300void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
301 op->attrs.Set(attr_name, value);
302}
303
304void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
305 op->attrs.Set(attr_name, static_cast<int64>(value));
306}
307
308void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
309 op->attrs.Set(attr_name, value);
310}
311
312void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
313 op->attrs.Set(attr_name, (value == 0) ? false : true);
314}
315
316void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
317 op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
318}
319
320void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
321 const int num_dims, TF_Status* out_status) {
322 if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
323 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
324 tensorflow::strings::StrCat(
325 "Value specified for `", attr_name, "` has ", num_dims,
326 " dimensions which is over the limit of ",
327 tensorflow::TensorShape::MaxDimensions(), ".")
328 .c_str());
329 return;
330 }
331 tensorflow::TensorShapeProto proto;
332 if (num_dims < 0) {
333 proto.set_unknown_rank(true);
334 } else {
335 for (int d = 0; d < num_dims; ++d) {
336 proto.add_dim()->set_size(dims[d]);
337 }
338 }
339 op->attrs.Set(attr_name, proto);
340}
341
342void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
343 const TFE_Op* value) {
344 tensorflow::AttrValue attr_value;
345 tensorflow::NameAttrList* func = attr_value.mutable_func();
346 func->set_name(value->name);
347 value->attrs.FillAttrValueMap(func->mutable_attr());
348 op->attrs.Set(attr_name, attr_value);
349}
350
351#define TFE_OP_SET_ATTR_LIST(fn, type) \
352 void fn(TFE_Op* op, const char* attr_name, const type* values, \
353 int num_values) { \
354 op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
355 values, num_values)); \
356 }
357TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
358TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
359#undef TFE_OP_SET_ATTR_LIST
360
361void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
362 const int64_t* values, int num_values) {
363 op->attrs.Set(attr_name,
364 tensorflow::gtl::ArraySlice<const int64>(
365 reinterpret_cast<const int64*>(values), num_values));
366}
367
368void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
369 const TF_DataType* values, int num_values) {
370 op->attrs.Set(
371 attr_name,
372 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
373 reinterpret_cast<const tensorflow::DataType*>(values), num_values));
374}
375
376void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
377 const unsigned char* values, int num_values) {
378 std::unique_ptr<bool[]> b(new bool[num_values]);
379 for (int i = 0; i < num_values; ++i) {
380 b[i] = values[i];
381 }
382 op->attrs.Set(attr_name,
383 tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
384}
385
386void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
387 const int64_t** dims, const int* num_dims,
388 int num_values, TF_Status* out_status) {
389 std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
390 new tensorflow::TensorShapeProto[num_values]);
391 for (int i = 0; i < num_values; ++i) {
392 const auto num_dims_i = num_dims[i];
393
394 if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
395 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
396 tensorflow::strings::StrCat(
397 "Value specified for `", attr_name, "` has ", num_dims_i,
398 " dimensions which is over the limit of ",
399 tensorflow::TensorShape::MaxDimensions(), ".")
400 .c_str());
401 return;
402 }
403 if (num_dims_i < 0) {
404 proto[i].set_unknown_rank(true);
405 } else {
406 const int64_t* dims_i = dims[i];
407 auto proto_i = &proto[i];
408 for (int d = 0; d < num_dims_i; ++d) {
409 proto_i->add_dim()->set_size(dims_i[d]);
410 }
411 }
412 }
413 op->attrs.Set(attr_name,
414 tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
415 proto.get(), num_values));
416}
417
418void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
419 const TFE_Op** value, int num_values) {
420 std::unique_ptr<tensorflow::NameAttrList[]> funcs(
421 new tensorflow::NameAttrList[num_values]);
422 for (int i = 0; i < num_values; i++) {
423 funcs[i].set_name(value[i]->name);
424 value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr());
425 }
426 op->attrs.Set(attr_name,
427 tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
428 funcs.get(), num_values));
429}
430} // extern "C"
431
432namespace {
433
434// Initializes the step stats if needed.
435void MaybeInitializeStepStats(tensorflow::StepStats* step_stats,
436 tensorflow::EagerContext* ctx) {
437 // Lazily initialize the RunMetadata with information about all devices if
438 // this is the first call.
439 while (step_stats->dev_stats_size() < ctx->devices()->size()) {
440 int device_idx = step_stats->dev_stats_size();
441 auto* dev_stats = step_stats->add_dev_stats();
442 dev_stats->set_device(ctx->devices()->at(device_idx)->name());
443 }
444}
445
446int StepStatsDeviceIndex(tensorflow::StepStats* step_stats,
447 tensorflow::EagerContext* ctx,
448 tensorflow::Device* device) {
449 // Find the current device's index.
450 if (device == nullptr) {
451 device = ctx->HostCPU();
452 }
453 for (int i = 0; i < ctx->devices()->size(); ++i) {
454 if (ctx->devices()->at(i) == device ||
455 ctx->devices()->at(i)->name() == device->name()) {
456 return i;
457 }
458 }
459 // TODO(apassos) do not fall back to host CPU if device is unknown.
460 return 0;
461}
462
463tensorflow::Status ValidateInputTypeAndPlacement(
464 tensorflow::EagerContext* ctx, tensorflow::Device* op_device, TFE_Op* op,
465 const tensorflow::OpKernel* kernel, tensorflow::RunMetadata* run_metadata) {
466 tensorflow::Device* host_device = ctx->HostCPU();
467 const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
468 if (memtypes.size() != op->inputs.size()) {
469 return tensorflow::errors::InvalidArgument(
470 "expected ", memtypes.size(), " inputs, got ", op->inputs.size());
471 }
472 for (int i = 0; i < op->inputs.size(); ++i) {
473 const tensorflow::Device* expected_device =
474 memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
475 tensorflow::TensorHandle* handle = op->inputs[i];
476 tensorflow::Device* handle_device = nullptr;
477 TF_RETURN_IF_ERROR(handle->Device(&handle_device));
478 const tensorflow::Device* actual_device =
479 handle_device == nullptr ? host_device : handle_device;
480 if (expected_device != actual_device) {
481 switch (ctx->GetDevicePlacementPolicy()) {
482 case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32:
483 // TODO(xpan): See if we could bubble python related error up
484 // to python level.
485 if (handle->dtype == tensorflow::DT_INT32) {
486 // Note: enabling silent copies of int32 tensors to match behavior
487 // of graph mode.
488 break;
489 }
490 TF_FALLTHROUGH_INTENDED;
491 case tensorflow::DEVICE_PLACEMENT_EXPLICIT:
492 return tensorflow::errors::InvalidArgument(
493 "Tensors on conflicting devices:"
494 " cannot compute ",
495 op->name, " as input #", i, " was expected to be on ",
496 expected_device->name(), " but is actually on ",
497 actual_device->name(), " (operation running on ",
498 op_device->name(), ")",
499 " Tensors can be copied explicitly using .gpu() or .cpu() "
500 "methods,"
501 " or transparently copied by using tf.enable_eager_execution("
502 "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
503 "between devices"
504 " may slow down your model");
505 case tensorflow::DEVICE_PLACEMENT_WARN:
506 LOG(WARNING) << "before computing " << op->name << " input #" << i
507 << " was expected to be on " << expected_device->name()
508 << " but is actually on " << actual_device->name()
509 << " (operation running on " << op_device->name()
510 << "). This triggers a copy which can be a performance "
511 "bottleneck.";
512 break;
513 case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing.
514 break;
515 }
516 // We are only here if the policy is warn or silent copies, so we should
517 // trigger a copy.
518 auto pre_time = tensorflow::Env::Default()->NowMicros();
519 tensorflow::TensorHandle* copied_tensor = nullptr;
520 tensorflow::Status status = tensorflow::EagerCopyToDevice(
521 handle, ctx, expected_device->name().c_str(), &copied_tensor);
522 if (run_metadata != nullptr) {
523 auto* step_stats = run_metadata->mutable_step_stats();
524 MaybeInitializeStepStats(step_stats, ctx);
525 // Record the sending on the source device for now.
526 int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
527 auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
528 auto* node_stats = dev_stats->add_node_stats();
529 node_stats->set_node_name("_Send");
530 node_stats->set_all_start_micros(pre_time);
531 node_stats->set_op_end_rel_micros(
532 tensorflow::Env::Default()->NowMicros() - pre_time);
533 }
534 if (!status.ok()) {
535 if (copied_tensor != nullptr) copied_tensor->Unref();
536 return tensorflow::errors::Internal(
537 "Failed copying input tensor from ", actual_device->name(), " to ",
538 expected_device->name(), " in order to run ", op->name, ": ",
539 status.error_message());
540 }
541 handle->Unref();
542 handle = copied_tensor;
543 op->inputs[i] = copied_tensor;
544 }
545 if (handle->dtype != kernel->input_type(i)) {
546 return tensorflow::errors::InvalidArgument(
547 "cannot compute ", op->name, " as input #", i,
548 " was expected to be a ",
549 tensorflow::DataTypeString(kernel->input_type(i)),
550 " tensor but is a ", tensorflow::DataTypeString(handle->dtype),
551 " tensor");
552 }
553 }
554 return tensorflow::Status::OK();
555}
556
557tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
558 TFE_Context* ctx, TF_Status* status) {
559 tensorflow::DeviceSet ds;
560 for (tensorflow::Device* d : *ctx->context.devices()) {
561 ds.AddDevice(d);
562 }
563 tensorflow::DeviceTypeVector final_devices;
564 status->status = tensorflow::SupportedDeviceTypesForNode(
565 ds.PrioritizedDeviceTypeList(), ndef, &final_devices);
566 if (!status->status.ok()) {
567 return nullptr;
568 }
569 if (final_devices.empty()) {
570 status->status = tensorflow::errors::Internal(
571 "Could not find valid device for node ", ndef.DebugString());
572 return nullptr;
573 }
574 for (tensorflow::Device* d : *ctx->context.devices()) {
575 if (d->device_type() == final_devices[0].type_string()) {
576 return d;
577 }
578 }
579 status->status = tensorflow::errors::Unknown(
580 "Could not find a device for node ", ndef.DebugString());
581 return nullptr;
582}
583
584
585#ifdef TENSORFLOW_EAGER_USE_XLA
586// Synthesizes and returns a wrapper function over `op`, which must be a
587// primitive op (e.g. matmul).
588//
589// The wrapper function conforms to the function signature expected by
590// _XlaLaunchOp, with input params ordered by <constants, (variable) args and
591// resources>. For example, if the op has input params <Const1, Arg2, Const3,
592// Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
593// Resource4> as the input params to the synthesized function.
594//
595// It populates `const_input_types`, `arg_input_types` and
596// `op_input_to_func_input` based on the reordering results, that the caller can
597// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
598// `status` accordingly.
599const tensorflow::FunctionDef* OpToFunction(
600 TFE_Op* op, std::vector<TF_DataType>* const_input_types,
601 std::vector<TF_DataType>* arg_input_types,
602 tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input,
603 TF_Status* status) {
604 DCHECK(!op->is_function());
605
606 tensorflow::FunctionDef fdef;
607
608 // Get the OpDef of the op we are trying to encapsulate.
609 TFE_Context* ctx = op->ctx;
610 const tensorflow::OpRegistrationData* op_data;
611 {
612 status->status = ctx->context.FindFunctionOpData(op->name, &op_data);
613 if (!status->status.ok()) {
614 return nullptr;
615 }
616 }
617 const tensorflow::OpDef& op_def = op_data->op_def;
618
619 tensorflow::OpDef* signature = fdef.mutable_signature();
620
621 // Handle constant inputs.
622 const std::unordered_set<string> const_inputs(
623 *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
624
625 // First add place holders for the input args, so that we can refer to them by
626 // position in the next loop. Also tally up the resource inputs.
627 int num_resource_inputs = 0;
628 for (int i = 0; i < op_def.input_arg_size(); ++i) {
629 if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) {
630 ++num_resource_inputs;
631 }
632 signature->add_input_arg();
633 }
634
635 // Now we map the input params from `op_def` to `signature`, where the param
636 // ordering for `signature` is: <constants, args, resources>.
637 int const_index = 0;
638 int arg_index = const_inputs.size();
639 int resource_index = op_def.input_arg_size() - num_resource_inputs;
640 for (int i = 0; i < op_def.input_arg_size(); ++i) {
641 const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
642 tensorflow::OpDef::ArgDef* func_input_arg = nullptr;
643 if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
644 VLOG(1) << "For const input, mapping op input " << i << " to func input "
645 << const_index;
646 (*op_input_to_func_input)[i] = const_index;
647 func_input_arg = signature->mutable_input_arg(const_index++);
648 const_input_types->push_back(
649 static_cast<TF_DataType>(op->inputs[i]->dtype));
650 } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
651 VLOG(1) << "For resource input, mapping op input " << i
652 << " to func input " << resource_index;
653 (*op_input_to_func_input)[i] = resource_index;
654 func_input_arg = signature->mutable_input_arg(resource_index++);
655 } else {
656 VLOG(1) << "For arg input, mapping op input " << i << " to func input "
657 << arg_index;
658 (*op_input_to_func_input)[i] = arg_index;
659 func_input_arg = signature->mutable_input_arg(arg_index++);
660 arg_input_types->push_back(
661 static_cast<TF_DataType>(op->inputs[i]->dtype));
662 }
663
664 func_input_arg->set_name(op_input_arg.name());
665 func_input_arg->set_type(op->inputs[i]->dtype);
666 }
667 VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
668
669 // Resources args are at the end of the function input params, and we should
670 // have iterated over all of them.
671 DCHECK_EQ(signature->input_arg_size(), resource_index);
672
673 // Make the synthesized function's name unique.
674 signature->set_name(tensorflow::strings::StrCat(
675 op_def.name(), func_id_generator.fetch_add(1)));
676
677 // Add the node def and set its input names to match op_def's names.
678 const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
679 DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
680 *fdef.add_node_def() = ndef;
681 for (int i = 0; i < op_def.input_arg_size(); ++i) {
682 fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
683 }
684 VLOG(1) << "Added NodeDef: " << fdef.DebugString();
685
686 // Fix the output names and set output types.
687 for (int i = 0; i < op_def.output_arg_size(); ++i) {
688 tensorflow::OpDef::ArgDef* arg = signature->add_output_arg();
689 const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
690 const string& out_tensor_name = tensorflow::strings::StrCat(
691 ndef.name(), ":", op_def_arg.name(), ":", 0);
692 arg->set_name(op_def_arg.name());
693 (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
694 const string& type_attr = op_def_arg.type_attr();
695 if (!type_attr.empty()) {
696 auto i = ndef.attr().find(type_attr);
697 if (i == ndef.attr().end()) {
698 status->status = tensorflow::errors::InvalidArgument(
699 tensorflow::strings::StrCat("Could not find attr ", type_attr,
700 " in NodeDef ", ndef.DebugString()));
701 return nullptr;
702 }
703 arg->set_type(i->second.type());
704 }
705 }
706 VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
707
708 status->status = ctx->context.AddFunctionDef(fdef);
709 if (!status->status.ok()) return nullptr;
710 const auto ret = ctx->context.FindFunctionDef(signature->name());
711 DCHECK(ret != nullptr);
712 return ret;
713}
714
715// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
716// via XLA.
717std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
718 VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
719 auto launch_op =
720 std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status));
721 if (TF_GetCode(status) != TF_OK) return nullptr;
722 if (op->device) {
723 TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
724 if (TF_GetCode(status) != TF_OK) return nullptr;
725 }
726
727 const tensorflow::FunctionDef* fdef;
728 {
729 fdef = op->ctx->context.FindFunctionDef(op->name);
730 }
731 std::vector<TF_DataType> const_input_types;
732 std::vector<TF_DataType> arg_input_types;
733 tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
734 if (fdef == nullptr) {
735 // See if this is a primitive op, and if so create a function for it, so
736 // that _XlaLaunchOp can access it.
737 fdef = OpToFunction(op, &const_input_types, &arg_input_types,
738 &op_input_to_func_input, status);
739 if (!status->status.ok()) return nullptr;
740 } else {
741 // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
742 // functions, so we need to find another way to handle constant inputs.
743 for (int i = const_input_types.size();
744 i < fdef->signature().input_arg_size(); ++i) {
745 VLOG(1) << "Adding Targs from input arg " << i;
746 const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i);
747 arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
748 }
749 }
750 DCHECK(fdef != nullptr);
751
752 // Copy inputs and their devices.
753 // Since input param reordering may have occurred between `op` and `launch_op`
754 // via `op_input_to_func_input`, adjust the actual inputs accordingly.
755 launch_op->inputs = op->inputs;
756 for (tensorflow::TensorHandle* h : launch_op->inputs) {
757 h->Ref();
758 }
759 if (!op_input_to_func_input.empty()) {
760 DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
761 for (int i = 0; i < op_input_to_func_input.size(); ++i) {
762 VLOG(1) << "mapping op input " << i << " to func input "
763 << op_input_to_func_input[i];
764
765 launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
766 }
767 }
768 launch_op->attrs.NumInputs(op->inputs.size());
769
770 TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
771 const_input_types.size());
772
773 // Set Targs and Nresources attrs.
774 TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
775 arg_input_types.size());
776 const int num_resource_inputs = fdef->signature().input_arg_size() -
777 const_input_types.size() -
778 arg_input_types.size();
779 TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
780
781 // Set Tresults attr.
782 std::vector<TF_DataType> tresults;
783 for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) {
784 tresults.push_back(static_cast<TF_DataType>(arg.type()));
785 }
786 TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
787 tresults.size());
788
789 // Set function attr.
790 tensorflow::AttrValue attr_value;
791 tensorflow::NameAttrList* func = attr_value.mutable_func();
792 func->set_name(fdef->signature().name());
793 launch_op->attrs.Set("function", attr_value);
794
795 return launch_op;
796}
797#endif // TENSORFLOW_EAGER_USE_XLA
798
799} // namespace
800
801extern "C" {
802
803void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
804 TF_Status* status) {
805 TFE_Context* ctx = op->ctx;
806 status->status = ctx->context.GetStatus();
807 if (!status->status.ok()) {
808 return;
809 }
810#ifdef TENSORFLOW_EAGER_USE_XLA
811 std::unique_ptr<TFE_Op> xla_launch_op;
812 if (op->use_xla && op->name != "_XlaLaunch") {
813 xla_launch_op = BuildXlaLaunch(op, status);
814 if (!status->status.ok()) {
815 return;
816 }
817 op = xla_launch_op.get();
818 }
819#endif // TENSORFLOW_EAGER_USE_XLA
820 // Ensure all resource-touching ops run in the device the resource is,
821 // regardless of anything else that has been specified. This is identical to
822 // the graph mode behavior.
823 for (int i = 0; i < op->inputs.size(); ++i) {
824 tensorflow::Device* input_op_device = nullptr;
825 status->status = op->inputs[i]->OpDevice(&input_op_device);
826 if (!status->status.ok()) return;
827 VLOG(2) << "for op " << op->name << " input " << i << " "
828 << tensorflow::DataTypeString(op->inputs[i]->dtype) << " "
829 << (input_op_device == nullptr ? "cpu" : input_op_device->name())
830 << " " << (op->device == nullptr ? "cpu" : op->device->name());
831 if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
832 (input_op_device != op->device || input_op_device == nullptr)) {
833 tensorflow::Device* d =
834 input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device;
835 VLOG(1) << "Changing device of operation " << op->name << " to "
836 << d->name() << " because input #" << i
837 << " is a resource in this device.";
838 op->device = d;
839 }
840 }
841 tensorflow::Device* device = op->device;
842
843 tensorflow::Fprint128 cache_key =
844 op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
845 tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key);
846 if (kernel == nullptr) {
847 const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
848 if (device == nullptr) {
849 device = SelectDevice(ndef, ctx, status);
850 if (!status->status.ok()) {
851 return;
852 }
853 }
854 CHECK(device != nullptr);
855 if (ctx->context.LogDevicePlacement()) {
856 LOG(INFO) << "Executing op " << ndef.op() << " in device "
857 << device->name();
858 }
859 kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous());
860 // Knowledge of the implementation of Init (and in-turn
861 // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
862 // will be accessed, so grab on to the lock.
863 // See WARNING comment in Execute (before kernel->Run) - would be nice to
864 // rework to avoid this subtlety.
865 tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu());
866 status->status = tensorflow::KernelAndDevice::Init(
867 ndef, ctx->context.func_lib(device), kernel);
868 if (!status->status.ok()) {
869 delete kernel;
870 return;
871 }
872 // Update output_dtypes inside `kernel`.
873 const tensorflow::OpDef* op_def = nullptr;
874 const tensorflow::FunctionDef* function_def =
875 ctx->context.FuncLibDef()->Find(ndef.op());
876 if (function_def != nullptr) {
877 op_def = &(function_def->signature());
878 }
879 if (op_def == nullptr) {
880 status->status = OpDefForOp(ndef.op().c_str(), &op_def);
881 if (!status->status.ok()) {
882 return;
883 }
884 }
885 tensorflow::DataTypeVector input_dtypes;
886 status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes,
887 kernel->mutable_output_dtypes());
888 if (!status->status.ok()) {
889 return;
890 }
891 ctx->context.AddKernelToCache(cache_key, kernel);
892 }
893 const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
894 const int output_dtypes_size = output_dtypes.size();
895 if (output_dtypes_size > *num_retvals) {
896 TF_SetStatus(status, TF_INVALID_ARGUMENT,
897 tensorflow::strings::StrCat("Expecting ", output_dtypes.size(),
898 " outputs, but *num_retvals is ",
899 *num_retvals)
900 .c_str());
901 return;
902 }
903 *num_retvals = output_dtypes_size;
904 if (device == nullptr) {
905 // TODO(apassos) debug how the assignment below might return a different
906 // device from the one requested above.
907 device = kernel->device();
908 }
909 status->status = ValidateInputTypeAndPlacement(
910 &ctx->context, device, op, kernel->kernel(),
911 ctx->context.ShouldStoreMetadata() ? ctx->context.RunMetadataProto()
912 : nullptr);
913 if (!status->status.ok()) return;
914 std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
915 if (ctx->context.ShouldStoreMetadata()) {
916 maybe_stats.reset(new tensorflow::NodeExecStats);
917 maybe_stats->set_node_name(op->name);
918 maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
919 maybe_stats->set_op_start_rel_micros(0);
920 maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
921 // TODO(apassos) track referenced tensors
922 }
923 if (ctx->context.Async()) {
924 // Note that for async mode, execution order will make sure that all
925 // input handles are ready before executing them.
926 // TODO(agarwal): Consider executing "cheap" kernels inline for performance.
927 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
928 *num_retvals);
929 tensorflow::uint64 id = op->ctx->context.NextId();
930 for (int i = 0; i < *num_retvals; ++i) {
931 tensorflow::TensorHandle* h =
932 new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context);
933 retvals[i] = new TFE_TensorHandle(h);
934 handle_retvals[i] = h;
935 }
936 tensorflow::EagerNode* node = new tensorflow::ExecuteNode(
937 id, &op->ctx->context, op->device, op->inputs, kernel,
938 maybe_stats.release(), output_dtypes, handle_retvals);
939 ctx->context.ExecutorAdd(node);
940 } else {
941 // Execute checks if retvals[i] is nullptr or not to figure if it needs to
942 // allocate it.
943 std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals,
944 nullptr);
945 status->status = tensorflow::EagerExecute(
946 &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(),
947 handle_retvals.data(), *num_retvals);
948 for (int i = 0; i < *num_retvals; ++i) {
949 retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
950 }
951 }
952}
953
954TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
955 TFE_Context* ctx,
956 const char* device_name,
957 TF_Status* status) {
958 tensorflow::TensorHandle* handle;
959 status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
960 device_name, &handle);
961 if (status->status.ok()) {
962 return new TFE_TensorHandle(handle);
963 }
964 return nullptr;
965}
966
967void TFE_ContextAddFunctionDef(TFE_Context* ctx,
968 const char* serialized_function_def, size_t size,
969 TF_Status* status) {
970 tensorflow::FunctionDef function_def;
971 if (!function_def.ParseFromArray(serialized_function_def, size)) {
972 status->status =
973 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
974 return;
975 }
976 status->status = ctx->context.AddFunctionDef(function_def);
977}
978
979void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
980 TF_Status* status) {
981 status->status = ctx->context.AddFunctionDef(function->fdef);
982}
983
984void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
985 ctx->context.SetShouldStoreMetadata(true);
986}
987
988void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
989 ctx->context.SetShouldStoreMetadata(false);
990}
991
992} // extern "C"
993
994TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
995 return new TFE_TensorHandle(t, nullptr, nullptr);
996}
997
998const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
999 TFE_TensorHandle* h, TF_Status* status) {
1000 tensorflow::Device* d = nullptr;
1001 tensorflow::Device* op_device = nullptr;
1002 const tensorflow::Tensor* t = nullptr;
1003 status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
1004 if (!status->status.ok()) return nullptr;
1005 if (d != nullptr) {
1006 status->status = tensorflow::errors::FailedPrecondition(
1007 "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
1008 "a tensorflow::Tensor");
1009 return nullptr;
1010 }
1011 return t;
1012}
1013
1014void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
1015 TF_Status* status) {
1016 TFE_ContextAsyncWait(ctx, status);
1017 if (!status->status.ok()) return;
1018 tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
1019 status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
1020 ctx->context.RunMetadataProto()->Clear();
1021}
1022
1023namespace {
1024TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
1025 TF_Status* status) {
1026 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
1027 for (const auto& attr : func.attr()) {
1028 if (TF_GetCode(status) != TF_OK) return nullptr;
1029 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
1030 if (TF_GetCode(status) != TF_OK) return nullptr;
1031 }
1032 return func_op;
1033}
1034} // namespace
1035
1036namespace tensorflow {
1037void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
1038 const tensorflow::AttrValue& default_value,
1039 const char* attr_name, TF_Status* status) {
1040 switch (default_value.value_case()) {
1041 case tensorflow::AttrValue::kS:
1042 TFE_OpSetAttrString(op, attr_name, default_value.s().data());
1043 break;
1044 case tensorflow::AttrValue::kI:
1045 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1046 break;
1047 case tensorflow::AttrValue::kF:
1048 TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1049 break;
1050 case tensorflow::AttrValue::kB:
1051 TFE_OpSetAttrBool(op, attr_name, default_value.b());
1052 break;
1053 case tensorflow::AttrValue::kType:
1054 TFE_OpSetAttrType(op, attr_name,
1055 static_cast<TF_DataType>(default_value.type()));
1056 break;
1057 case tensorflow::AttrValue::kShape: {
1058 const auto& tensor_shape = default_value.shape();
1059 if (tensor_shape.unknown_rank()) {
1060 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1061 } else {
1062 const auto num_dims = tensor_shape.dim_size();
1063 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1064 for (int i = 0; i < num_dims; ++i) {
1065 dims[i] = tensor_shape.dim(i).size();
1066 }
1067 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1068 }
1069 } break;
1070 case tensorflow::AttrValue::kFunc: {
1071 const auto func_op = GetFunc(ctx, default_value.func(), status);
1072 if (TF_GetCode(status) != TF_OK) return;
1073 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1074 // require TFE_Op* and just convert it internally a NameAttrValue, so
1075 // consider adding an overload to the C API to make this case easier.
1076 TFE_OpSetAttrFunction(op, attr_name, func_op);
1077 } break;
1078 case tensorflow::AttrValue::kList:
1079 TF_FALLTHROUGH_INTENDED;
1080 case tensorflow::AttrValue::kTensor:
1081 TF_FALLTHROUGH_INTENDED;
1082 case tensorflow::AttrValue::kPlaceholder:
1083 TF_FALLTHROUGH_INTENDED;
1084 case tensorflow::AttrValue::VALUE_NOT_SET:
1085 TF_SetStatus(
1086 status, TF_UNIMPLEMENTED,
1087 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1088 default_value.DebugString())
1089 .data());
1090 }
1091}
1092} // namespace tensorflow
1093
1094
1095TFE_Op::~TFE_Op() {
1096 for (tensorflow::TensorHandle* h : inputs) {
1097 h->Unref();
1098 }
1099}
1100