1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/eager/c_api.h" |
17 | |
18 | #include <algorithm> |
19 | #include <cstddef> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/c/c_api.h" |
25 | #include "tensorflow/c/c_api_internal.h" |
26 | #include "tensorflow/c/eager/c_api_internal.h" |
27 | #include "tensorflow/c/eager/runtime.h" |
28 | #ifdef TENSORFLOW_EAGER_USE_XLA |
29 | #include "tensorflow/compiler/tf2xla/xla_op_registry.h" |
30 | #endif // TENSORFLOW_EAGER_USE_XLA |
31 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
32 | #include "tensorflow/core/common_runtime/device_factory.h" |
33 | #include "tensorflow/core/common_runtime/device_mgr.h" |
34 | #include "tensorflow/core/common_runtime/device_set.h" |
35 | #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" |
36 | #include "tensorflow/core/common_runtime/eager/execute.h" |
37 | #include "tensorflow/core/common_runtime/eager/execute_node.h" |
38 | #include "tensorflow/core/common_runtime/function.h" |
39 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
40 | #include "tensorflow/core/framework/node_def_util.h" |
41 | #include "tensorflow/core/framework/rendezvous.h" |
42 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
43 | #include "tensorflow/core/framework/types.h" |
44 | #include "tensorflow/core/lib/core/refcount.h" |
45 | #include "tensorflow/core/lib/gtl/flatmap.h" |
46 | #include "tensorflow/core/lib/gtl/map_util.h" |
47 | #include "tensorflow/core/lib/gtl/stl_util.h" |
48 | #include "tensorflow/core/platform/env.h" |
49 | #include "tensorflow/core/platform/mutex.h" |
50 | #include "tensorflow/core/platform/thread_annotations.h" |
51 | #include "tensorflow/core/public/version.h" |
52 | |
53 | using tensorflow::int64; |
54 | using tensorflow::string; |
55 | |
56 | namespace { |
57 | bool IsCPU(const tensorflow::Device* d) { |
58 | return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; |
59 | } |
60 | |
61 | bool IsXLA(const tensorflow::Device* d) { |
62 | if (d == nullptr) return false; |
63 | const auto& device_type = d->attributes().device_type(); |
64 | return device_type.find("XLA" ) != std::string::npos; |
65 | } |
66 | |
67 | string DeviceName(const tensorflow::Device* d) { |
68 | return (d == nullptr) ? "cpu:0" : d->name(); |
69 | } |
70 | |
71 | #ifdef TENSORFLOW_EAGER_USE_XLA |
72 | std::atomic_int_fast64_t func_id_generator(0); |
73 | #endif // TENSORFLOW_EAGER_USE_XLA |
74 | |
75 | } // namespace |
76 | |
77 | extern "C" { |
78 | |
79 | TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } |
80 | |
81 | void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, |
82 | size_t proto_len, TF_Status* status) { |
83 | TF_SetConfig(&options->session_options, proto, proto_len, status); |
84 | } |
85 | |
86 | void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, |
87 | unsigned char async) { |
88 | options->async = async; |
89 | } |
90 | void TFE_ContextOptionsSetDevicePlacementPolicy( |
91 | TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { |
92 | options->policy = policy; |
93 | } |
94 | |
95 | TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, |
96 | unsigned char async, |
97 | TF_Status* status) { |
98 | status->status = ctx->context.SetAsyncForThread(async); |
99 | } |
100 | |
101 | void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } |
102 | |
103 | TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { |
104 | std::vector<tensorflow::Device*> devices; |
105 | status->status = tensorflow::DeviceFactory::AddDevices( |
106 | opts->session_options.options, "/job:localhost/replica:0/task:0" , |
107 | &devices); |
108 | if (!status->status.ok()) { |
109 | return nullptr; |
110 | } |
111 | std::unique_ptr<tensorflow::DeviceMgr> device_mgr( |
112 | new tensorflow::DeviceMgr(devices)); |
113 | tensorflow::Rendezvous* r = |
114 | new tensorflow::IntraProcessRendezvous(device_mgr.get()); |
115 | return new TFE_Context(opts->session_options.options, opts->policy, |
116 | opts->async, std::move(device_mgr), r); |
117 | } |
118 | |
119 | void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { |
120 | delete ctx; |
121 | } |
122 | |
123 | TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { |
124 | TF_DeviceList* list = new TF_DeviceList; |
125 | ctx->context.device_mgr()->ListDeviceAttributes(&list->response); |
126 | return list; |
127 | } |
128 | |
129 | void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } |
130 | |
131 | void TFE_ContextSetThreadLocalDevicePlacementPolicy( |
132 | TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { |
133 | ctx->context.SetThreadLocalDevicePlacementPolicy( |
134 | static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); |
135 | } |
136 | |
137 | // Note: this function looks up a thread local policy. So it should be called in |
138 | // the appropriate client thread. In particular, in async mode, it may not be |
139 | // safe to call this function from the async EagerExecutor threads. |
140 | extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( |
141 | TFE_Context* ctx) { |
142 | return static_cast<TFE_ContextDevicePlacementPolicy>( |
143 | ctx->context.GetDevicePlacementPolicy()); |
144 | } |
145 | |
146 | void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { |
147 | status->status = ctx->context.AsyncWait(); |
148 | } |
149 | |
150 | void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { |
151 | status->status = ctx->context.GetStatus(); |
152 | } |
153 | |
154 | void TFE_ContextAsyncClearError(TFE_Context* ctx) { |
155 | ctx->context.ClearAsyncError(); |
156 | } |
157 | |
158 | TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { |
159 | tensorflow::Tensor tensor; |
160 | status->status = tensorflow::TF_TensorToTensor(t, &tensor); |
161 | if (!status->status.ok()) return nullptr; |
162 | return new TFE_TensorHandle(tensor, nullptr, nullptr); |
163 | } |
164 | |
165 | void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { |
166 | DCHECK(h); |
167 | if (h->handle) { |
168 | h->handle->Unref(); |
169 | } |
170 | delete h; |
171 | } |
172 | |
173 | TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { |
174 | return static_cast<TF_DataType>(h->handle->dtype); |
175 | } |
176 | |
177 | int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { |
178 | const tensorflow::Tensor* t = nullptr; |
179 | status->status = h->handle->Tensor(&t); |
180 | return t == nullptr ? 0 : t->dims(); |
181 | } |
182 | |
183 | int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, |
184 | TF_Status* status) { |
185 | const tensorflow::Tensor* t = nullptr; |
186 | status->status = h->handle->Tensor(&t); |
187 | return t == nullptr ? 0 : t->dim_size(dim_index); |
188 | } |
189 | |
190 | const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { |
191 | tensorflow::Device* d = nullptr; |
192 | status->status = h->handle->OpDevice(&d); |
193 | return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" |
194 | : d->name().c_str(); |
195 | } |
196 | |
197 | TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { |
198 | // TODO(agarwal): move this implementation inside TFE_TensorHandle. |
199 | tensorflow::Device* d = nullptr; |
200 | tensorflow::Device* op_device = nullptr; |
201 | const tensorflow::Tensor* t = nullptr; |
202 | status->status = h->handle->TensorAndDevice(&t, &d, &op_device); |
203 | if (!status->status.ok()) return nullptr; |
204 | tensorflow::TensorHandle* h_cpu = nullptr; |
205 | if (!IsCPU(d)) { |
206 | status->status = h->handle->CopyToDevice( |
207 | h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); |
208 | if (!status->status.ok()) { |
209 | return nullptr; |
210 | } |
211 | status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); |
212 | if (!status->status.ok()) { |
213 | h_cpu->Unref(); |
214 | return nullptr; |
215 | } |
216 | } |
217 | TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); |
218 | if (h_cpu != nullptr) { |
219 | h_cpu->Unref(); |
220 | } |
221 | return retval; |
222 | } |
223 | } // extern "C" |
224 | |
225 | extern "C" { |
226 | |
227 | TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, |
228 | TF_Status* status) { |
229 | const char* name = op_or_function_name; // Shorthand |
230 | const tensorflow::AttrTypeMap* types; |
231 | status->status = tensorflow::AttrTypeMapForOp(name, &types); |
232 | if (status->status.ok()) return new TFE_Op(ctx, name, types); |
233 | if (TF_GetCode(status) == TF_NOT_FOUND) { |
234 | if (ctx->context.FindFunctionByName(name)) { |
235 | status->status = tensorflow::Status::OK(); |
236 | return new TFE_Op(ctx, name, nullptr); |
237 | } |
238 | } |
239 | return nullptr; |
240 | } |
241 | |
242 | void TFE_DeleteOp(TFE_Op* op) { delete op; } |
243 | |
244 | void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { |
245 | tensorflow::Device* d = nullptr; |
246 | if (device_name != nullptr && strlen(device_name) > 0) { |
247 | status->status = op->ctx->context.FindDeviceByName(device_name, &d); |
248 | } |
249 | op->device = d; |
250 | } |
251 | |
252 | const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { |
253 | tensorflow::Device* device = |
254 | (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device; |
255 | return device->name().c_str(); |
256 | } |
257 | |
258 | void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { |
259 | op->use_xla = enable; |
260 | #ifndef TENSORFLOW_EAGER_USE_XLA |
261 | LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " |
262 | "built with XLA support." ; |
263 | #endif // TENSORFLOW_EAGER_USE_XLA |
264 | } |
265 | |
266 | void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { |
267 | h->handle->Ref(); |
268 | op->inputs.push_back(h->handle); |
269 | op->attrs.NumInputs(op->inputs.size()); |
270 | } |
271 | |
272 | TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, |
273 | unsigned char* is_list, TF_Status* status) { |
274 | TF_AttrType ret; |
275 | if (op->is_function()) { |
276 | status->status = tensorflow::errors::Unimplemented( |
277 | "TODO(apassos): Support for attributes for TensorFlow functions is not " |
278 | "ready yet." ); |
279 | return TF_ATTR_INT; // The compiler requires that we return something. |
280 | } |
281 | status->status = |
282 | tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list); |
283 | return ret; |
284 | } |
285 | |
286 | TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, |
287 | const char* op_or_function_name, |
288 | const char* attr_name, unsigned char* is_list, |
289 | TF_Status* status) { |
290 | TF_AttrType ret; |
291 | TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); |
292 | if (!status->status.ok()) { |
293 | return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. |
294 | } |
295 | ret = TFE_OpGetAttrType(op, attr_name, is_list, status); |
296 | TFE_DeleteOp(op); |
297 | return ret; |
298 | } |
299 | |
300 | void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { |
301 | op->attrs.Set(attr_name, value); |
302 | } |
303 | |
304 | void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { |
305 | op->attrs.Set(attr_name, static_cast<int64>(value)); |
306 | } |
307 | |
308 | void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { |
309 | op->attrs.Set(attr_name, value); |
310 | } |
311 | |
312 | void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { |
313 | op->attrs.Set(attr_name, (value == 0) ? false : true); |
314 | } |
315 | |
316 | void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { |
317 | op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value)); |
318 | } |
319 | |
320 | void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, |
321 | const int num_dims, TF_Status* out_status) { |
322 | if (num_dims > tensorflow::TensorShape::MaxDimensions()) { |
323 | TF_SetStatus(out_status, TF_INVALID_ARGUMENT, |
324 | tensorflow::strings::StrCat( |
325 | "Value specified for `" , attr_name, "` has " , num_dims, |
326 | " dimensions which is over the limit of " , |
327 | tensorflow::TensorShape::MaxDimensions(), "." ) |
328 | .c_str()); |
329 | return; |
330 | } |
331 | tensorflow::TensorShapeProto proto; |
332 | if (num_dims < 0) { |
333 | proto.set_unknown_rank(true); |
334 | } else { |
335 | for (int d = 0; d < num_dims; ++d) { |
336 | proto.add_dim()->set_size(dims[d]); |
337 | } |
338 | } |
339 | op->attrs.Set(attr_name, proto); |
340 | } |
341 | |
342 | void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, |
343 | const TFE_Op* value) { |
344 | tensorflow::AttrValue attr_value; |
345 | tensorflow::NameAttrList* func = attr_value.mutable_func(); |
346 | func->set_name(value->name); |
347 | value->attrs.FillAttrValueMap(func->mutable_attr()); |
348 | op->attrs.Set(attr_name, attr_value); |
349 | } |
350 | |
351 | #define TFE_OP_SET_ATTR_LIST(fn, type) \ |
352 | void fn(TFE_Op* op, const char* attr_name, const type* values, \ |
353 | int num_values) { \ |
354 | op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \ |
355 | values, num_values)); \ |
356 | } |
357 | TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) |
358 | TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) |
359 | #undef TFE_OP_SET_ATTR_LIST |
360 | |
361 | void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, |
362 | const int64_t* values, int num_values) { |
363 | op->attrs.Set(attr_name, |
364 | tensorflow::gtl::ArraySlice<const int64>( |
365 | reinterpret_cast<const int64*>(values), num_values)); |
366 | } |
367 | |
368 | void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, |
369 | const TF_DataType* values, int num_values) { |
370 | op->attrs.Set( |
371 | attr_name, |
372 | tensorflow::gtl::ArraySlice<const tensorflow::DataType>( |
373 | reinterpret_cast<const tensorflow::DataType*>(values), num_values)); |
374 | } |
375 | |
376 | void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, |
377 | const unsigned char* values, int num_values) { |
378 | std::unique_ptr<bool[]> b(new bool[num_values]); |
379 | for (int i = 0; i < num_values; ++i) { |
380 | b[i] = values[i]; |
381 | } |
382 | op->attrs.Set(attr_name, |
383 | tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values)); |
384 | } |
385 | |
386 | void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, |
387 | const int64_t** dims, const int* num_dims, |
388 | int num_values, TF_Status* out_status) { |
389 | std::unique_ptr<tensorflow::TensorShapeProto[]> proto( |
390 | new tensorflow::TensorShapeProto[num_values]); |
391 | for (int i = 0; i < num_values; ++i) { |
392 | const auto num_dims_i = num_dims[i]; |
393 | |
394 | if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) { |
395 | TF_SetStatus(out_status, TF_INVALID_ARGUMENT, |
396 | tensorflow::strings::StrCat( |
397 | "Value specified for `" , attr_name, "` has " , num_dims_i, |
398 | " dimensions which is over the limit of " , |
399 | tensorflow::TensorShape::MaxDimensions(), "." ) |
400 | .c_str()); |
401 | return; |
402 | } |
403 | if (num_dims_i < 0) { |
404 | proto[i].set_unknown_rank(true); |
405 | } else { |
406 | const int64_t* dims_i = dims[i]; |
407 | auto proto_i = &proto[i]; |
408 | for (int d = 0; d < num_dims_i; ++d) { |
409 | proto_i->add_dim()->set_size(dims_i[d]); |
410 | } |
411 | } |
412 | } |
413 | op->attrs.Set(attr_name, |
414 | tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>( |
415 | proto.get(), num_values)); |
416 | } |
417 | |
418 | void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, |
419 | const TFE_Op** value, int num_values) { |
420 | std::unique_ptr<tensorflow::NameAttrList[]> funcs( |
421 | new tensorflow::NameAttrList[num_values]); |
422 | for (int i = 0; i < num_values; i++) { |
423 | funcs[i].set_name(value[i]->name); |
424 | value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr()); |
425 | } |
426 | op->attrs.Set(attr_name, |
427 | tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>( |
428 | funcs.get(), num_values)); |
429 | } |
430 | } // extern "C" |
431 | |
432 | namespace { |
433 | |
434 | // Initializes the step stats if needed. |
435 | void MaybeInitializeStepStats(tensorflow::StepStats* step_stats, |
436 | tensorflow::EagerContext* ctx) { |
437 | // Lazily initialize the RunMetadata with information about all devices if |
438 | // this is the first call. |
439 | while (step_stats->dev_stats_size() < ctx->devices()->size()) { |
440 | int device_idx = step_stats->dev_stats_size(); |
441 | auto* dev_stats = step_stats->add_dev_stats(); |
442 | dev_stats->set_device(ctx->devices()->at(device_idx)->name()); |
443 | } |
444 | } |
445 | |
446 | int StepStatsDeviceIndex(tensorflow::StepStats* step_stats, |
447 | tensorflow::EagerContext* ctx, |
448 | tensorflow::Device* device) { |
449 | // Find the current device's index. |
450 | if (device == nullptr) { |
451 | device = ctx->HostCPU(); |
452 | } |
453 | for (int i = 0; i < ctx->devices()->size(); ++i) { |
454 | if (ctx->devices()->at(i) == device || |
455 | ctx->devices()->at(i)->name() == device->name()) { |
456 | return i; |
457 | } |
458 | } |
459 | // TODO(apassos) do not fall back to host CPU if device is unknown. |
460 | return 0; |
461 | } |
462 | |
463 | tensorflow::Status ValidateInputTypeAndPlacement( |
464 | tensorflow::EagerContext* ctx, tensorflow::Device* op_device, TFE_Op* op, |
465 | const tensorflow::OpKernel* kernel, tensorflow::RunMetadata* run_metadata) { |
466 | tensorflow::Device* host_device = ctx->HostCPU(); |
467 | const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); |
468 | if (memtypes.size() != op->inputs.size()) { |
469 | return tensorflow::errors::InvalidArgument( |
470 | "expected " , memtypes.size(), " inputs, got " , op->inputs.size()); |
471 | } |
472 | for (int i = 0; i < op->inputs.size(); ++i) { |
473 | const tensorflow::Device* expected_device = |
474 | memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; |
475 | tensorflow::TensorHandle* handle = op->inputs[i]; |
476 | tensorflow::Device* handle_device = nullptr; |
477 | TF_RETURN_IF_ERROR(handle->Device(&handle_device)); |
478 | const tensorflow::Device* actual_device = |
479 | handle_device == nullptr ? host_device : handle_device; |
480 | if (expected_device != actual_device) { |
481 | switch (ctx->GetDevicePlacementPolicy()) { |
482 | case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32: |
483 | // TODO(xpan): See if we could bubble python related error up |
484 | // to python level. |
485 | if (handle->dtype == tensorflow::DT_INT32) { |
486 | // Note: enabling silent copies of int32 tensors to match behavior |
487 | // of graph mode. |
488 | break; |
489 | } |
490 | TF_FALLTHROUGH_INTENDED; |
491 | case tensorflow::DEVICE_PLACEMENT_EXPLICIT: |
492 | return tensorflow::errors::InvalidArgument( |
493 | "Tensors on conflicting devices:" |
494 | " cannot compute " , |
495 | op->name, " as input #" , i, " was expected to be on " , |
496 | expected_device->name(), " but is actually on " , |
497 | actual_device->name(), " (operation running on " , |
498 | op_device->name(), ")" , |
499 | " Tensors can be copied explicitly using .gpu() or .cpu() " |
500 | "methods," |
501 | " or transparently copied by using tf.enable_eager_execution(" |
502 | "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors " |
503 | "between devices" |
504 | " may slow down your model" ); |
505 | case tensorflow::DEVICE_PLACEMENT_WARN: |
506 | LOG(WARNING) << "before computing " << op->name << " input #" << i |
507 | << " was expected to be on " << expected_device->name() |
508 | << " but is actually on " << actual_device->name() |
509 | << " (operation running on " << op_device->name() |
510 | << "). This triggers a copy which can be a performance " |
511 | "bottleneck." ; |
512 | break; |
513 | case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing. |
514 | break; |
515 | } |
516 | // We are only here if the policy is warn or silent copies, so we should |
517 | // trigger a copy. |
518 | auto pre_time = tensorflow::Env::Default()->NowMicros(); |
519 | tensorflow::TensorHandle* copied_tensor = nullptr; |
520 | tensorflow::Status status = tensorflow::EagerCopyToDevice( |
521 | handle, ctx, expected_device->name().c_str(), &copied_tensor); |
522 | if (run_metadata != nullptr) { |
523 | auto* step_stats = run_metadata->mutable_step_stats(); |
524 | MaybeInitializeStepStats(step_stats, ctx); |
525 | // Record the sending on the source device for now. |
526 | int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device); |
527 | auto* dev_stats = step_stats->mutable_dev_stats(device_idx); |
528 | auto* node_stats = dev_stats->add_node_stats(); |
529 | node_stats->set_node_name("_Send" ); |
530 | node_stats->set_all_start_micros(pre_time); |
531 | node_stats->set_op_end_rel_micros( |
532 | tensorflow::Env::Default()->NowMicros() - pre_time); |
533 | } |
534 | if (!status.ok()) { |
535 | if (copied_tensor != nullptr) copied_tensor->Unref(); |
536 | return tensorflow::errors::Internal( |
537 | "Failed copying input tensor from " , actual_device->name(), " to " , |
538 | expected_device->name(), " in order to run " , op->name, ": " , |
539 | status.error_message()); |
540 | } |
541 | handle->Unref(); |
542 | handle = copied_tensor; |
543 | op->inputs[i] = copied_tensor; |
544 | } |
545 | if (handle->dtype != kernel->input_type(i)) { |
546 | return tensorflow::errors::InvalidArgument( |
547 | "cannot compute " , op->name, " as input #" , i, |
548 | " was expected to be a " , |
549 | tensorflow::DataTypeString(kernel->input_type(i)), |
550 | " tensor but is a " , tensorflow::DataTypeString(handle->dtype), |
551 | " tensor" ); |
552 | } |
553 | } |
554 | return tensorflow::Status::OK(); |
555 | } |
556 | |
557 | tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, |
558 | TFE_Context* ctx, TF_Status* status) { |
559 | tensorflow::DeviceSet ds; |
560 | for (tensorflow::Device* d : *ctx->context.devices()) { |
561 | ds.AddDevice(d); |
562 | } |
563 | tensorflow::DeviceTypeVector final_devices; |
564 | status->status = tensorflow::SupportedDeviceTypesForNode( |
565 | ds.PrioritizedDeviceTypeList(), ndef, &final_devices); |
566 | if (!status->status.ok()) { |
567 | return nullptr; |
568 | } |
569 | if (final_devices.empty()) { |
570 | status->status = tensorflow::errors::Internal( |
571 | "Could not find valid device for node " , ndef.DebugString()); |
572 | return nullptr; |
573 | } |
574 | for (tensorflow::Device* d : *ctx->context.devices()) { |
575 | if (d->device_type() == final_devices[0].type_string()) { |
576 | return d; |
577 | } |
578 | } |
579 | status->status = tensorflow::errors::Unknown( |
580 | "Could not find a device for node " , ndef.DebugString()); |
581 | return nullptr; |
582 | } |
583 | |
584 | |
585 | #ifdef TENSORFLOW_EAGER_USE_XLA |
586 | // Synthesizes and returns a wrapper function over `op`, which must be a |
587 | // primitive op (e.g. matmul). |
588 | // |
589 | // The wrapper function conforms to the function signature expected by |
590 | // _XlaLaunchOp, with input params ordered by <constants, (variable) args and |
591 | // resources>. For example, if the op has input params <Const1, Arg2, Const3, |
592 | // Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5, |
593 | // Resource4> as the input params to the synthesized function. |
594 | // |
595 | // It populates `const_input_types`, `arg_input_types` and |
596 | // `op_input_to_func_input` based on the reordering results, that the caller can |
597 | // use them to build an _XlaLaunchOp. On error, it returns NULL, and sets |
598 | // `status` accordingly. |
599 | const tensorflow::FunctionDef* OpToFunction( |
600 | TFE_Op* op, std::vector<TF_DataType>* const_input_types, |
601 | std::vector<TF_DataType>* arg_input_types, |
602 | tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input, |
603 | TF_Status* status) { |
604 | DCHECK(!op->is_function()); |
605 | |
606 | tensorflow::FunctionDef fdef; |
607 | |
608 | // Get the OpDef of the op we are trying to encapsulate. |
609 | TFE_Context* ctx = op->ctx; |
610 | const tensorflow::OpRegistrationData* op_data; |
611 | { |
612 | status->status = ctx->context.FindFunctionOpData(op->name, &op_data); |
613 | if (!status->status.ok()) { |
614 | return nullptr; |
615 | } |
616 | } |
617 | const tensorflow::OpDef& op_def = op_data->op_def; |
618 | |
619 | tensorflow::OpDef* signature = fdef.mutable_signature(); |
620 | |
621 | // Handle constant inputs. |
622 | const std::unordered_set<string> const_inputs( |
623 | *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name)); |
624 | |
625 | // First add place holders for the input args, so that we can refer to them by |
626 | // position in the next loop. Also tally up the resource inputs. |
627 | int num_resource_inputs = 0; |
628 | for (int i = 0; i < op_def.input_arg_size(); ++i) { |
629 | if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) { |
630 | ++num_resource_inputs; |
631 | } |
632 | signature->add_input_arg(); |
633 | } |
634 | |
635 | // Now we map the input params from `op_def` to `signature`, where the param |
636 | // ordering for `signature` is: <constants, args, resources>. |
637 | int const_index = 0; |
638 | int arg_index = const_inputs.size(); |
639 | int resource_index = op_def.input_arg_size() - num_resource_inputs; |
640 | for (int i = 0; i < op_def.input_arg_size(); ++i) { |
641 | const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i); |
642 | tensorflow::OpDef::ArgDef* func_input_arg = nullptr; |
643 | if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) { |
644 | VLOG(1) << "For const input, mapping op input " << i << " to func input " |
645 | << const_index; |
646 | (*op_input_to_func_input)[i] = const_index; |
647 | func_input_arg = signature->mutable_input_arg(const_index++); |
648 | const_input_types->push_back( |
649 | static_cast<TF_DataType>(op->inputs[i]->dtype)); |
650 | } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { |
651 | VLOG(1) << "For resource input, mapping op input " << i |
652 | << " to func input " << resource_index; |
653 | (*op_input_to_func_input)[i] = resource_index; |
654 | func_input_arg = signature->mutable_input_arg(resource_index++); |
655 | } else { |
656 | VLOG(1) << "For arg input, mapping op input " << i << " to func input " |
657 | << arg_index; |
658 | (*op_input_to_func_input)[i] = arg_index; |
659 | func_input_arg = signature->mutable_input_arg(arg_index++); |
660 | arg_input_types->push_back( |
661 | static_cast<TF_DataType>(op->inputs[i]->dtype)); |
662 | } |
663 | |
664 | func_input_arg->set_name(op_input_arg.name()); |
665 | func_input_arg->set_type(op->inputs[i]->dtype); |
666 | } |
667 | VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); |
668 | |
669 | // Resources args are at the end of the function input params, and we should |
670 | // have iterated over all of them. |
671 | DCHECK_EQ(signature->input_arg_size(), resource_index); |
672 | |
673 | // Make the synthesized function's name unique. |
674 | signature->set_name(tensorflow::strings::StrCat( |
675 | op_def.name(), func_id_generator.fetch_add(1))); |
676 | |
677 | // Add the node def and set its input names to match op_def's names. |
678 | const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); |
679 | DCHECK_EQ(signature->input_arg_size(), ndef.input_size()); |
680 | *fdef.add_node_def() = ndef; |
681 | for (int i = 0; i < op_def.input_arg_size(); ++i) { |
682 | fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name()); |
683 | } |
684 | VLOG(1) << "Added NodeDef: " << fdef.DebugString(); |
685 | |
686 | // Fix the output names and set output types. |
687 | for (int i = 0; i < op_def.output_arg_size(); ++i) { |
688 | tensorflow::OpDef::ArgDef* arg = signature->add_output_arg(); |
689 | const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i); |
690 | const string& out_tensor_name = tensorflow::strings::StrCat( |
691 | ndef.name(), ":" , op_def_arg.name(), ":" , 0); |
692 | arg->set_name(op_def_arg.name()); |
693 | (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name; |
694 | const string& type_attr = op_def_arg.type_attr(); |
695 | if (!type_attr.empty()) { |
696 | auto i = ndef.attr().find(type_attr); |
697 | if (i == ndef.attr().end()) { |
698 | status->status = tensorflow::errors::InvalidArgument( |
699 | tensorflow::strings::StrCat("Could not find attr " , type_attr, |
700 | " in NodeDef " , ndef.DebugString())); |
701 | return nullptr; |
702 | } |
703 | arg->set_type(i->second.type()); |
704 | } |
705 | } |
706 | VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); |
707 | |
708 | status->status = ctx->context.AddFunctionDef(fdef); |
709 | if (!status->status.ok()) return nullptr; |
710 | const auto ret = ctx->context.FindFunctionDef(signature->name()); |
711 | DCHECK(ret != nullptr); |
712 | return ret; |
713 | } |
714 | |
715 | // Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed |
716 | // via XLA. |
717 | std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) { |
718 | VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name; |
719 | auto launch_op = |
720 | std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch" , status)); |
721 | if (TF_GetCode(status) != TF_OK) return nullptr; |
722 | if (op->device) { |
723 | TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status); |
724 | if (TF_GetCode(status) != TF_OK) return nullptr; |
725 | } |
726 | |
727 | const tensorflow::FunctionDef* fdef; |
728 | { |
729 | fdef = op->ctx->context.FindFunctionDef(op->name); |
730 | } |
731 | std::vector<TF_DataType> const_input_types; |
732 | std::vector<TF_DataType> arg_input_types; |
733 | tensorflow::gtl::FlatMap<int, int> op_input_to_func_input; |
734 | if (fdef == nullptr) { |
735 | // See if this is a primitive op, and if so create a function for it, so |
736 | // that _XlaLaunchOp can access it. |
737 | fdef = OpToFunction(op, &const_input_types, &arg_input_types, |
738 | &op_input_to_func_input, status); |
739 | if (!status->status.ok()) return nullptr; |
740 | } else { |
741 | // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for |
742 | // functions, so we need to find another way to handle constant inputs. |
743 | for (int i = const_input_types.size(); |
744 | i < fdef->signature().input_arg_size(); ++i) { |
745 | VLOG(1) << "Adding Targs from input arg " << i; |
746 | const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i); |
747 | arg_input_types.push_back(static_cast<TF_DataType>(arg.type())); |
748 | } |
749 | } |
750 | DCHECK(fdef != nullptr); |
751 | |
752 | // Copy inputs and their devices. |
753 | // Since input param reordering may have occurred between `op` and `launch_op` |
754 | // via `op_input_to_func_input`, adjust the actual inputs accordingly. |
755 | launch_op->inputs = op->inputs; |
756 | for (tensorflow::TensorHandle* h : launch_op->inputs) { |
757 | h->Ref(); |
758 | } |
759 | if (!op_input_to_func_input.empty()) { |
760 | DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); |
761 | for (int i = 0; i < op_input_to_func_input.size(); ++i) { |
762 | VLOG(1) << "mapping op input " << i << " to func input " |
763 | << op_input_to_func_input[i]; |
764 | |
765 | launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; |
766 | } |
767 | } |
768 | launch_op->attrs.NumInputs(op->inputs.size()); |
769 | |
770 | TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants" , const_input_types.data(), |
771 | const_input_types.size()); |
772 | |
773 | // Set Targs and Nresources attrs. |
774 | TFE_OpSetAttrTypeList(launch_op.get(), "Targs" , arg_input_types.data(), |
775 | arg_input_types.size()); |
776 | const int num_resource_inputs = fdef->signature().input_arg_size() - |
777 | const_input_types.size() - |
778 | arg_input_types.size(); |
779 | TFE_OpSetAttrInt(launch_op.get(), "Nresources" , num_resource_inputs); |
780 | |
781 | // Set Tresults attr. |
782 | std::vector<TF_DataType> tresults; |
783 | for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) { |
784 | tresults.push_back(static_cast<TF_DataType>(arg.type())); |
785 | } |
786 | TFE_OpSetAttrTypeList(launch_op.get(), "Tresults" , tresults.data(), |
787 | tresults.size()); |
788 | |
789 | // Set function attr. |
790 | tensorflow::AttrValue attr_value; |
791 | tensorflow::NameAttrList* func = attr_value.mutable_func(); |
792 | func->set_name(fdef->signature().name()); |
793 | launch_op->attrs.Set("function" , attr_value); |
794 | |
795 | return launch_op; |
796 | } |
797 | #endif // TENSORFLOW_EAGER_USE_XLA |
798 | |
799 | } // namespace |
800 | |
801 | extern "C" { |
802 | |
803 | void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, |
804 | TF_Status* status) { |
805 | TFE_Context* ctx = op->ctx; |
806 | status->status = ctx->context.GetStatus(); |
807 | if (!status->status.ok()) { |
808 | return; |
809 | } |
810 | #ifdef TENSORFLOW_EAGER_USE_XLA |
811 | std::unique_ptr<TFE_Op> xla_launch_op; |
812 | if (op->use_xla && op->name != "_XlaLaunch" ) { |
813 | xla_launch_op = BuildXlaLaunch(op, status); |
814 | if (!status->status.ok()) { |
815 | return; |
816 | } |
817 | op = xla_launch_op.get(); |
818 | } |
819 | #endif // TENSORFLOW_EAGER_USE_XLA |
820 | // Ensure all resource-touching ops run in the device the resource is, |
821 | // regardless of anything else that has been specified. This is identical to |
822 | // the graph mode behavior. |
823 | for (int i = 0; i < op->inputs.size(); ++i) { |
824 | tensorflow::Device* input_op_device = nullptr; |
825 | status->status = op->inputs[i]->OpDevice(&input_op_device); |
826 | if (!status->status.ok()) return; |
827 | VLOG(2) << "for op " << op->name << " input " << i << " " |
828 | << tensorflow::DataTypeString(op->inputs[i]->dtype) << " " |
829 | << (input_op_device == nullptr ? "cpu" : input_op_device->name()) |
830 | << " " << (op->device == nullptr ? "cpu" : op->device->name()); |
831 | if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE && |
832 | (input_op_device != op->device || input_op_device == nullptr)) { |
833 | tensorflow::Device* d = |
834 | input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device; |
835 | VLOG(1) << "Changing device of operation " << op->name << " to " |
836 | << d->name() << " because input #" << i |
837 | << " is a resource in this device." ; |
838 | op->device = d; |
839 | } |
840 | } |
841 | tensorflow::Device* device = op->device; |
842 | |
843 | tensorflow::Fprint128 cache_key = |
844 | op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name()); |
845 | tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key); |
846 | if (kernel == nullptr) { |
847 | const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); |
848 | if (device == nullptr) { |
849 | device = SelectDevice(ndef, ctx, status); |
850 | if (!status->status.ok()) { |
851 | return; |
852 | } |
853 | } |
854 | CHECK(device != nullptr); |
855 | if (ctx->context.LogDevicePlacement()) { |
856 | LOG(INFO) << "Executing op " << ndef.op() << " in device " |
857 | << device->name(); |
858 | } |
859 | kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous()); |
860 | // Knowledge of the implementation of Init (and in-turn |
861 | // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def |
862 | // will be accessed, so grab on to the lock. |
863 | // See WARNING comment in Execute (before kernel->Run) - would be nice to |
864 | // rework to avoid this subtlety. |
865 | tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu()); |
866 | status->status = tensorflow::KernelAndDevice::Init( |
867 | ndef, ctx->context.func_lib(device), kernel); |
868 | if (!status->status.ok()) { |
869 | delete kernel; |
870 | return; |
871 | } |
872 | // Update output_dtypes inside `kernel`. |
873 | const tensorflow::OpDef* op_def = nullptr; |
874 | const tensorflow::FunctionDef* function_def = |
875 | ctx->context.FuncLibDef()->Find(ndef.op()); |
876 | if (function_def != nullptr) { |
877 | op_def = &(function_def->signature()); |
878 | } |
879 | if (op_def == nullptr) { |
880 | status->status = OpDefForOp(ndef.op().c_str(), &op_def); |
881 | if (!status->status.ok()) { |
882 | return; |
883 | } |
884 | } |
885 | tensorflow::DataTypeVector input_dtypes; |
886 | status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes, |
887 | kernel->mutable_output_dtypes()); |
888 | if (!status->status.ok()) { |
889 | return; |
890 | } |
891 | ctx->context.AddKernelToCache(cache_key, kernel); |
892 | } |
893 | const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); |
894 | const int output_dtypes_size = output_dtypes.size(); |
895 | if (output_dtypes_size > *num_retvals) { |
896 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
897 | tensorflow::strings::StrCat("Expecting " , output_dtypes.size(), |
898 | " outputs, but *num_retvals is " , |
899 | *num_retvals) |
900 | .c_str()); |
901 | return; |
902 | } |
903 | *num_retvals = output_dtypes_size; |
904 | if (device == nullptr) { |
905 | // TODO(apassos) debug how the assignment below might return a different |
906 | // device from the one requested above. |
907 | device = kernel->device(); |
908 | } |
909 | status->status = ValidateInputTypeAndPlacement( |
910 | &ctx->context, device, op, kernel->kernel(), |
911 | ctx->context.ShouldStoreMetadata() ? ctx->context.RunMetadataProto() |
912 | : nullptr); |
913 | if (!status->status.ok()) return; |
914 | std::unique_ptr<tensorflow::NodeExecStats> maybe_stats; |
915 | if (ctx->context.ShouldStoreMetadata()) { |
916 | maybe_stats.reset(new tensorflow::NodeExecStats); |
917 | maybe_stats->set_node_name(op->name); |
918 | maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); |
919 | maybe_stats->set_op_start_rel_micros(0); |
920 | maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); |
921 | // TODO(apassos) track referenced tensors |
922 | } |
923 | if (ctx->context.Async()) { |
924 | // Note that for async mode, execution order will make sure that all |
925 | // input handles are ready before executing them. |
926 | // TODO(agarwal): Consider executing "cheap" kernels inline for performance. |
927 | tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals( |
928 | *num_retvals); |
929 | tensorflow::uint64 id = op->ctx->context.NextId(); |
930 | for (int i = 0; i < *num_retvals; ++i) { |
931 | tensorflow::TensorHandle* h = |
932 | new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context); |
933 | retvals[i] = new TFE_TensorHandle(h); |
934 | handle_retvals[i] = h; |
935 | } |
936 | tensorflow::EagerNode* node = new tensorflow::ExecuteNode( |
937 | id, &op->ctx->context, op->device, op->inputs, kernel, |
938 | maybe_stats.release(), output_dtypes, handle_retvals); |
939 | ctx->context.ExecutorAdd(node); |
940 | } else { |
941 | // Execute checks if retvals[i] is nullptr or not to figure if it needs to |
942 | // allocate it. |
943 | std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals, |
944 | nullptr); |
945 | status->status = tensorflow::EagerExecute( |
946 | &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(), |
947 | handle_retvals.data(), *num_retvals); |
948 | for (int i = 0; i < *num_retvals; ++i) { |
949 | retvals[i] = new TFE_TensorHandle(handle_retvals[i]); |
950 | } |
951 | } |
952 | } |
953 | |
954 | TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, |
955 | TFE_Context* ctx, |
956 | const char* device_name, |
957 | TF_Status* status) { |
958 | tensorflow::TensorHandle* handle; |
959 | status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, |
960 | device_name, &handle); |
961 | if (status->status.ok()) { |
962 | return new TFE_TensorHandle(handle); |
963 | } |
964 | return nullptr; |
965 | } |
966 | |
967 | void TFE_ContextAddFunctionDef(TFE_Context* ctx, |
968 | const char* serialized_function_def, size_t size, |
969 | TF_Status* status) { |
970 | tensorflow::FunctionDef function_def; |
971 | if (!function_def.ParseFromArray(serialized_function_def, size)) { |
972 | status->status = |
973 | tensorflow::errors::InvalidArgument("Invalid FunctionDef proto" ); |
974 | return; |
975 | } |
976 | status->status = ctx->context.AddFunctionDef(function_def); |
977 | } |
978 | |
979 | void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, |
980 | TF_Status* status) { |
981 | status->status = ctx->context.AddFunctionDef(function->fdef); |
982 | } |
983 | |
984 | void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { |
985 | ctx->context.SetShouldStoreMetadata(true); |
986 | } |
987 | |
988 | void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { |
989 | ctx->context.SetShouldStoreMetadata(false); |
990 | } |
991 | |
992 | } // extern "C" |
993 | |
994 | TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { |
995 | return new TFE_TensorHandle(t, nullptr, nullptr); |
996 | } |
997 | |
998 | const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( |
999 | TFE_TensorHandle* h, TF_Status* status) { |
1000 | tensorflow::Device* d = nullptr; |
1001 | tensorflow::Device* op_device = nullptr; |
1002 | const tensorflow::Tensor* t = nullptr; |
1003 | status->status = h->handle->TensorAndDevice(&t, &d, &op_device); |
1004 | if (!status->status.ok()) return nullptr; |
1005 | if (d != nullptr) { |
1006 | status->status = tensorflow::errors::FailedPrecondition( |
1007 | "TFE_TensorHandle is placed in device (not host) memory. Cannot return " |
1008 | "a tensorflow::Tensor" ); |
1009 | return nullptr; |
1010 | } |
1011 | return t; |
1012 | } |
1013 | |
1014 | void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, |
1015 | TF_Status* status) { |
1016 | TFE_ContextAsyncWait(ctx, status); |
1017 | if (!status->status.ok()) return; |
1018 | tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); |
1019 | status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); |
1020 | ctx->context.RunMetadataProto()->Clear(); |
1021 | } |
1022 | |
1023 | namespace { |
1024 | TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, |
1025 | TF_Status* status) { |
1026 | TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); |
1027 | for (const auto& attr : func.attr()) { |
1028 | if (TF_GetCode(status) != TF_OK) return nullptr; |
1029 | SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); |
1030 | if (TF_GetCode(status) != TF_OK) return nullptr; |
1031 | } |
1032 | return func_op; |
1033 | } |
1034 | } // namespace |
1035 | |
1036 | namespace tensorflow { |
1037 | void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, |
1038 | const tensorflow::AttrValue& default_value, |
1039 | const char* attr_name, TF_Status* status) { |
1040 | switch (default_value.value_case()) { |
1041 | case tensorflow::AttrValue::kS: |
1042 | TFE_OpSetAttrString(op, attr_name, default_value.s().data()); |
1043 | break; |
1044 | case tensorflow::AttrValue::kI: |
1045 | TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i())); |
1046 | break; |
1047 | case tensorflow::AttrValue::kF: |
1048 | TFE_OpSetAttrFloat(op, attr_name, default_value.f()); |
1049 | break; |
1050 | case tensorflow::AttrValue::kB: |
1051 | TFE_OpSetAttrBool(op, attr_name, default_value.b()); |
1052 | break; |
1053 | case tensorflow::AttrValue::kType: |
1054 | TFE_OpSetAttrType(op, attr_name, |
1055 | static_cast<TF_DataType>(default_value.type())); |
1056 | break; |
1057 | case tensorflow::AttrValue::kShape: { |
1058 | const auto& tensor_shape = default_value.shape(); |
1059 | if (tensor_shape.unknown_rank()) { |
1060 | TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); |
1061 | } else { |
1062 | const auto num_dims = tensor_shape.dim_size(); |
1063 | std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); |
1064 | for (int i = 0; i < num_dims; ++i) { |
1065 | dims[i] = tensor_shape.dim(i).size(); |
1066 | } |
1067 | TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); |
1068 | } |
1069 | } break; |
1070 | case tensorflow::AttrValue::kFunc: { |
1071 | const auto func_op = GetFunc(ctx, default_value.func(), status); |
1072 | if (TF_GetCode(status) != TF_OK) return; |
1073 | // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList |
1074 | // require TFE_Op* and just convert it internally a NameAttrValue, so |
1075 | // consider adding an overload to the C API to make this case easier. |
1076 | TFE_OpSetAttrFunction(op, attr_name, func_op); |
1077 | } break; |
1078 | case tensorflow::AttrValue::kList: |
1079 | TF_FALLTHROUGH_INTENDED; |
1080 | case tensorflow::AttrValue::kTensor: |
1081 | TF_FALLTHROUGH_INTENDED; |
1082 | case tensorflow::AttrValue::kPlaceholder: |
1083 | TF_FALLTHROUGH_INTENDED; |
1084 | case tensorflow::AttrValue::VALUE_NOT_SET: |
1085 | TF_SetStatus( |
1086 | status, TF_UNIMPLEMENTED, |
1087 | tensorflow::strings::StrCat("Unable to get setfor default value: " , |
1088 | default_value.DebugString()) |
1089 | .data()); |
1090 | } |
1091 | } |
1092 | } // namespace tensorflow |
1093 | |
1094 | |
1095 | TFE_Op::~TFE_Op() { |
1096 | for (tensorflow::TensorHandle* h : inputs) { |
1097 | h->Unref(); |
1098 | } |
1099 | } |
1100 | |