1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/c_api.h" |
17 | |
18 | #include <algorithm> |
19 | #include <limits> |
20 | #include <memory> |
21 | #include <vector> |
22 | |
23 | #ifndef __ANDROID__ |
24 | #include "tensorflow/cc/framework/gradients.h" |
25 | #include "tensorflow/cc/framework/ops.h" |
26 | #include "tensorflow/cc/framework/scope_internal.h" |
27 | #include "tensorflow/cc/ops/while_loop.h" |
28 | #include "tensorflow/cc/saved_model/loader.h" |
29 | #include "tensorflow/core/framework/op_gen_lib.h" |
30 | #endif |
31 | #include "tensorflow/c/c_api_internal.h" |
32 | #include "tensorflow/core/common_runtime/device_mgr.h" |
33 | #include "tensorflow/core/common_runtime/eval_const_tensor.h" |
34 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
35 | #include "tensorflow/core/framework/allocation_description.pb.h" |
36 | #include "tensorflow/core/framework/log_memory.h" |
37 | #include "tensorflow/core/framework/node_def_util.h" |
38 | #include "tensorflow/core/framework/op_kernel.h" |
39 | #include "tensorflow/core/framework/partial_tensor_shape.h" |
40 | #include "tensorflow/core/framework/tensor.h" |
41 | #include "tensorflow/core/framework/tensor_shape.h" |
42 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
43 | #include "tensorflow/core/framework/types.h" |
44 | #include "tensorflow/core/framework/versions.pb.h" |
45 | #include "tensorflow/core/graph/graph.h" |
46 | #include "tensorflow/core/graph/graph_constructor.h" |
47 | #include "tensorflow/core/graph/node_builder.h" |
48 | #include "tensorflow/core/lib/core/coding.h" |
49 | #include "tensorflow/core/lib/core/errors.h" |
50 | #include "tensorflow/core/lib/core/status.h" |
51 | #include "tensorflow/core/lib/core/stringpiece.h" |
52 | #include "tensorflow/core/lib/gtl/array_slice.h" |
53 | #include "tensorflow/core/lib/strings/strcat.h" |
54 | #include "tensorflow/core/platform/mem.h" |
55 | #include "tensorflow/core/platform/mutex.h" |
56 | #include "tensorflow/core/platform/protobuf.h" |
57 | #include "tensorflow/core/platform/thread_annotations.h" |
58 | #include "tensorflow/core/platform/types.h" |
59 | #include "tensorflow/core/public/session.h" |
60 | #include "tensorflow/core/public/version.h" |
61 | |
62 | // The implementation below is at the top level instead of the |
63 | // brain namespace because we are defining 'extern "C"' functions. |
64 | using tensorflow::AllocationDescription; |
65 | using tensorflow::DataType; |
66 | using tensorflow::ExtendSessionGraphHelper; |
67 | using tensorflow::Graph; |
68 | using tensorflow::GraphDef; |
69 | using tensorflow::mutex_lock; |
70 | using tensorflow::NameRangeMap; |
71 | using tensorflow::NameRangesForNode; |
72 | using tensorflow::NewSession; |
73 | using tensorflow::Node; |
74 | using tensorflow::NodeBuilder; |
75 | using tensorflow::NodeDef; |
76 | using tensorflow::OpDef; |
77 | using tensorflow::OpRegistry; |
78 | using tensorflow::OutputTensor; |
79 | using tensorflow::PartialTensorShape; |
80 | using tensorflow::RunMetadata; |
81 | using tensorflow::RunOptions; |
82 | using tensorflow::Session; |
83 | using tensorflow::Status; |
84 | using tensorflow::string; |
85 | using tensorflow::Tensor; |
86 | using tensorflow::TensorBuffer; |
87 | using tensorflow::TensorId; |
88 | using tensorflow::TensorShape; |
89 | using tensorflow::TensorShapeProto; |
90 | using tensorflow::VersionDef; |
91 | using tensorflow::error::Code; |
92 | using tensorflow::errors::FailedPrecondition; |
93 | using tensorflow::errors::InvalidArgument; |
94 | using tensorflow::gtl::ArraySlice; |
95 | using tensorflow::strings::StrCat; |
96 | |
97 | extern "C" { |
98 | |
99 | // -------------------------------------------------------------------------- |
100 | const char* TF_Version() { return TF_VERSION_STRING; } |
101 | |
102 | // -------------------------------------------------------------------------- |
103 | size_t TF_DataTypeSize(TF_DataType dt) { |
104 | return static_cast<size_t>( |
105 | tensorflow::DataTypeSize(static_cast<DataType>(dt))); |
106 | } |
107 | |
108 | // -------------------------------------------------------------------------- |
109 | |
110 | TF_Status* TF_NewStatus() { return new TF_Status; } |
111 | |
112 | void TF_DeleteStatus(TF_Status* s) { delete s; } |
113 | |
114 | void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { |
115 | if (code == TF_OK) { |
116 | s->status = Status::OK(); |
117 | return; |
118 | } |
119 | s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg)); |
120 | } |
121 | |
122 | TF_Code TF_GetCode(const TF_Status* s) { |
123 | return static_cast<TF_Code>(s->status.code()); |
124 | } |
125 | |
126 | const char* TF_Message(const TF_Status* s) { |
127 | return s->status.error_message().c_str(); |
128 | } |
129 | |
130 | // -------------------------------------------------------------------------- |
131 | |
132 | namespace { |
133 | class TF_ManagedBuffer : public TensorBuffer { |
134 | public: |
135 | void* data_; |
136 | size_t len_; |
137 | void (*deallocator_)(void* data, size_t len, void* arg); |
138 | void* deallocator_arg_; |
139 | |
140 | ~TF_ManagedBuffer() override { |
141 | (*deallocator_)(data_, len_, deallocator_arg_); |
142 | } |
143 | |
144 | void* data() const override { return data_; } |
145 | size_t size() const override { return len_; } |
146 | TensorBuffer* root_buffer() override { return this; } |
147 | void FillAllocationDescription(AllocationDescription* proto) const override { |
148 | tensorflow::int64 rb = size(); |
149 | proto->set_requested_bytes(rb); |
150 | proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); |
151 | } |
152 | |
153 | // Prevents input forwarding from mutating this buffer. |
154 | bool OwnsMemory() const override { return false; } |
155 | }; |
156 | |
157 | void* allocate_tensor(const char* operation, size_t len) { |
158 | void* data = |
159 | tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len); |
160 | if (tensorflow::LogMemory::IsEnabled() && data != nullptr) { |
161 | tensorflow::LogMemory::RecordRawAllocation( |
162 | operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, |
163 | len, data, tensorflow::cpu_allocator()); |
164 | } |
165 | return data; |
166 | } |
167 | |
168 | void deallocate_buffer(void* data, size_t len, void* arg) { |
169 | if (tensorflow::LogMemory::IsEnabled() && data != nullptr) { |
170 | tensorflow::LogMemory::RecordRawDeallocation( |
171 | "TensorFlow C Api" , |
172 | tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data, |
173 | tensorflow::cpu_allocator(), false); |
174 | } |
175 | tensorflow::cpu_allocator()->DeallocateRaw(data); |
176 | } |
177 | |
178 | } // namespace |
179 | |
180 | TF_Tensor::~TF_Tensor() { buffer->Unref(); } |
181 | |
182 | TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims, |
183 | int num_dims, size_t len) { |
184 | void* data = allocate_tensor("TF_AllocateTensor" , len); |
185 | return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer, |
186 | nullptr); |
187 | } |
188 | |
189 | TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, |
190 | void* data, size_t len, |
191 | void (*deallocator)(void* data, size_t len, void* arg), |
192 | void* deallocator_arg) { |
193 | std::vector<tensorflow::int64> dimvec(num_dims); |
194 | for (int i = 0; i < num_dims; ++i) { |
195 | dimvec[i] = static_cast<tensorflow::int64>(dims[i]); |
196 | } |
197 | |
198 | TF_ManagedBuffer* buf = new TF_ManagedBuffer; |
199 | buf->len_ = len; |
200 | if (dtype != TF_STRING && dtype != TF_RESOURCE && |
201 | tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) && |
202 | reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) { |
203 | // TF_STRING and TF_RESOURCE tensors have a different representation in |
204 | // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste |
205 | // (any alignment requirements will be taken care of by TF_TensorToTensor |
206 | // and TF_TensorFromTensor). |
207 | // |
208 | // Other types have the same representation, so copy only if it is safe to |
209 | // do so. |
210 | buf->data_ = allocate_tensor("TF_NewTensor" , len); |
211 | std::memcpy(buf->data_, data, len); |
212 | buf->deallocator_ = deallocate_buffer; |
213 | buf->deallocator_arg_ = nullptr; |
214 | // Free the original buffer. |
215 | deallocator(data, len, deallocator_arg); |
216 | } else { |
217 | buf->data_ = data; |
218 | buf->deallocator_ = deallocator; |
219 | buf->deallocator_arg_ = deallocator_arg; |
220 | } |
221 | TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; |
222 | size_t elem_size = TF_DataTypeSize(dtype); |
223 | if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { |
224 | delete ret; |
225 | return nullptr; |
226 | } |
227 | return ret; |
228 | } |
229 | |
230 | TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { |
231 | // It is safe to move the Tensor if and only if we own the unique reference to |
232 | // it. In that case, we might as well not delete and reallocate, but a future |
233 | // implementation might need to do so. |
234 | TensorBuffer* buf = tensor->buffer; |
235 | if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() && |
236 | buf->OwnsMemory()) { |
237 | return tensor; |
238 | } |
239 | return nullptr; |
240 | } |
241 | |
242 | void TF_DeleteTensor(TF_Tensor* t) { delete t; } |
243 | |
244 | TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; } |
245 | int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); } |
246 | int64_t TF_Dim(const TF_Tensor* t, int dim_index) { |
247 | return static_cast<int64_t>(t->shape.dim_size(dim_index)); |
248 | } |
249 | size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } |
250 | void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } |
251 | |
252 | // -------------------------------------------------------------------------- |
253 | size_t TF_StringEncode(const char* src, size_t src_len, char* dst, |
254 | size_t dst_len, TF_Status* status) { |
255 | const size_t sz = TF_StringEncodedSize(src_len); |
256 | if (sz < src_len) { |
257 | status->status = InvalidArgument("src string is too large to encode" ); |
258 | return 0; |
259 | } |
260 | if (dst_len < sz) { |
261 | status->status = |
262 | InvalidArgument("dst_len (" , dst_len, ") too small to encode a " , |
263 | src_len, "-byte string" ); |
264 | return 0; |
265 | } |
266 | dst = tensorflow::core::EncodeVarint64(dst, src_len); |
267 | memcpy(dst, src, src_len); |
268 | return sz; |
269 | } |
270 | |
271 | static Status TF_StringDecode_Impl(const char* src, size_t src_len, |
272 | const char** dst, size_t* dst_len) { |
273 | tensorflow::uint64 len64 = 0; |
274 | const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64); |
275 | if (p == nullptr) { |
276 | return InvalidArgument("invalid string encoding or truncated src buffer" ); |
277 | } |
278 | if (len64 > std::numeric_limits<size_t>::max()) { |
279 | return InvalidArgument("encoded string is " , len64, |
280 | "-bytes, which is too large for this architecture" ); |
281 | } |
282 | *dst = p; |
283 | *dst_len = static_cast<size_t>(len64); |
284 | return Status::OK(); |
285 | } |
286 | |
287 | size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, |
288 | size_t* dst_len, TF_Status* status) { |
289 | status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len); |
290 | if (!status->status.ok()) return 0; |
291 | return static_cast<size_t>(*dst - src) + *dst_len; |
292 | } |
293 | |
294 | size_t TF_StringEncodedSize(size_t len) { |
295 | return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len; |
296 | } |
297 | |
298 | // -------------------------------------------------------------------------- |
299 | TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } |
300 | void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } |
301 | |
302 | void TF_SetTarget(TF_SessionOptions* options, const char* target) { |
303 | options->options.target = target; |
304 | } |
305 | |
306 | void TF_SetConfig(TF_SessionOptions* options, const void* proto, |
307 | size_t proto_len, TF_Status* status) { |
308 | if (!options->options.config.ParseFromArray(proto, proto_len)) { |
309 | status->status = InvalidArgument("Unparseable ConfigProto" ); |
310 | } |
311 | } |
312 | // -------------------------------------------------------------------------- |
313 | TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; } |
314 | |
315 | TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { |
316 | void* copy = tensorflow::port::Malloc(proto_len); |
317 | memcpy(copy, proto, proto_len); |
318 | |
319 | TF_Buffer* buf = new TF_Buffer; |
320 | buf->data = copy; |
321 | buf->length = proto_len; |
322 | buf->data_deallocator = [](void* data, size_t length) { |
323 | tensorflow::port::Free(data); |
324 | }; |
325 | return buf; |
326 | } |
327 | |
328 | void TF_DeleteBuffer(TF_Buffer* buffer) { |
329 | if (buffer->data_deallocator != nullptr) { |
330 | (*buffer->data_deallocator)(const_cast<void*>(buffer->data), |
331 | buffer->length); |
332 | } |
333 | delete buffer; |
334 | } |
335 | |
336 | TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } |
337 | |
338 | // -------------------------------------------------------------------------- |
339 | |
340 | TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, |
341 | TF_Status* status) { |
342 | Session* session; |
343 | status->status = NewSession(opt->options, &session); |
344 | if (status->status.ok()) { |
345 | return new TF_DeprecatedSession({session}); |
346 | } else { |
347 | DCHECK_EQ(nullptr, session); |
348 | return nullptr; |
349 | } |
350 | } |
351 | |
352 | void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { |
353 | status->status = s->session->Close(); |
354 | } |
355 | |
356 | void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { |
357 | status->status = Status::OK(); |
358 | delete s->session; |
359 | delete s; |
360 | } |
361 | |
362 | void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto, |
363 | size_t proto_len, TF_Status* status) { |
364 | GraphDef g; |
365 | if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { |
366 | status->status = InvalidArgument("Invalid GraphDef" ); |
367 | return; |
368 | } |
369 | status->status = s->session->Extend(g); |
370 | } |
371 | |
372 | static void DeleteArray(void* data, size_t size, void* arg) { |
373 | DCHECK_EQ(data, arg); |
374 | delete[] reinterpret_cast<char*>(arg); |
375 | } |
376 | |
377 | } // end extern "C" |
378 | |
379 | namespace tensorflow { |
380 | namespace { |
381 | |
382 | // Reset helper for converting character arrays to string vectors. |
383 | void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, |
384 | int ncontainers, TF_Status* status) { |
385 | std::vector<string> container_names(ncontainers); |
386 | for (int i = 0; i < ncontainers; ++i) { |
387 | container_names[i] = containers[i]; |
388 | } |
389 | |
390 | status->status = Reset(opt->options, container_names); |
391 | } |
392 | |
393 | // This traverses the specified nodes in topological order to verify there are |
394 | // no cycles. Starting with inputless nodes, it visits nodes whose inputs have |
395 | // all been visited, and counts the total number of visited nodes. If there is a |
396 | // cycle, nodes in the cycle will never be visited, and the visited count will |
397 | // be less than the total node count. |
398 | Status ValidateNoCycles(const Graph& g) { |
399 | // TODO(nolivia): check this on a subset of the graph instead of all of it. |
400 | // A node is ready when all of its inputs have been visited. |
401 | std::vector<const Node*> ready; |
402 | std::vector<int> pending_count(g.num_node_ids(), 0); |
403 | |
404 | for (int i = 0; i < g.num_node_ids(); ++i) { |
405 | const Node* n = g.FindNodeId(i); |
406 | if (n == nullptr) continue; |
407 | pending_count[i] = n->in_edges().size(); |
408 | if (n->IsMerge()) { |
409 | // While-loop cycles are legal cycles so we manually adjust the |
410 | // pending_count to make sure that the loop is visited. |
411 | for (const Edge* e : n->in_edges()) { |
412 | if (!e->IsControlEdge() && e->src()->IsNextIteration()) { |
413 | pending_count[i]--; |
414 | } |
415 | } |
416 | } |
417 | if (pending_count[i] == 0) { |
418 | ready.push_back(n); |
419 | } |
420 | } |
421 | |
422 | int processed = 0; |
423 | while (!ready.empty()) { |
424 | const Node* node = ready.back(); |
425 | ready.pop_back(); |
426 | ++processed; |
427 | |
428 | for (const Edge* out : node->out_edges()) { |
429 | const int output_id = out->dst()->id(); |
430 | pending_count[output_id]--; |
431 | if (pending_count[output_id] == 0) { |
432 | ready.push_back(out->dst()); |
433 | } |
434 | } |
435 | } |
436 | |
437 | if (processed < g.num_nodes()) { |
438 | std::vector<string> nodes_in_cycle; |
439 | for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; |
440 | ++i) { |
441 | if (pending_count[i] != 0) { |
442 | nodes_in_cycle.push_back(g.FindNodeId(i)->name()); |
443 | } |
444 | } |
445 | return errors::InvalidArgument( |
446 | "Graph is invalid, contains a cycle with " , g.num_nodes() - processed, |
447 | " nodes, including: " , str_util::Join(nodes_in_cycle, ", " )); |
448 | } |
449 | return Status::OK(); |
450 | } |
451 | } // namespace |
452 | } // namespace tensorflow |
453 | |
454 | extern "C" { |
455 | |
456 | void TF_Reset(const TF_SessionOptions* opt, const char** containers, |
457 | int ncontainers, TF_Status* status) { |
458 | tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status); |
459 | } |
460 | |
461 | } // end extern "C" |
462 | |
463 | namespace tensorflow { |
464 | |
465 | Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { |
466 | if (src->dtype == TF_RESOURCE) { |
467 | if (src->shape.dims() != 0) { |
468 | return InvalidArgument( |
469 | "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with " |
470 | "shape " , |
471 | src->shape.DebugString()); |
472 | } |
473 | *dst = Tensor(DT_RESOURCE, src->shape); |
474 | if (!dst->scalar<ResourceHandle>()().ParseFromString( |
475 | string(static_cast<const char*>(TF_TensorData(src)), |
476 | TF_TensorByteSize(src)))) { |
477 | return InvalidArgument( |
478 | "Malformed TF_RESOUCE tensor: unable to parse resource handle" ); |
479 | } |
480 | return Status::OK(); |
481 | } |
482 | if (src->dtype != TF_STRING) { |
483 | *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer); |
484 | return Status::OK(); |
485 | } |
486 | // TF_STRING tensors require copying since Tensor class expects a sequence of |
487 | // string objects. |
488 | const tensorflow::int64 num_elements = src->shape.num_elements(); |
489 | const char* input = reinterpret_cast<const char*>(TF_TensorData(src)); |
490 | const size_t src_size = TF_TensorByteSize(src); |
491 | if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) < |
492 | num_elements) { |
493 | return InvalidArgument( |
494 | "Malformed TF_STRING tensor; too short to hold number of elements" ); |
495 | } |
496 | const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; |
497 | const char* limit = input + src_size; |
498 | |
499 | *dst = Tensor(static_cast<DataType>(src->dtype), src->shape); |
500 | auto dstarray = dst->flat<string>(); |
501 | for (tensorflow::int64 i = 0; i < num_elements; ++i) { |
502 | tensorflow::uint64 offset = |
503 | reinterpret_cast<const tensorflow::uint64*>(input)[i]; |
504 | if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) { |
505 | return InvalidArgument("Malformed TF_STRING tensor; element " , i, |
506 | " out of range" ); |
507 | } |
508 | size_t len; |
509 | const char* p; |
510 | const char* srcp = data_start + offset; |
511 | Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len); |
512 | if (!status.ok()) return status; |
513 | dstarray(i).assign(p, len); |
514 | } |
515 | return Status::OK(); |
516 | } |
517 | |
518 | // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to |
519 | // result in a zero-sized tensor. |
520 | static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { |
521 | static char empty; |
522 | tensorflow::int64 nelems = 1; |
523 | std::vector<tensorflow::int64> dims; |
524 | for (int i = 0; i < shape.dims(); ++i) { |
525 | dims.push_back(shape.dim_size(i)); |
526 | nelems *= shape.dim_size(i); |
527 | } |
528 | CHECK_EQ(nelems, 0); |
529 | static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
530 | "64-bit int types should match in size" ); |
531 | return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()), |
532 | shape.dims(), reinterpret_cast<void*>(&empty), 0, |
533 | [](void*, size_t, void*) {}, nullptr); |
534 | } |
535 | |
536 | // Non-static for testing. |
537 | TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, |
538 | TF_Status* status) { |
539 | if (!src.IsInitialized()) { |
540 | status->status = FailedPrecondition( |
541 | "attempt to use a tensor with an uninitialized value" ); |
542 | return nullptr; |
543 | } |
544 | if (src.NumElements() == 0) { |
545 | return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape()); |
546 | } |
547 | if (src.dtype() == DT_RESOURCE) { |
548 | if (src.shape().dims() != 0) { |
549 | status->status = InvalidArgument( |
550 | "Unexpected non-scalar DT_RESOURCE tensor seen (shape: " , |
551 | src.shape().DebugString(), |
552 | "). Please file a bug at " |
553 | "https://github.com/tensorflow/tensorflow/issues/new, " |
554 | "ideally with a " |
555 | "short code snippet that reproduces this error." ); |
556 | return nullptr; |
557 | } |
558 | const string str = src.scalar<ResourceHandle>()().SerializeAsString(); |
559 | TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size()); |
560 | std::memcpy(TF_TensorData(t), str.c_str(), str.size()); |
561 | return t; |
562 | } |
563 | if (src.dtype() != DT_STRING) { |
564 | TensorBuffer* buf = TensorCApi::Buffer(src); |
565 | buf->Ref(); |
566 | return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(), |
567 | buf}; |
568 | } |
569 | // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly |
570 | // encoded sequence of strings. |
571 | |
572 | // Compute bytes needed for encoding. |
573 | size_t size = 0; |
574 | const auto& srcarray = src.flat<string>(); |
575 | for (int i = 0; i < srcarray.size(); ++i) { |
576 | const string& s = srcarray(i); |
577 | // uint64 starting_offset, TF_StringEncode-d string. |
578 | size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size()); |
579 | } |
580 | |
581 | // Encode all strings. |
582 | char* base = new char[size]; |
583 | char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size(); |
584 | char* dst = data_start; // Where next string is encoded. |
585 | size_t dst_len = size - static_cast<size_t>(data_start - base); |
586 | tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base); |
587 | for (int i = 0; i < srcarray.size(); ++i) { |
588 | *offsets = (dst - data_start); |
589 | offsets++; |
590 | const string& s = srcarray(i); |
591 | size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); |
592 | if (!status->status.ok()) { |
593 | status->status = InvalidArgument( |
594 | "invalid string tensor encoding (string #" , i, " of " , |
595 | srcarray.size(), "): " , status->status.error_message()); |
596 | delete[] base; |
597 | return nullptr; |
598 | } |
599 | dst += consumed; |
600 | dst_len -= consumed; |
601 | } |
602 | if (dst != base + size) { |
603 | status->status = InvalidArgument( |
604 | "invalid string tensor encoding (decoded " , (dst - base), |
605 | " bytes, but the tensor is encoded in " , size, " bytes" ); |
606 | delete[] base; |
607 | return nullptr; |
608 | } |
609 | |
610 | auto dims = src.shape().dim_sizes(); |
611 | std::vector<tensorflow::int64> dimvec(dims.size()); |
612 | for (size_t i = 0; i < dims.size(); ++i) { |
613 | dimvec[i] = dims[i]; |
614 | } |
615 | static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
616 | "64-bit int types should match in size" ); |
617 | return TF_NewTensor(TF_STRING, |
618 | reinterpret_cast<const int64_t*>(dimvec.data()), |
619 | dimvec.size(), base, size, DeleteArray, base); |
620 | } |
621 | |
622 | Status MessageToBuffer(const tensorflow::protobuf::Message& in, |
623 | TF_Buffer* out) { |
624 | if (out->data != nullptr) { |
625 | return InvalidArgument("Passing non-empty TF_Buffer is invalid." ); |
626 | } |
627 | const size_t proto_size = in.ByteSizeLong(); |
628 | void* buf = tensorflow::port::Malloc(proto_size); |
629 | if (buf == nullptr) { |
630 | return tensorflow::errors::ResourceExhausted( |
631 | "Failed to allocate memory to serialize message of type '" , |
632 | in.GetTypeName(), "' and size " , proto_size); |
633 | } |
634 | in.SerializeToArray(buf, proto_size); |
635 | out->data = buf; |
636 | out->length = proto_size; |
637 | out->data_deallocator = [](void* data, size_t length) { |
638 | tensorflow::port::Free(data); |
639 | }; |
640 | return Status::OK(); |
641 | } |
642 | |
643 | void RecordMutation(TF_Graph* graph, const TF_Operation& op, |
644 | const char* mutation_type) { |
645 | // If any session has already run this node_id, mark this session as |
646 | // unrunnable. |
647 | for (auto it : graph->sessions) { |
648 | mutex_lock session_lock(it.first->mu); |
649 | if (it.first->last_num_graph_nodes > op.node.id()) { |
650 | it.second = strings::StrCat( |
651 | "Operation '" , op.node.DebugString(), "' was changed by " , |
652 | mutation_type, |
653 | " after it was run by a session. This mutation will have no effect, " |
654 | "and will trigger an error in the future. Either don't modify " |
655 | "nodes after running them or create a new session." ); |
656 | } |
657 | } |
658 | } |
659 | |
660 | namespace { |
661 | |
662 | // Helper method that creates a shape handle for a shape described by dims. |
663 | tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( |
664 | tensorflow::shape_inference::InferenceContext* ic, int num_dims, |
665 | const int64_t* dims) { |
666 | if (num_dims != -1) { |
667 | std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec; |
668 | dim_vec.reserve(num_dims); |
669 | for (int i = 0; i < num_dims; ++i) { |
670 | dim_vec.push_back(ic->MakeDim(dims[i])); |
671 | } |
672 | return ic->MakeShape(dim_vec); |
673 | } else { |
674 | return ic->UnknownShape(); |
675 | } |
676 | } |
677 | |
678 | } // namespace |
679 | |
680 | void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, |
681 | int num_shapes_and_types, |
682 | const int64_t** shapes, |
683 | const int* ranks, |
684 | const TF_DataType* types, |
685 | TF_Status* status) { |
686 | Node* node = &output.oper->node; |
687 | |
688 | mutex_lock l(graph->mu); |
689 | tensorflow::shape_inference::InferenceContext* ic = |
690 | graph->refiner.GetContext(node); |
691 | if (ic == nullptr) { |
692 | status->status = |
693 | InvalidArgument("Node " , node->name(), " was not found in the graph" ); |
694 | return; |
695 | } |
696 | |
697 | auto shape_and_type_vec = |
698 | std::vector<tensorflow::shape_inference::ShapeAndType>( |
699 | num_shapes_and_types); |
700 | for (int i = 0; i < num_shapes_and_types; ++i) { |
701 | tensorflow::shape_inference::ShapeHandle shape_handle = |
702 | ShapeHandleFromDims(ic, ranks[i], shapes[i]); |
703 | shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( |
704 | shape_handle, static_cast<DataType>(types[i])); |
705 | } |
706 | |
707 | ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); |
708 | } |
709 | |
710 | // Helpers for loading a TensorFlow plugin (a .so file). |
711 | Status LoadLibrary(const char* library_filename, void** result, |
712 | const void** buf, size_t* len); |
713 | |
714 | // TODO(josh11b,mrry): Change Session to be able to use a Graph* |
715 | // directly, instead of requiring us to serialize to a GraphDef and |
716 | // call Session::Extend(). |
717 | bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { |
718 | if (session->graph != nullptr) { |
719 | // Take the graph lock before the session lock to avoid deadlock. This is |
720 | // safe since session->graph does not change. |
721 | session->graph->mu.lock(); |
722 | mutex_lock session_lock(session->mu); |
723 | const Graph& graph = session->graph->graph; |
724 | |
725 | const string& mutation_warning = session->graph->sessions[session]; |
726 | if (!mutation_warning.empty()) { |
727 | // TODO(b/74949947): turn this back into an error status |
728 | LOG(WARNING) << mutation_warning; |
729 | session->graph->sessions[session].clear(); |
730 | } |
731 | |
732 | const auto num_nodes = graph.num_node_ids(); |
733 | if (session->last_num_graph_nodes < num_nodes) { |
734 | status->status = tensorflow::ValidateNoCycles(session->graph->graph); |
735 | if (!status->status.ok()) { |
736 | session->graph->mu.unlock(); |
737 | return false; |
738 | } |
739 | |
740 | GraphDef graph_def; |
741 | *graph_def.mutable_versions() = graph.versions(); |
742 | // Fill graph_def with nodes with ids in the range |
743 | // [session->last_num_graph_nodes, num_nodes), that is the nodes |
744 | // added since the last TF_SessionRun() call. |
745 | for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { |
746 | Node* const node = graph.FindNodeId(id); |
747 | if (node != nullptr && node->IsOp()) { |
748 | NodeDef* const node_def = graph_def.add_node(); |
749 | *node_def = node->def(); |
750 | } |
751 | } |
752 | *graph_def.mutable_library() = graph.flib_def().ToProto(); |
753 | session->graph->mu.unlock(); |
754 | status->status = session->session->Extend(graph_def); |
755 | if (!status->status.ok()) { |
756 | // Contract is we always delete input_values[i]. |
757 | return false; |
758 | } |
759 | // Note: session->session is not modified if Extend() fails, so |
760 | // we only set last_num_graph_nodes if it succeeds. |
761 | session->last_num_graph_nodes = num_nodes; |
762 | } else { |
763 | session->graph->mu.unlock(); |
764 | } |
765 | } |
766 | return true; |
767 | } |
768 | |
769 | } // namespace tensorflow |
770 | |
771 | static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, |
772 | TF_Status* status) { |
773 | status->status = Status::OK(); |
774 | for (int i = 0; i < noutputs; ++i) { |
775 | c_outputs[i] = nullptr; |
776 | } |
777 | } |
778 | |
779 | static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, |
780 | std::vector<std::pair<string, Tensor>>* input_pairs, |
781 | TF_Status* status) { |
782 | const int ninputs = input_pairs->size(); |
783 | for (int i = 0; i < ninputs; ++i) { |
784 | status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); |
785 | if (!status->status.ok()) return false; |
786 | } |
787 | return true; |
788 | } |
789 | |
790 | static void TF_Run_Helper( |
791 | Session* session, const char* handle, const TF_Buffer* run_options, |
792 | // Input tensors |
793 | const std::vector<std::pair<string, Tensor>>& input_pairs, |
794 | // Output tensors |
795 | const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs, |
796 | // Target nodes |
797 | const std::vector<string>& target_oper_names, TF_Buffer* run_metadata, |
798 | TF_Status* status) { |
799 | const int noutputs = output_tensor_names.size(); |
800 | std::vector<Tensor> outputs(noutputs); |
801 | Status result; |
802 | |
803 | if (handle == nullptr) { |
804 | RunOptions run_options_proto; |
805 | if (run_options != nullptr && !run_options_proto.ParseFromArray( |
806 | run_options->data, run_options->length)) { |
807 | status->status = InvalidArgument("Unparseable RunOptions proto" ); |
808 | return; |
809 | } |
810 | if (run_metadata != nullptr && run_metadata->data != nullptr) { |
811 | status->status = |
812 | InvalidArgument("Passing non-empty run_metadata is invalid." ); |
813 | return; |
814 | } |
815 | |
816 | RunMetadata run_metadata_proto; |
817 | result = session->Run(run_options_proto, input_pairs, output_tensor_names, |
818 | target_oper_names, &outputs, &run_metadata_proto); |
819 | |
820 | // Serialize back to upstream client, who now owns the new buffer |
821 | if (run_metadata != nullptr) { |
822 | status->status = MessageToBuffer(run_metadata_proto, run_metadata); |
823 | if (!status->status.ok()) return; |
824 | } |
825 | } else { |
826 | // NOTE(zongheng): PRun does not support RunOptions yet. |
827 | result = session->PRun(handle, input_pairs, output_tensor_names, &outputs); |
828 | } |
829 | if (!result.ok()) { |
830 | status->status = result; |
831 | return; |
832 | } |
833 | |
834 | // Store results in c_outputs[] |
835 | for (int i = 0; i < noutputs; ++i) { |
836 | const Tensor& src = outputs[i]; |
837 | if (!src.IsInitialized() || src.NumElements() == 0) { |
838 | c_outputs[i] = |
839 | EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape()); |
840 | continue; |
841 | } |
842 | c_outputs[i] = TF_TensorFromTensor(src, status); |
843 | if (!status->status.ok()) return; |
844 | } |
845 | } |
846 | |
847 | extern "C" { |
848 | |
849 | void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, |
850 | // Input tensors |
851 | const char** c_input_names, TF_Tensor** c_inputs, int ninputs, |
852 | // Output tensors |
853 | const char** c_output_names, TF_Tensor** c_outputs, int noutputs, |
854 | // Target nodes |
855 | const char** c_target_oper_names, int ntargets, |
856 | TF_Buffer* run_metadata, TF_Status* status) { |
857 | TF_Run_Setup(noutputs, c_outputs, status); |
858 | std::vector<std::pair<string, Tensor>> input_pairs(ninputs); |
859 | if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; |
860 | for (int i = 0; i < ninputs; ++i) { |
861 | input_pairs[i].first = c_input_names[i]; |
862 | } |
863 | std::vector<string> output_names(noutputs); |
864 | for (int i = 0; i < noutputs; ++i) { |
865 | output_names[i] = c_output_names[i]; |
866 | } |
867 | std::vector<string> target_oper_names(ntargets); |
868 | for (int i = 0; i < ntargets; ++i) { |
869 | target_oper_names[i] = c_target_oper_names[i]; |
870 | } |
871 | TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names, |
872 | c_outputs, target_oper_names, run_metadata, status); |
873 | } |
874 | |
875 | void TF_PRunSetup(TF_DeprecatedSession* s, |
876 | // Input names |
877 | const char** c_input_names, int ninputs, |
878 | // Output names |
879 | const char** c_output_names, int noutputs, |
880 | // Target nodes |
881 | const char** c_target_oper_names, int ntargets, |
882 | const char** handle, TF_Status* status) { |
883 | *handle = nullptr; |
884 | |
885 | std::vector<string> input_names(ninputs); |
886 | std::vector<string> output_names(noutputs); |
887 | std::vector<string> target_oper_names(ntargets); |
888 | for (int i = 0; i < ninputs; ++i) { |
889 | input_names[i] = c_input_names[i]; |
890 | } |
891 | for (int i = 0; i < noutputs; ++i) { |
892 | output_names[i] = c_output_names[i]; |
893 | } |
894 | for (int i = 0; i < ntargets; ++i) { |
895 | target_oper_names[i] = c_target_oper_names[i]; |
896 | } |
897 | string new_handle; |
898 | status->status = s->session->PRunSetup(input_names, output_names, |
899 | target_oper_names, &new_handle); |
900 | if (status->status.ok()) { |
901 | char* buf = new char[new_handle.size() + 1]; |
902 | memcpy(buf, new_handle.c_str(), new_handle.size() + 1); |
903 | *handle = buf; |
904 | } |
905 | } |
906 | |
907 | void TF_PRun(TF_DeprecatedSession* s, const char* handle, |
908 | // Input tensors |
909 | const char** c_input_names, TF_Tensor** c_inputs, int ninputs, |
910 | // Output tensors |
911 | const char** c_output_names, TF_Tensor** c_outputs, int noutputs, |
912 | // Target nodes |
913 | const char** c_target_oper_names, int ntargets, |
914 | TF_Status* status) { |
915 | TF_Run_Setup(noutputs, c_outputs, status); |
916 | std::vector<std::pair<string, Tensor>> input_pairs(ninputs); |
917 | if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; |
918 | for (int i = 0; i < ninputs; ++i) { |
919 | input_pairs[i].first = c_input_names[i]; |
920 | } |
921 | |
922 | std::vector<string> output_names(noutputs); |
923 | for (int i = 0; i < noutputs; ++i) { |
924 | output_names[i] = c_output_names[i]; |
925 | } |
926 | std::vector<string> target_oper_names(ntargets); |
927 | for (int i = 0; i < ntargets; ++i) { |
928 | target_oper_names[i] = c_target_oper_names[i]; |
929 | } |
930 | TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names, |
931 | c_outputs, target_oper_names, nullptr, status); |
932 | } |
933 | |
934 | TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { |
935 | TF_Library* lib_handle = new TF_Library; |
936 | status->status = tensorflow::LoadLibrary( |
937 | library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, |
938 | &lib_handle->op_list.length); |
939 | if (!status->status.ok()) { |
940 | delete lib_handle; |
941 | return nullptr; |
942 | } |
943 | return lib_handle; |
944 | } |
945 | |
946 | TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } |
947 | |
948 | void TF_DeleteLibraryHandle(TF_Library* lib_handle) { |
949 | tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data)); |
950 | delete lib_handle; |
951 | } |
952 | |
953 | TF_Buffer* TF_GetAllOpList() { |
954 | std::vector<tensorflow::OpDef> op_defs; |
955 | tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs); |
956 | tensorflow::OpList op_list; |
957 | for (const auto& op : op_defs) { |
958 | *(op_list.add_op()) = op; |
959 | } |
960 | TF_Buffer* ret = TF_NewBuffer(); |
961 | TF_CHECK_OK(MessageToBuffer(op_list, ret)); |
962 | return ret; |
963 | } |
964 | |
965 | // -------------------------------------------------------------------------- |
966 | // ListDevices & SessionListDevices API |
967 | |
968 | void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; } |
969 | |
970 | TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { |
971 | TF_DeviceList* response = new TF_DeviceList; |
972 | status->status = session->session->ListDevices(&response->response); |
973 | return response; |
974 | } |
975 | |
976 | TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session, |
977 | TF_Status* status) { |
978 | TF_DeviceList* response = new TF_DeviceList; |
979 | status->status = session->session->ListDevices(&response->response); |
980 | return response; |
981 | } |
982 | |
983 | int TF_DeviceListCount(const TF_DeviceList* list) { |
984 | return list->response.size(); |
985 | } |
986 | |
987 | #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \ |
988 | return_type method_name(const TF_DeviceList* list, const int index, \ |
989 | TF_Status* status) { \ |
990 | if (list == nullptr) { \ |
991 | status->status = InvalidArgument("list is null!"); \ |
992 | return err_val; \ |
993 | } \ |
994 | if (index < 0 || index >= list->response.size()) { \ |
995 | status->status = InvalidArgument("index out of bounds"); \ |
996 | return err_val; \ |
997 | } \ |
998 | status->status = Status::OK(); \ |
999 | return list->response[index].accessor; \ |
1000 | } |
1001 | |
1002 | TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); |
1003 | TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), |
1004 | nullptr); |
1005 | TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); |
1006 | |
1007 | #undef TF_DEVICELIST_METHOD |
1008 | |
1009 | } // end extern "C" |
1010 | |
1011 | // -------------------------------------------------------------------------- |
1012 | // New Graph and Session API |
1013 | |
1014 | // Helper functions ----------------------------------------------------------- |
1015 | |
1016 | namespace { |
1017 | |
1018 | TF_Operation* ToOperation(Node* node) { |
1019 | return static_cast<TF_Operation*>(static_cast<void*>(node)); |
1020 | } |
1021 | |
1022 | string OutputName(const TF_Output& output) { |
1023 | return StrCat(output.oper->node.name(), ":" , output.index); |
1024 | } |
1025 | |
1026 | const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, |
1027 | const char* attr_name, |
1028 | TF_Status* status) { |
1029 | const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); |
1030 | if (attr == nullptr) { |
1031 | status->status = InvalidArgument("Operation '" , oper->node.name(), |
1032 | "' has no attr named '" , attr_name, "'." ); |
1033 | } |
1034 | return attr; |
1035 | } |
1036 | |
1037 | TensorId ToTensorId(const TF_Output& output) { |
1038 | return TensorId(output.oper->node.name(), output.index); |
1039 | } |
1040 | |
1041 | #ifndef __ANDROID__ |
1042 | std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs, |
1043 | int n) { |
1044 | std::vector<tensorflow::Output> outputs(n); |
1045 | for (int i = 0; i < n; ++i) { |
1046 | outputs[i] = |
1047 | tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index); |
1048 | } |
1049 | return outputs; |
1050 | } |
1051 | |
1052 | void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs, |
1053 | TF_Output* tf_outputs) { |
1054 | for (int i = 0; i < outputs.size(); i++) { |
1055 | tf_outputs[i].oper = ToOperation(outputs[i].node()); |
1056 | tf_outputs[i].index = outputs[i].index(); |
1057 | } |
1058 | } |
1059 | #endif // __ANDROID__ |
1060 | |
1061 | } // namespace |
1062 | |
1063 | // Shape functions ----------------------------------------------------------- |
1064 | |
1065 | void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, |
1066 | const int64_t* dims, const int num_dims, |
1067 | TF_Status* status) { |
1068 | Node* node = &output.oper->node; |
1069 | |
1070 | mutex_lock l(graph->mu); |
1071 | tensorflow::shape_inference::InferenceContext* ic = |
1072 | graph->refiner.GetContext(node); |
1073 | if (ic == nullptr) { |
1074 | status->status = |
1075 | InvalidArgument("Node " , node->name(), " was not found in the graph" ); |
1076 | return; |
1077 | } |
1078 | tensorflow::shape_inference::ShapeHandle new_shape = |
1079 | tensorflow::ShapeHandleFromDims(ic, num_dims, dims); |
1080 | status->status = graph->refiner.SetShape(node, output.index, new_shape); |
1081 | } |
1082 | |
1083 | int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, |
1084 | TF_Status* status) { |
1085 | Node* node = &output.oper->node; |
1086 | |
1087 | mutex_lock l(graph->mu); |
1088 | tensorflow::shape_inference::InferenceContext* ic = |
1089 | graph->refiner.GetContext(node); |
1090 | if (ic == nullptr) { |
1091 | status->status = |
1092 | InvalidArgument("Node " , node->name(), " was not found in the graph" ); |
1093 | return -1; |
1094 | } |
1095 | |
1096 | tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); |
1097 | |
1098 | // Unknown rank means the number of dimensions is -1. |
1099 | if (!ic->RankKnown(shape)) { |
1100 | return -1; |
1101 | } |
1102 | |
1103 | return ic->Rank(shape); |
1104 | } |
1105 | |
1106 | void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, |
1107 | int num_dims, TF_Status* status) { |
1108 | Node* node = &output.oper->node; |
1109 | |
1110 | mutex_lock l(graph->mu); |
1111 | tensorflow::shape_inference::InferenceContext* ic = |
1112 | graph->refiner.GetContext(node); |
1113 | if (ic == nullptr) { |
1114 | status->status = |
1115 | InvalidArgument("Node " , node->name(), " was not found in the graph" ); |
1116 | return; |
1117 | } |
1118 | |
1119 | tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index); |
1120 | |
1121 | int rank = -1; |
1122 | if (ic->RankKnown(shape)) { |
1123 | rank = ic->Rank(shape); |
1124 | } |
1125 | |
1126 | if (num_dims != rank) { |
1127 | status->status = InvalidArgument("Expected rank is " , num_dims, |
1128 | " but actual rank is " , rank); |
1129 | return; |
1130 | } |
1131 | |
1132 | if (num_dims == 0) { |
1133 | // Output shape is a scalar. |
1134 | return; |
1135 | } |
1136 | |
1137 | // Rank is greater than 0, so fill in the values, if known, and |
1138 | // -1 for unknown values. |
1139 | for (int i = 0; i < num_dims; ++i) { |
1140 | auto dim = ic->Dim(shape, i); |
1141 | tensorflow::int64 value = -1; |
1142 | if (ic->ValueKnown(dim)) { |
1143 | value = ic->Value(dim); |
1144 | } |
1145 | dims[i] = value; |
1146 | } |
1147 | } |
1148 | |
1149 | // TF_OperationDescription functions ------------------------------------------ |
1150 | |
1151 | extern "C" { |
1152 | |
1153 | static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, |
1154 | const char* op_type, |
1155 | const char* oper_name) |
1156 | EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { |
1157 | return new TF_OperationDescription(graph, op_type, oper_name); |
1158 | } |
1159 | |
1160 | TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type, |
1161 | const char* oper_name) { |
1162 | mutex_lock l(graph->mu); |
1163 | return TF_NewOperationLocked(graph, op_type, oper_name); |
1164 | } |
1165 | |
1166 | void TF_SetDevice(TF_OperationDescription* desc, const char* device) { |
1167 | desc->node_builder.Device(device); |
1168 | } |
1169 | |
1170 | void TF_AddInput(TF_OperationDescription* desc, TF_Output input) { |
1171 | desc->node_builder.Input(&input.oper->node, input.index); |
1172 | } |
1173 | |
1174 | void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs, |
1175 | int num_inputs) { |
1176 | std::vector<NodeBuilder::NodeOut> input_list; |
1177 | input_list.reserve(num_inputs); |
1178 | for (int i = 0; i < num_inputs; ++i) { |
1179 | input_list.emplace_back(&inputs[i].oper->node, inputs[i].index); |
1180 | } |
1181 | desc->node_builder.Input(input_list); |
1182 | } |
1183 | |
1184 | void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { |
1185 | desc->node_builder.ControlInput(&input->node); |
1186 | } |
1187 | |
1188 | void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { |
1189 | desc->colocation_constraints.emplace( |
1190 | StrCat(tensorflow::kColocationGroupPrefix, op->node.name())); |
1191 | } |
1192 | |
1193 | void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, |
1194 | const void* value, size_t length) { |
1195 | tensorflow::StringPiece s(static_cast<const char*>(value), length); |
1196 | desc->node_builder.Attr(attr_name, s); |
1197 | } |
1198 | |
1199 | void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, |
1200 | const void* const* values, const size_t* lengths, |
1201 | int num_values) { |
1202 | if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { |
1203 | desc->colocation_constraints.clear(); |
1204 | for (int i = 0; i < num_values; ++i) { |
1205 | desc->colocation_constraints.emplace(static_cast<const char*>(values[i]), |
1206 | lengths[i]); |
1207 | } |
1208 | } else { |
1209 | std::vector<tensorflow::StringPiece> v; |
1210 | v.reserve(num_values); |
1211 | for (int i = 0; i < num_values; ++i) { |
1212 | v.emplace_back(static_cast<const char*>(values[i]), lengths[i]); |
1213 | } |
1214 | desc->node_builder.Attr(attr_name, v); |
1215 | } |
1216 | } |
1217 | |
1218 | void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, |
1219 | int64_t value) { |
1220 | static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
1221 | "64-bit int types should match in size" ); |
1222 | desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value)); |
1223 | } |
1224 | |
1225 | void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name, |
1226 | const int64_t* values, int num_values) { |
1227 | static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
1228 | "64-bit int types should match in size" ); |
1229 | desc->node_builder.Attr( |
1230 | attr_name, |
1231 | ArraySlice<const tensorflow::int64>( |
1232 | reinterpret_cast<const tensorflow::int64*>(values), num_values)); |
1233 | } |
1234 | |
1235 | void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name, |
1236 | float value) { |
1237 | desc->node_builder.Attr(attr_name, value); |
1238 | } |
1239 | |
1240 | void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name, |
1241 | const float* values, int num_values) { |
1242 | desc->node_builder.Attr(attr_name, |
1243 | ArraySlice<const float>(values, num_values)); |
1244 | } |
1245 | |
1246 | void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, |
1247 | unsigned char value) { |
1248 | desc->node_builder.Attr(attr_name, static_cast<bool>(value)); |
1249 | } |
1250 | |
1251 | void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name, |
1252 | const unsigned char* values, int num_values) { |
1253 | std::unique_ptr<bool[]> b(new bool[num_values]); |
1254 | for (int i = 0; i < num_values; ++i) { |
1255 | b[i] = values[i]; |
1256 | } |
1257 | desc->node_builder.Attr(attr_name, |
1258 | ArraySlice<const bool>(b.get(), num_values)); |
1259 | } |
1260 | |
1261 | void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, |
1262 | TF_DataType value) { |
1263 | desc->node_builder.Attr(attr_name, static_cast<DataType>(value)); |
1264 | } |
1265 | |
1266 | void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, |
1267 | const TF_DataType* values, int num_values) { |
1268 | desc->node_builder.Attr( |
1269 | attr_name, ArraySlice<const DataType>( |
1270 | reinterpret_cast<const DataType*>(values), num_values)); |
1271 | } |
1272 | |
1273 | void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, |
1274 | const char* value, size_t length) { |
1275 | tensorflow::NameAttrList func_name; |
1276 | func_name.set_name(std::string(value, value + length)); |
1277 | desc->node_builder.Attr(attr_name, func_name); |
1278 | } |
1279 | |
1280 | void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, |
1281 | const int64_t* dims, int num_dims) { |
1282 | PartialTensorShape shape; |
1283 | if (num_dims >= 0) { |
1284 | static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
1285 | "64-bit int types should match in size" ); |
1286 | shape = PartialTensorShape(ArraySlice<tensorflow::int64>( |
1287 | reinterpret_cast<const tensorflow::int64*>(dims), num_dims)); |
1288 | } |
1289 | desc->node_builder.Attr(attr_name, shape); |
1290 | } |
1291 | |
1292 | void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name, |
1293 | const int64_t* const* dims, const int* num_dims, |
1294 | int num_shapes) { |
1295 | std::vector<PartialTensorShape> shapes; |
1296 | shapes.reserve(num_shapes); |
1297 | for (int i = 0; i < num_shapes; ++i) { |
1298 | if (num_dims[i] < 0) { |
1299 | shapes.emplace_back(); |
1300 | } else { |
1301 | static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
1302 | "64-bit int types should match in size" ); |
1303 | shapes.emplace_back(ArraySlice<tensorflow::int64>( |
1304 | reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i])); |
1305 | } |
1306 | } |
1307 | desc->node_builder.Attr(attr_name, shapes); |
1308 | } |
1309 | |
1310 | void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, |
1311 | const char* attr_name, const void* proto, |
1312 | size_t proto_len, TF_Status* status) { |
1313 | // shape.ParseFromArray takes an int as length, this function takes size_t, |
1314 | // make sure there is no information loss. |
1315 | if (proto_len > std::numeric_limits<int>::max()) { |
1316 | status->status = InvalidArgument( |
1317 | "proto_len (" , proto_len, |
1318 | " bytes) is too large to be parsed by the protocol buffer library" ); |
1319 | return; |
1320 | } |
1321 | TensorShapeProto shape; |
1322 | if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) { |
1323 | desc->node_builder.Attr(attr_name, shape); |
1324 | status->status = Status::OK(); |
1325 | } else { |
1326 | status->status = InvalidArgument("Unparseable TensorShapeProto" ); |
1327 | } |
1328 | } |
1329 | |
1330 | void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, |
1331 | const char* attr_name, |
1332 | const void* const* protos, |
1333 | const size_t* proto_lens, int num_shapes, |
1334 | TF_Status* status) { |
1335 | std::vector<TensorShapeProto> shapes; |
1336 | shapes.resize(num_shapes); |
1337 | for (int i = 0; i < num_shapes; ++i) { |
1338 | if (proto_lens[i] > std::numeric_limits<int>::max()) { |
1339 | status->status = InvalidArgument( |
1340 | "length of element " , i, " in the list (" , proto_lens[i], |
1341 | " bytes) is too large to be parsed by the protocol buffer library" ); |
1342 | return; |
1343 | } |
1344 | if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) { |
1345 | status->status = |
1346 | InvalidArgument("Unparseable TensorShapeProto at index " , i); |
1347 | return; |
1348 | } |
1349 | } |
1350 | desc->node_builder.Attr(attr_name, shapes); |
1351 | status->status = Status::OK(); |
1352 | } |
1353 | |
1354 | void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, |
1355 | TF_Tensor* value, TF_Status* status) { |
1356 | Tensor t; |
1357 | status->status = TF_TensorToTensor(value, &t); |
1358 | if (status->status.ok()) desc->node_builder.Attr(attr_name, t); |
1359 | } |
1360 | |
1361 | void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, |
1362 | TF_Tensor* const* values, int num_values, |
1363 | TF_Status* status) { |
1364 | status->status = Status::OK(); |
1365 | std::vector<Tensor> t; |
1366 | t.reserve(num_values); |
1367 | |
1368 | for (int i = 0; i < num_values && status->status.ok(); ++i) { |
1369 | Tensor v; |
1370 | status->status = TF_TensorToTensor(values[i], &v); |
1371 | t.emplace_back(v); |
1372 | } |
1373 | |
1374 | if (status->status.ok()) desc->node_builder.Attr(attr_name, t); |
1375 | } |
1376 | |
1377 | void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, |
1378 | const void* proto, size_t proto_len, |
1379 | TF_Status* status) { |
1380 | tensorflow::AttrValue attr_value; |
1381 | if (!attr_value.ParseFromArray(proto, proto_len)) { |
1382 | status->status = InvalidArgument("Unparseable AttrValue proto" ); |
1383 | return; |
1384 | } |
1385 | |
1386 | if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { |
1387 | if (attr_value.value_case() != tensorflow::AttrValue::kList && |
1388 | attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) { |
1389 | status->status = |
1390 | InvalidArgument("Expected \"list\" field for \"" , |
1391 | tensorflow::kColocationAttrName, "\" attribute" ); |
1392 | return; |
1393 | } |
1394 | desc->colocation_constraints.clear(); |
1395 | for (const string& location : attr_value.list().s()) { |
1396 | desc->colocation_constraints.insert(location); |
1397 | } |
1398 | } else { |
1399 | desc->node_builder.Attr(attr_name, attr_value); |
1400 | } |
1401 | |
1402 | status->status = Status::OK(); |
1403 | } |
1404 | |
1405 | static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, |
1406 | TF_Status* status) |
1407 | EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { |
1408 | Node* ret = nullptr; |
1409 | |
1410 | if (desc->graph->name_map.count(desc->node_builder.node_name())) { |
1411 | status->status = InvalidArgument("Duplicate node name in graph: '" , |
1412 | desc->node_builder.node_name(), "'" ); |
1413 | } else { |
1414 | if (!desc->colocation_constraints.empty()) { |
1415 | desc->node_builder.Attr( |
1416 | tensorflow::kColocationAttrName, |
1417 | std::vector<string>(desc->colocation_constraints.begin(), |
1418 | desc->colocation_constraints.end())); |
1419 | } |
1420 | status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); |
1421 | |
1422 | if (status->status.ok()) { |
1423 | // Run shape inference function for newly added node. |
1424 | status->status = desc->graph->refiner.AddNode(ret); |
1425 | } |
1426 | if (status->status.ok()) { |
1427 | // Add the node to the name-to-node mapping. |
1428 | desc->graph->name_map[ret->name()] = ret; |
1429 | } else if (ret != nullptr) { |
1430 | desc->graph->graph.RemoveNode(ret); |
1431 | ret = nullptr; |
1432 | } |
1433 | } |
1434 | |
1435 | delete desc; |
1436 | |
1437 | return ToOperation(ret); |
1438 | } |
1439 | |
1440 | TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, |
1441 | TF_Status* status) { |
1442 | mutex_lock l(desc->graph->mu); |
1443 | return TF_FinishOperationLocked(desc, status); |
1444 | } |
1445 | |
1446 | // TF_Operation functions |
1447 | // ---------------------------------------------------------- |
1448 | |
1449 | const char* TF_OperationName(TF_Operation* oper) { |
1450 | return oper->node.name().c_str(); |
1451 | } |
1452 | |
1453 | const char* TF_OperationOpType(TF_Operation* oper) { |
1454 | return oper->node.type_string().c_str(); |
1455 | } |
1456 | |
1457 | const char* TF_OperationDevice(TF_Operation* oper) { |
1458 | return oper->node.requested_device().c_str(); |
1459 | } |
1460 | |
1461 | int TF_OperationNumOutputs(TF_Operation* oper) { |
1462 | return oper->node.num_outputs(); |
1463 | } |
1464 | |
1465 | TF_DataType TF_OperationOutputType(TF_Output oper_out) { |
1466 | return static_cast<TF_DataType>( |
1467 | oper_out.oper->node.output_type(oper_out.index)); |
1468 | } |
1469 | |
1470 | int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, |
1471 | TF_Status* status) { |
1472 | NameRangeMap name_ranges; |
1473 | status->status = |
1474 | NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); |
1475 | if (!status->status.ok()) return -1; |
1476 | auto iter = name_ranges.find(arg_name); |
1477 | if (iter == name_ranges.end()) { |
1478 | status->status = InvalidArgument("Input arg '" , arg_name, "' not found" ); |
1479 | return -1; |
1480 | } |
1481 | return iter->second.second - iter->second.first; |
1482 | } |
1483 | |
1484 | int TF_OperationNumInputs(TF_Operation* oper) { |
1485 | return oper->node.num_inputs(); |
1486 | } |
1487 | |
1488 | TF_DataType TF_OperationInputType(TF_Input oper_in) { |
1489 | return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index)); |
1490 | } |
1491 | |
1492 | int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, |
1493 | TF_Status* status) { |
1494 | NameRangeMap name_ranges; |
1495 | status->status = |
1496 | NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); |
1497 | if (!status->status.ok()) return -1; |
1498 | auto iter = name_ranges.find(arg_name); |
1499 | if (iter == name_ranges.end()) { |
1500 | status->status = InvalidArgument("Input arg '" , arg_name, "' not found" ); |
1501 | return -1; |
1502 | } |
1503 | return iter->second.second - iter->second.first; |
1504 | } |
1505 | |
1506 | TF_Output TF_OperationInput(TF_Input oper_in) { |
1507 | const tensorflow::Edge* edge; |
1508 | Status s = oper_in.oper->node.input_edge(oper_in.index, &edge); |
1509 | if (!s.ok()) { |
1510 | return {nullptr, -1}; |
1511 | } |
1512 | |
1513 | return {ToOperation(edge->src()), edge->src_output()}; |
1514 | } |
1515 | |
1516 | int TF_OperationOutputNumConsumers(TF_Output oper_out) { |
1517 | int count = 0; |
1518 | for (const auto* edge : oper_out.oper->node.out_edges()) { |
1519 | if (edge->src_output() == oper_out.index) { |
1520 | ++count; |
1521 | } |
1522 | } |
1523 | return count; |
1524 | } |
1525 | |
1526 | int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, |
1527 | int max_consumers) { |
1528 | int count = 0; |
1529 | for (const auto* edge : oper_out.oper->node.out_edges()) { |
1530 | if (edge->src_output() == oper_out.index) { |
1531 | if (count < max_consumers) { |
1532 | consumers[count] = {ToOperation(edge->dst()), edge->dst_input()}; |
1533 | } |
1534 | ++count; |
1535 | } |
1536 | } |
1537 | return count; |
1538 | } |
1539 | |
1540 | int TF_OperationNumControlInputs(TF_Operation* oper) { |
1541 | int count = 0; |
1542 | for (const auto* edge : oper->node.in_edges()) { |
1543 | if (edge->IsControlEdge() && !edge->src()->IsSource()) { |
1544 | ++count; |
1545 | } |
1546 | } |
1547 | return count; |
1548 | } |
1549 | |
1550 | int TF_OperationGetControlInputs(TF_Operation* oper, |
1551 | TF_Operation** control_inputs, |
1552 | int max_control_inputs) { |
1553 | int count = 0; |
1554 | for (const auto* edge : oper->node.in_edges()) { |
1555 | if (edge->IsControlEdge() && !edge->src()->IsSource()) { |
1556 | if (count < max_control_inputs) { |
1557 | control_inputs[count] = ToOperation(edge->src()); |
1558 | } |
1559 | ++count; |
1560 | } |
1561 | } |
1562 | return count; |
1563 | } |
1564 | |
1565 | int TF_OperationNumControlOutputs(TF_Operation* oper) { |
1566 | int count = 0; |
1567 | for (const auto* edge : oper->node.out_edges()) { |
1568 | if (edge->IsControlEdge() && !edge->dst()->IsSink()) { |
1569 | ++count; |
1570 | } |
1571 | } |
1572 | return count; |
1573 | } |
1574 | |
1575 | int TF_OperationGetControlOutputs(TF_Operation* oper, |
1576 | TF_Operation** control_outputs, |
1577 | int max_control_outputs) { |
1578 | int count = 0; |
1579 | for (const auto* edge : oper->node.out_edges()) { |
1580 | if (edge->IsControlEdge() && !edge->dst()->IsSink()) { |
1581 | if (count < max_control_outputs) { |
1582 | control_outputs[count] = ToOperation(edge->dst()); |
1583 | } |
1584 | ++count; |
1585 | } |
1586 | } |
1587 | return count; |
1588 | } |
1589 | |
1590 | TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, |
1591 | const char* attr_name, |
1592 | TF_Status* status) { |
1593 | TF_AttrMetadata metadata; |
1594 | const auto* attr = GetAttrValue(oper, attr_name, status); |
1595 | if (!status->status.ok()) return metadata; |
1596 | switch (attr->value_case()) { |
1597 | #define SINGLE_CASE(kK, attr_type, size_expr) \ |
1598 | case tensorflow::AttrValue::kK: \ |
1599 | metadata.is_list = 0; \ |
1600 | metadata.list_size = -1; \ |
1601 | metadata.type = attr_type; \ |
1602 | metadata.total_size = size_expr; \ |
1603 | break; |
1604 | |
1605 | SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); |
1606 | SINGLE_CASE(kI, TF_ATTR_INT, -1); |
1607 | SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); |
1608 | SINGLE_CASE(kB, TF_ATTR_BOOL, -1); |
1609 | SINGLE_CASE(kType, TF_ATTR_TYPE, -1); |
1610 | SINGLE_CASE(kShape, TF_ATTR_SHAPE, |
1611 | attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); |
1612 | SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); |
1613 | #undef SINGLE_CASE |
1614 | |
1615 | case tensorflow::AttrValue::kList: |
1616 | metadata.is_list = 1; |
1617 | metadata.list_size = 0; |
1618 | metadata.total_size = -1; |
1619 | #define LIST_CASE(field, attr_type, ...) \ |
1620 | if (attr->list().field##_size() > 0) { \ |
1621 | metadata.type = attr_type; \ |
1622 | metadata.list_size = attr->list().field##_size(); \ |
1623 | __VA_ARGS__; \ |
1624 | break; \ |
1625 | } |
1626 | |
1627 | LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0; |
1628 | for (int i = 0; i < attr->list().s_size(); |
1629 | ++i) { metadata.total_size += attr->list().s(i).size(); }); |
1630 | LIST_CASE(i, TF_ATTR_INT); |
1631 | LIST_CASE(f, TF_ATTR_FLOAT); |
1632 | LIST_CASE(b, TF_ATTR_BOOL); |
1633 | LIST_CASE(type, TF_ATTR_TYPE); |
1634 | LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0; |
1635 | for (int i = 0; i < attr->list().shape_size(); ++i) { |
1636 | const auto& s = attr->list().shape(i); |
1637 | metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); |
1638 | }); |
1639 | LIST_CASE(tensor, TF_ATTR_TENSOR); |
1640 | LIST_CASE(tensor, TF_ATTR_FUNC); |
1641 | #undef LIST_CASE |
1642 | // All lists empty, determine the type from the OpDef. |
1643 | if (metadata.list_size == 0) { |
1644 | for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { |
1645 | const auto& a = oper->node.op_def().attr(i); |
1646 | if (a.name().compare(attr_name) != 0) continue; |
1647 | const string& typestr = a.type(); |
1648 | if (typestr == "list(string)" ) { |
1649 | metadata.type = TF_ATTR_STRING; |
1650 | } else if (typestr == "list(int)" ) { |
1651 | metadata.type = TF_ATTR_INT; |
1652 | } else if (typestr == "list(float)" ) { |
1653 | metadata.type = TF_ATTR_FLOAT; |
1654 | } else if (typestr == "list(bool)" ) { |
1655 | metadata.type = TF_ATTR_BOOL; |
1656 | } else if (typestr == "list(type)" ) { |
1657 | metadata.type = TF_ATTR_TYPE; |
1658 | } else if (typestr == "list(shape)" ) { |
1659 | metadata.type = TF_ATTR_SHAPE; |
1660 | } else if (typestr == "list(tensor)" ) { |
1661 | metadata.type = TF_ATTR_TENSOR; |
1662 | } else if (typestr == "list(func)" ) { |
1663 | metadata.type = TF_ATTR_FUNC; |
1664 | } else { |
1665 | status->status = InvalidArgument( |
1666 | "Attribute '" , attr_name, |
1667 | "' has an empty value of an unrecognized type '" , typestr, "'" ); |
1668 | return metadata; |
1669 | } |
1670 | } |
1671 | } |
1672 | break; |
1673 | |
1674 | case tensorflow::AttrValue::kPlaceholder: |
1675 | metadata.is_list = 0; |
1676 | metadata.list_size = -1; |
1677 | metadata.type = TF_ATTR_PLACEHOLDER; |
1678 | metadata.total_size = -1; |
1679 | break; |
1680 | |
1681 | case tensorflow::AttrValue::kFunc: |
1682 | metadata.is_list = 0; |
1683 | metadata.list_size = -1; |
1684 | metadata.type = TF_ATTR_FUNC; |
1685 | metadata.total_size = -1; |
1686 | break; |
1687 | |
1688 | case tensorflow::AttrValue::VALUE_NOT_SET: |
1689 | status->status = |
1690 | InvalidArgument("Attribute '" , attr_name, "' has no value set" ); |
1691 | break; |
1692 | } |
1693 | return metadata; |
1694 | } |
1695 | |
1696 | void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, |
1697 | void* value, size_t max_length, |
1698 | TF_Status* status) { |
1699 | const auto* attr = GetAttrValue(oper, attr_name, status); |
1700 | if (!status->status.ok()) return; |
1701 | if (attr->value_case() != tensorflow::AttrValue::kS) { |
1702 | status->status = |
1703 | InvalidArgument("Attribute '" , attr_name, "' is not a string" ); |
1704 | return; |
1705 | } |
1706 | if (max_length <= 0) { |
1707 | return; |
1708 | } |
1709 | const auto& s = attr->s(); |
1710 | std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length)); |
1711 | } |
1712 | |
1713 | void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, |
1714 | void** values, size_t* lengths, |
1715 | int max_values, void* storage, |
1716 | size_t storage_size, TF_Status* status) { |
1717 | const auto* attr = GetAttrValue(oper, attr_name, status); |
1718 | if (!status->status.ok()) return; |
1719 | if (attr->value_case() != tensorflow::AttrValue::kList) { |
1720 | status->status = |
1721 | InvalidArgument("Value for '" , attr_name, "' is not a list" ); |
1722 | return; |
1723 | } |
1724 | const auto len = std::min(max_values, attr->list().s_size()); |
1725 | char* p = static_cast<char*>(storage); |
1726 | for (int i = 0; i < len; ++i) { |
1727 | const string& s = attr->list().s(i); |
1728 | values[i] = p; |
1729 | lengths[i] = s.size(); |
1730 | if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) { |
1731 | status->status = InvalidArgument( |
1732 | "Not enough storage to hold the requested list of strings" ); |
1733 | return; |
1734 | } |
1735 | memcpy(values[i], s.data(), s.size()); |
1736 | p += s.size(); |
1737 | } |
1738 | } |
1739 | |
1740 | #define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ |
1741 | void func(TF_Operation* oper, const char* attr_name, c_type* value, \ |
1742 | TF_Status* status) { \ |
1743 | cpp_type v; \ |
1744 | status->status = \ |
1745 | tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ |
1746 | *value = static_cast<c_type>(v); \ |
1747 | } \ |
1748 | void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ |
1749 | int max_values, TF_Status* status) { \ |
1750 | const auto* attr = GetAttrValue(oper, attr_name, status); \ |
1751 | if (!status->status.ok()) return; \ |
1752 | if (attr->value_case() != tensorflow::AttrValue::kList) { \ |
1753 | status->status = \ |
1754 | InvalidArgument("Value for '", attr_name, "' is not a list."); \ |
1755 | return; \ |
1756 | } \ |
1757 | const auto len = std::min(max_values, attr->list().list_field##_size()); \ |
1758 | for (int i = 0; i < len; ++i) { \ |
1759 | values[i] = static_cast<c_type>(attr->list().list_field(i)); \ |
1760 | } \ |
1761 | } |
1762 | DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); |
1763 | DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); |
1764 | DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b); |
1765 | DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); |
1766 | #undef DEFINE_GETATTR |
1767 | |
1768 | void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, |
1769 | int64_t* value, int num_dims, TF_Status* status) { |
1770 | PartialTensorShape shape; |
1771 | status->status = |
1772 | tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); |
1773 | if (!status->status.ok()) return; |
1774 | auto len = std::min(shape.dims(), num_dims); |
1775 | for (int i = 0; i < len; ++i) { |
1776 | value[i] = shape.dim_size(i); |
1777 | } |
1778 | } |
1779 | |
1780 | void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, |
1781 | int64_t** values, int* num_dims, |
1782 | int max_values, int64_t* storage, |
1783 | int storage_size, TF_Status* status) { |
1784 | std::vector<PartialTensorShape> shapes; |
1785 | status->status = |
1786 | tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); |
1787 | if (!status->status.ok()) return; |
1788 | auto len = std::min(static_cast<int>(shapes.size()), max_values); |
1789 | int64_t* p = storage; |
1790 | int storage_left = storage_size; |
1791 | for (int i = 0; i < len; ++i) { |
1792 | // shapes[i].dims() == -1 for shapes with an unknown rank. |
1793 | int64_t n = shapes[i].dims(); |
1794 | num_dims[i] = n; |
1795 | values[i] = p; |
1796 | if (n < 0) { |
1797 | continue; |
1798 | } |
1799 | if (storage_left < n) { |
1800 | status->status = InvalidArgument( |
1801 | "Not enough storage to hold the requested list of shapes" ); |
1802 | return; |
1803 | } |
1804 | storage_left -= n; |
1805 | for (int j = 0; j < n; ++j, ++p) { |
1806 | *p = shapes[i].dim_size(j); |
1807 | } |
1808 | } |
1809 | } |
1810 | |
1811 | void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, |
1812 | const char* attr_name, |
1813 | TF_Buffer* value, TF_Status* status) { |
1814 | const auto* attr = GetAttrValue(oper, attr_name, status); |
1815 | if (!status->status.ok()) return; |
1816 | if (attr->value_case() != tensorflow::AttrValue::kShape) { |
1817 | status->status = |
1818 | InvalidArgument("Value for '" , attr_name, "' is not a shape." ); |
1819 | return; |
1820 | } |
1821 | status->status = MessageToBuffer(attr->shape(), value); |
1822 | } |
1823 | |
1824 | void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, |
1825 | const char* attr_name, |
1826 | TF_Buffer** values, int max_values, |
1827 | TF_Status* status) { |
1828 | const auto* attr = GetAttrValue(oper, attr_name, status); |
1829 | if (!status->status.ok()) return; |
1830 | if (attr->value_case() != tensorflow::AttrValue::kList) { |
1831 | status->status = |
1832 | InvalidArgument("Value for '" , attr_name, "' is not a list" ); |
1833 | return; |
1834 | } |
1835 | const auto len = std::min(max_values, attr->list().shape_size()); |
1836 | for (int i = 0; i < len; ++i) { |
1837 | values[i] = TF_NewBuffer(); |
1838 | status->status = MessageToBuffer(attr->list().shape(i), values[i]); |
1839 | if (!status->status.ok()) { |
1840 | // Delete everything allocated to far, the operation has failed. |
1841 | for (int j = 0; j <= i; ++j) { |
1842 | TF_DeleteBuffer(values[j]); |
1843 | } |
1844 | return; |
1845 | } |
1846 | } |
1847 | } |
1848 | |
1849 | void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, |
1850 | TF_Tensor** value, TF_Status* status) { |
1851 | *value = nullptr; |
1852 | Tensor t; |
1853 | status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); |
1854 | if (!status->status.ok()) return; |
1855 | *value = TF_TensorFromTensor(t, status); |
1856 | } |
1857 | |
1858 | void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, |
1859 | TF_Tensor** values, int max_values, |
1860 | TF_Status* status) { |
1861 | std::vector<Tensor> ts; |
1862 | status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); |
1863 | if (!status->status.ok()) return; |
1864 | const auto len = std::min(max_values, static_cast<int>(ts.size())); |
1865 | for (int i = 0; i < len; ++i) { |
1866 | values[i] = TF_TensorFromTensor(ts[i], status); |
1867 | } |
1868 | } |
1869 | |
1870 | void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, |
1871 | TF_Buffer* output_attr_value, |
1872 | TF_Status* status) { |
1873 | const auto* attr = GetAttrValue(oper, attr_name, status); |
1874 | if (!status->status.ok()) return; |
1875 | status->status = MessageToBuffer(*attr, output_attr_value); |
1876 | } |
1877 | |
1878 | void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, |
1879 | TF_Status* status) { |
1880 | status->status = MessageToBuffer(oper->node.def(), output_node_def); |
1881 | } |
1882 | |
1883 | // TF_Graph functions --------------------------------------------------------- |
1884 | |
1885 | TF_Graph::TF_Graph() |
1886 | : graph(tensorflow::OpRegistry::Global()), |
1887 | refiner(graph.versions().producer(), graph.op_registry()), |
1888 | delete_requested(false), |
1889 | parent(nullptr), |
1890 | parent_inputs(nullptr) {} |
1891 | |
1892 | TF_Graph* TF_NewGraph() { return new TF_Graph; } |
1893 | |
1894 | void TF_DeleteGraph(TF_Graph* g) { |
1895 | g->mu.lock(); |
1896 | g->delete_requested = true; |
1897 | const bool del = g->sessions.empty(); |
1898 | g->mu.unlock(); |
1899 | if (del) delete g; |
1900 | } |
1901 | |
1902 | TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) { |
1903 | mutex_lock l(graph->mu); |
1904 | auto iter = graph->name_map.find(oper_name); |
1905 | if (iter == graph->name_map.end()) { |
1906 | return nullptr; |
1907 | } else { |
1908 | return ToOperation(iter->second); |
1909 | } |
1910 | } |
1911 | |
1912 | TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) { |
1913 | if (*pos == 0) { |
1914 | // Advance past the first sentinel nodes in every graph (the source & sink). |
1915 | *pos += 2; |
1916 | } else { |
1917 | // Advance to the next node. |
1918 | *pos += 1; |
1919 | } |
1920 | |
1921 | mutex_lock l(graph->mu); |
1922 | while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) { |
1923 | Node* node = graph->graph.FindNodeId(*pos); |
1924 | // FindNodeId() returns nullptr for nodes that have been deleted. |
1925 | // We aren't currently allowing nodes to be deleted, but it is safer |
1926 | // to still check. |
1927 | if (node != nullptr) return ToOperation(node); |
1928 | *pos += 1; |
1929 | } |
1930 | |
1931 | // No more nodes. |
1932 | return nullptr; |
1933 | } |
1934 | |
1935 | void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, |
1936 | TF_Status* status) { |
1937 | GraphDef def; |
1938 | { |
1939 | mutex_lock l(graph->mu); |
1940 | graph->graph.ToGraphDef(&def); |
1941 | } |
1942 | status->status = MessageToBuffer(def, output_graph_def); |
1943 | } |
1944 | |
1945 | void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, |
1946 | TF_Buffer* output_op_def, TF_Status* status) { |
1947 | const OpDef* op_def; |
1948 | { |
1949 | mutex_lock l(graph->mu); |
1950 | status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); |
1951 | if (!status->status.ok()) return; |
1952 | } |
1953 | status->status = MessageToBuffer(*op_def, output_op_def); |
1954 | } |
1955 | |
1956 | void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def, |
1957 | TF_Status* status) { |
1958 | VersionDef versions; |
1959 | { |
1960 | mutex_lock l(graph->mu); |
1961 | versions = graph->graph.versions(); |
1962 | } |
1963 | status->status = MessageToBuffer(versions, output_version_def); |
1964 | } |
1965 | |
1966 | TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { |
1967 | return new TF_ImportGraphDefOptions; |
1968 | } |
1969 | void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) { |
1970 | delete opts; |
1971 | } |
1972 | void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, |
1973 | const char* prefix) { |
1974 | opts->opts.prefix = prefix; |
1975 | } |
1976 | |
1977 | void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, |
1978 | unsigned char uniquify_names) { |
1979 | opts->opts.uniquify_names = uniquify_names; |
1980 | } |
1981 | |
1982 | void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, |
1983 | unsigned char uniquify_prefix) { |
1984 | opts->opts.uniquify_prefix = uniquify_prefix; |
1985 | } |
1986 | |
1987 | void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, |
1988 | const char* src_name, |
1989 | int src_index, TF_Output dst) { |
1990 | opts->tensor_id_data.push_back(src_name); |
1991 | const string& src_name_str = opts->tensor_id_data.back(); |
1992 | // We don't need to store dst's name in tensor_id_data, since `dst` must |
1993 | // outlive the ImportGraphDef call. |
1994 | opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); |
1995 | } |
1996 | |
1997 | void TF_ImportGraphDefOptionsRemapControlDependency( |
1998 | TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) { |
1999 | opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] = |
2000 | TensorId(dst->node.name(), tensorflow::Graph::kControlSlot); |
2001 | } |
2002 | |
2003 | extern void TF_ImportGraphDefOptionsAddControlDependency( |
2004 | TF_ImportGraphDefOptions* opts, TF_Operation* oper) { |
2005 | opts->opts.control_dependencies.push_back(oper->node.name()); |
2006 | } |
2007 | |
2008 | void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, |
2009 | const char* oper_name, int index) { |
2010 | opts->tensor_id_data.push_back(oper_name); |
2011 | const string& oper_name_str = opts->tensor_id_data.back(); |
2012 | opts->opts.return_tensors.emplace_back(oper_name_str, index); |
2013 | } |
2014 | |
2015 | int TF_ImportGraphDefOptionsNumReturnOutputs( |
2016 | const TF_ImportGraphDefOptions* opts) { |
2017 | return opts->opts.return_tensors.size(); |
2018 | } |
2019 | |
2020 | void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, |
2021 | const char* oper_name) { |
2022 | opts->opts.return_nodes.push_back(oper_name); |
2023 | } |
2024 | |
2025 | int TF_ImportGraphDefOptionsNumReturnOperations( |
2026 | const TF_ImportGraphDefOptions* opts) { |
2027 | return opts->opts.return_nodes.size(); |
2028 | } |
2029 | |
2030 | void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, |
2031 | int* num_outputs, |
2032 | TF_Output** outputs) { |
2033 | *num_outputs = results->return_tensors.size(); |
2034 | *outputs = results->return_tensors.data(); |
2035 | } |
2036 | |
2037 | void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, |
2038 | int* num_opers, |
2039 | TF_Operation*** opers) { |
2040 | *num_opers = results->return_nodes.size(); |
2041 | *opers = results->return_nodes.data(); |
2042 | } |
2043 | |
2044 | void TF_ImportGraphDefResultsMissingUnusedInputMappings( |
2045 | TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, |
2046 | const char*** src_names, int** src_indexes) { |
2047 | *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); |
2048 | *src_names = results->missing_unused_key_names.data(); |
2049 | *src_indexes = results->missing_unused_key_indexes.data(); |
2050 | } |
2051 | |
2052 | void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { |
2053 | delete results; |
2054 | } |
2055 | |
2056 | static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, |
2057 | const TF_ImportGraphDefOptions* opts, |
2058 | TF_ImportGraphDefResults* tf_results, |
2059 | TF_Status* status) |
2060 | EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { |
2061 | const int last_node_id = graph->graph.num_node_ids(); |
2062 | tensorflow::ImportGraphDefResults results; |
2063 | status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, |
2064 | &graph->refiner, &results); |
2065 | if (!status->status.ok()) return; |
2066 | |
2067 | // Add new nodes to name_map |
2068 | for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { |
2069 | auto* node = graph->graph.FindNodeId(i); |
2070 | if (node != nullptr) graph->name_map[node->name()] = node; |
2071 | } |
2072 | |
2073 | // Populate return_tensors |
2074 | DCHECK(tf_results->return_tensors.empty()); |
2075 | tf_results->return_tensors.resize(results.return_tensors.size()); |
2076 | for (int i = 0; i < results.return_tensors.size(); ++i) { |
2077 | tf_results->return_tensors[i].oper = |
2078 | ToOperation(results.return_tensors[i].first); |
2079 | tf_results->return_tensors[i].index = results.return_tensors[i].second; |
2080 | } |
2081 | |
2082 | // Populate return_nodes |
2083 | DCHECK(tf_results->return_nodes.empty()); |
2084 | tf_results->return_nodes.resize(results.return_nodes.size()); |
2085 | for (int i = 0; i < results.return_nodes.size(); ++i) { |
2086 | tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); |
2087 | } |
2088 | |
2089 | // Populate missing unused map keys |
2090 | DCHECK(tf_results->missing_unused_key_names.empty()); |
2091 | DCHECK(tf_results->missing_unused_key_indexes.empty()); |
2092 | DCHECK(tf_results->missing_unused_key_names_data.empty()); |
2093 | |
2094 | size_t size = results.missing_unused_input_map_keys.size(); |
2095 | tf_results->missing_unused_key_names.resize(size); |
2096 | tf_results->missing_unused_key_indexes.resize(size); |
2097 | |
2098 | for (int i = 0; i < size; ++i) { |
2099 | TensorId id = results.missing_unused_input_map_keys[i]; |
2100 | tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); |
2101 | tf_results->missing_unused_key_names[i] = |
2102 | tf_results->missing_unused_key_names_data.back().c_str(); |
2103 | tf_results->missing_unused_key_indexes[i] = id.second; |
2104 | } |
2105 | } |
2106 | |
2107 | TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( |
2108 | TF_Graph* graph, const TF_Buffer* graph_def, |
2109 | const TF_ImportGraphDefOptions* options, TF_Status* status) { |
2110 | GraphDef def; |
2111 | if (!def.ParseFromArray(graph_def->data, graph_def->length)) { |
2112 | status->status = InvalidArgument("Invalid GraphDef" ); |
2113 | return nullptr; |
2114 | } |
2115 | auto results = new TF_ImportGraphDefResults(); |
2116 | mutex_lock l(graph->mu); |
2117 | GraphImportGraphDefLocked(graph, def, options, results, status); |
2118 | if (!status->status.ok()) { |
2119 | delete results; |
2120 | return nullptr; |
2121 | } |
2122 | return results; |
2123 | } |
2124 | |
2125 | void TF_GraphImportGraphDefWithReturnOutputs( |
2126 | TF_Graph* graph, const TF_Buffer* graph_def, |
2127 | const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, |
2128 | int num_return_outputs, TF_Status* status) { |
2129 | if (num_return_outputs != options->opts.return_tensors.size()) { |
2130 | status->status = InvalidArgument("Expected 'num_return_outputs' to be " , |
2131 | options->opts.return_tensors.size(), |
2132 | ", got " , num_return_outputs); |
2133 | return; |
2134 | } |
2135 | if (num_return_outputs > 0 && return_outputs == nullptr) { |
2136 | status->status = InvalidArgument( |
2137 | "'return_outputs' must be preallocated to length " , num_return_outputs); |
2138 | return; |
2139 | } |
2140 | GraphDef def; |
2141 | if (!def.ParseFromArray(graph_def->data, graph_def->length)) { |
2142 | status->status = InvalidArgument("Invalid GraphDef" ); |
2143 | return; |
2144 | } |
2145 | TF_ImportGraphDefResults results; |
2146 | mutex_lock l(graph->mu); |
2147 | GraphImportGraphDefLocked(graph, def, options, &results, status); |
2148 | DCHECK_EQ(results.return_tensors.size(), num_return_outputs); |
2149 | memcpy(return_outputs, results.return_tensors.data(), |
2150 | num_return_outputs * sizeof(TF_Output)); |
2151 | } |
2152 | |
2153 | void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, |
2154 | const TF_ImportGraphDefOptions* options, |
2155 | TF_Status* status) { |
2156 | TF_ImportGraphDefResults* results = |
2157 | TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); |
2158 | TF_DeleteImportGraphDefResults(results); |
2159 | } |
2160 | |
2161 | // While loop functions ------------------------------------------------------- |
2162 | |
2163 | namespace { |
2164 | |
2165 | #ifndef __ANDROID__ |
2166 | |
2167 | // Creates a placeholder representing an input to the cond or body graph. |
2168 | // TODO(skyewm): remove these from final graph |
2169 | bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, |
2170 | TF_Output* input, TF_Status* status) { |
2171 | TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder" , name); |
2172 | TF_SetAttrType(desc, "dtype" , TF_OperationOutputType(parent_input)); |
2173 | // TODO(skyewm): set placeholder shape |
2174 | TF_Operation* oper = TF_FinishOperation(desc, status); |
2175 | if (!status->status.ok()) return false; |
2176 | *input = {oper, 0}; |
2177 | return true; |
2178 | } |
2179 | |
2180 | // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input |
2181 | // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix` |
2182 | // will be prepended to copied node names. `control_deps` are nodes in |
2183 | // `dst_graph` that the copied `src_graph` nodes will have control dependencies |
2184 | // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes |
2185 | // in `dst_graph` will be returned. `return_nodes` must be non-null. |
2186 | Status CopyGraph(Graph* src_graph, Graph* dst_graph, |
2187 | tensorflow::ShapeRefiner* dst_refiner, |
2188 | const TF_Output* src_inputs, |
2189 | const std::vector<tensorflow::Output>& dst_inputs, |
2190 | const string& prefix, |
2191 | const std::vector<tensorflow::Operation>& control_deps, |
2192 | const TF_Output* nodes_to_return, int nreturn_nodes, |
2193 | std::vector<tensorflow::Output>* return_nodes) { |
2194 | DCHECK(return_nodes != nullptr); |
2195 | GraphDef gdef; |
2196 | src_graph->ToGraphDef(&gdef); |
2197 | |
2198 | tensorflow::ImportGraphDefOptions opts; |
2199 | opts.prefix = prefix; |
2200 | |
2201 | for (int i = 0; i < dst_inputs.size(); ++i) { |
2202 | opts.input_map[ToTensorId(src_inputs[i])] = |
2203 | TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index()); |
2204 | } |
2205 | opts.skip_mapped_nodes = true; |
2206 | |
2207 | for (const tensorflow::Operation& op : control_deps) { |
2208 | opts.control_dependencies.push_back(op.node()->name()); |
2209 | } |
2210 | |
2211 | for (int i = 0; i < nreturn_nodes; ++i) { |
2212 | opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); |
2213 | } |
2214 | |
2215 | // TODO(skyewm): change to OutputTensor |
2216 | tensorflow::ImportGraphDefResults results; |
2217 | TF_RETURN_IF_ERROR( |
2218 | ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); |
2219 | |
2220 | for (const auto& pair : results.return_tensors) { |
2221 | return_nodes->emplace_back(pair.first, pair.second); |
2222 | } |
2223 | return Status::OK(); |
2224 | } |
2225 | |
2226 | bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) { |
2227 | if (params.cond_graph == nullptr || params.body_graph == nullptr || |
2228 | params.cond_graph->parent == nullptr || |
2229 | params.cond_graph->parent != params.body_graph->parent || |
2230 | params.cond_graph->parent_inputs != params.body_graph->parent_inputs || |
2231 | params.ninputs <= 0 || params.cond_inputs == nullptr || |
2232 | params.body_inputs == nullptr || params.body_outputs == nullptr) { |
2233 | s->status = InvalidArgument( |
2234 | "TF_WhileParams must be created by successful TF_NewWhile() call" ); |
2235 | return false; |
2236 | } |
2237 | return true; |
2238 | } |
2239 | |
2240 | bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) { |
2241 | if (params.cond_output.oper == nullptr) { |
2242 | s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set" ); |
2243 | return false; |
2244 | } |
2245 | for (int i = 0; i < params.ninputs; ++i) { |
2246 | if (params.body_outputs[i].oper == nullptr) { |
2247 | s->status = InvalidArgument("TF_WhileParams `body_outputs[" , i, "]` " , |
2248 | "field isn't set" ); |
2249 | return false; |
2250 | } |
2251 | } |
2252 | if (params.name == nullptr) { |
2253 | s->status = InvalidArgument("TF_WhileParams `name` field is null" ); |
2254 | return false; |
2255 | } |
2256 | return true; |
2257 | } |
2258 | |
2259 | #endif // __ANDROID__ |
2260 | |
2261 | void FreeWhileResources(const TF_WhileParams* params) { |
2262 | TF_DeleteGraph(params->cond_graph); |
2263 | TF_DeleteGraph(params->body_graph); |
2264 | delete[] params->cond_inputs; |
2265 | delete[] params->body_inputs; |
2266 | delete[] params->body_outputs; |
2267 | } |
2268 | |
2269 | TF_WhileParams EmptyWhileParams() { |
2270 | return {0, nullptr, nullptr, {nullptr, 0}, |
2271 | nullptr, nullptr, nullptr, nullptr}; |
2272 | } |
2273 | |
2274 | } // namespace |
2275 | |
2276 | TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, |
2277 | TF_Status* status) { |
2278 | #ifdef __ANDROID__ |
2279 | status->status = tensorflow::errors::Unimplemented( |
2280 | "Creating while loops is not supported in Android. File a bug at " |
2281 | "https://github.com/tensorflow/tensorflow/issues if this feature is " |
2282 | "important to you" ); |
2283 | return EmptyWhileParams(); |
2284 | #else |
2285 | if (ninputs == 0) { |
2286 | status->status = |
2287 | InvalidArgument("TF_NewWhile() must be passed at least one input" ); |
2288 | return EmptyWhileParams(); |
2289 | } |
2290 | |
2291 | TF_Graph* cond_graph = TF_NewGraph(); |
2292 | TF_Graph* body_graph = TF_NewGraph(); |
2293 | cond_graph->parent = g; |
2294 | cond_graph->parent_inputs = inputs; |
2295 | body_graph->parent = g; |
2296 | body_graph->parent_inputs = inputs; |
2297 | |
2298 | TF_Output* cond_inputs = new TF_Output[ninputs]; |
2299 | TF_Output cond_output = {nullptr, -1}; |
2300 | TF_Output* body_inputs = new TF_Output[ninputs]; |
2301 | TF_Output* body_outputs = new TF_Output[ninputs]; |
2302 | for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1}; |
2303 | const char* name = nullptr; |
2304 | |
2305 | for (int i = 0; i < ninputs; ++i) { |
2306 | // TODO(skyewm): prefix names with underscore (requires some plumbing) |
2307 | if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input" , i).c_str(), |
2308 | &cond_inputs[i], status)) { |
2309 | break; |
2310 | } |
2311 | if (!CreateInput(inputs[i], body_graph, StrCat("body_input" , i).c_str(), |
2312 | &body_inputs[i], status)) { |
2313 | break; |
2314 | } |
2315 | } |
2316 | |
2317 | TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output, |
2318 | body_graph, body_inputs, body_outputs, name}; |
2319 | |
2320 | if (!status->status.ok()) { |
2321 | FreeWhileResources(¶ms); |
2322 | return EmptyWhileParams(); |
2323 | } |
2324 | return params; |
2325 | #endif // __ANDROID__ |
2326 | } |
2327 | |
2328 | #ifndef __ANDROID__ |
2329 | namespace { |
2330 | |
2331 | // TODO(skyewm): make nodes in while loop unfetchable like in Python version |
2332 | void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, |
2333 | TF_Output* outputs) { |
2334 | if (!ValidateInputWhileParams(*params, status)) return; |
2335 | |
2336 | TF_Graph* parent = params->cond_graph->parent; |
2337 | TF_Output* parent_inputs = params->cond_graph->parent_inputs; |
2338 | int num_loop_vars = params->ninputs; |
2339 | |
2340 | mutex_lock l(parent->mu); |
2341 | |
2342 | // 'cond_fn' copies the cond graph into the parent graph. |
2343 | tensorflow::ops::CondGraphBuilderFn cond_fn = |
2344 | [params, parent](const tensorflow::Scope& scope, |
2345 | const std::vector<tensorflow::Output>& inputs, |
2346 | tensorflow::Output* output) { |
2347 | DCHECK_EQ(scope.graph(), &parent->graph); |
2348 | std::vector<tensorflow::Output> cond_output; |
2349 | TF_RETURN_IF_ERROR(CopyGraph( |
2350 | ¶ms->cond_graph->graph, &parent->graph, &parent->refiner, |
2351 | params->cond_inputs, inputs, scope.impl()->name(), |
2352 | scope.impl()->control_deps(), ¶ms->cond_output, |
2353 | /* nreturn_nodes */ 1, &cond_output)); |
2354 | *output = cond_output[0]; |
2355 | return Status::OK(); |
2356 | }; |
2357 | |
2358 | // 'body_fn' copies the body graph into the parent graph. |
2359 | tensorflow::ops::BodyGraphBuilderFn body_fn = |
2360 | [params, parent, num_loop_vars]( |
2361 | const tensorflow::Scope& scope, |
2362 | const std::vector<tensorflow::Output>& inputs, |
2363 | std::vector<tensorflow::Output>* outputs) { |
2364 | DCHECK_EQ(scope.graph(), &parent->graph); |
2365 | TF_RETURN_IF_ERROR( |
2366 | CopyGraph(¶ms->body_graph->graph, &parent->graph, |
2367 | &parent->refiner, params->body_inputs, inputs, |
2368 | scope.impl()->name(), scope.impl()->control_deps(), |
2369 | params->body_outputs, num_loop_vars, outputs)); |
2370 | return Status::OK(); |
2371 | }; |
2372 | |
2373 | // Create the while loop using an internal scope. |
2374 | tensorflow::Scope scope = |
2375 | NewInternalScope(&parent->graph, &status->status, &parent->refiner) |
2376 | .NewSubScope(params->name); |
2377 | |
2378 | const int first_new_node_id = parent->graph.num_node_ids(); |
2379 | |
2380 | tensorflow::OutputList loop_outputs; |
2381 | status->status = tensorflow::ops::BuildWhileLoop( |
2382 | scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn, |
2383 | body_fn, params->name, &loop_outputs); |
2384 | |
2385 | // Update name_map with newly-created ops. |
2386 | // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns |
2387 | // a bad status. Once we fix this, we may want to return early instead of |
2388 | // executing the following code. |
2389 | for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) { |
2390 | Node* new_node = parent->graph.FindNodeId(i); |
2391 | if (new_node == nullptr) continue; |
2392 | parent->name_map[new_node->name()] = new_node; |
2393 | } |
2394 | |
2395 | // Populate 'outputs'. |
2396 | DCHECK_LE(loop_outputs.size(), num_loop_vars); |
2397 | for (int i = 0; i < loop_outputs.size(); ++i) { |
2398 | outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()}; |
2399 | } |
2400 | } |
2401 | |
2402 | } // namespace |
2403 | #endif // __ANDROID__ |
2404 | |
2405 | void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, |
2406 | TF_Output* outputs) { |
2407 | #ifdef __ANDROID__ |
2408 | status->status = tensorflow::errors::Unimplemented( |
2409 | "Creating while loops is not supported in Android. File a bug at " |
2410 | "https://github.com/tensorflow/tensorflow/issues if this feature is " |
2411 | "important to you" ); |
2412 | #else |
2413 | // If it appears the caller created or modified `params`, don't free resources |
2414 | if (!ValidateConstWhileParams(*params, status)) return; |
2415 | TF_FinishWhileHelper(params, status, outputs); |
2416 | FreeWhileResources(params); |
2417 | #endif // __ANDROID__ |
2418 | } |
2419 | |
2420 | void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } |
2421 | |
2422 | void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, |
2423 | TF_Output* dx, TF_Status* status, TF_Output* dy) { |
2424 | #ifdef __ANDROID__ |
2425 | status->status = tensorflow::errors::Unimplemented( |
2426 | "Adding gradients is not supported in Android. File a bug at " |
2427 | "https://github.com/tensorflow/tensorflow/issues if this feature is " |
2428 | "important to you" ); |
2429 | #else |
2430 | std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny); |
2431 | std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx); |
2432 | std::vector<tensorflow::Output> dy_arg; |
2433 | |
2434 | { |
2435 | // We need to hold on to the lock while we have a scope that uses TF_Graph. |
2436 | mutex_lock graph_lock(g->mu); |
2437 | |
2438 | const int first_new_node_id = g->graph.num_node_ids(); |
2439 | |
2440 | tensorflow::Scope scope = |
2441 | NewInternalScope(&g->graph, &status->status, &g->refiner) |
2442 | .NewSubScope("gradients" ); |
2443 | |
2444 | if (dx != nullptr) { |
2445 | std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny); |
2446 | status->status = |
2447 | AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg); |
2448 | } else { |
2449 | status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg); |
2450 | } |
2451 | |
2452 | // Update g->name_map with the name_map from the scope, which will contain |
2453 | // the new gradient ops. |
2454 | for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { |
2455 | Node* n = g->graph.FindNodeId(i); |
2456 | if (n == nullptr) continue; |
2457 | g->name_map[n->name()] = n; |
2458 | } |
2459 | } |
2460 | |
2461 | // Unpack the results from grad_outputs_arg. |
2462 | TFOutputsFromOutputs(dy_arg, dy); |
2463 | #endif // __ANDROID__ |
2464 | } |
2465 | |
2466 | // TF_Session functions ---------------------------------------------- |
2467 | |
2468 | TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) |
2469 | : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {} |
2470 | |
2471 | TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, |
2472 | TF_Status* status) { |
2473 | Session* session; |
2474 | status->status = NewSession(opt->options, &session); |
2475 | if (status->status.ok()) { |
2476 | TF_Session* new_session = new TF_Session(session, graph); |
2477 | if (graph != nullptr) { |
2478 | mutex_lock l(graph->mu); |
2479 | graph->sessions[new_session] = "" ; |
2480 | } |
2481 | return new_session; |
2482 | } else { |
2483 | DCHECK_EQ(nullptr, session); |
2484 | return nullptr; |
2485 | } |
2486 | } |
2487 | |
2488 | TF_Session* TF_LoadSessionFromSavedModel( |
2489 | const TF_SessionOptions* session_options, const TF_Buffer* run_options, |
2490 | const char* export_dir, const char* const* tags, int tags_len, |
2491 | TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) { |
2492 | // TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that |
2493 | // the tensorflow/cc/saved_model:loader build target is Android friendly. |
2494 | #ifdef __ANDROID__ |
2495 | status->status = tensorflow::errors::Unimplemented( |
2496 | "Loading a SavedModel is not supported in Android. File a bug at " |
2497 | "https://github.com/tensorflow/tensorflow/issues if this feature is " |
2498 | "important to you" ); |
2499 | return nullptr; |
2500 | #else |
2501 | mutex_lock l(graph->mu); |
2502 | if (!graph->name_map.empty()) { |
2503 | status->status = InvalidArgument("Graph is non-empty." ); |
2504 | return nullptr; |
2505 | } |
2506 | |
2507 | RunOptions run_options_proto; |
2508 | if (run_options != nullptr && !run_options_proto.ParseFromArray( |
2509 | run_options->data, run_options->length)) { |
2510 | status->status = InvalidArgument("Unparseable RunOptions proto" ); |
2511 | return nullptr; |
2512 | } |
2513 | |
2514 | std::unordered_set<string> tag_set; |
2515 | for (int i = 0; i < tags_len; i++) { |
2516 | tag_set.insert(string(tags[i])); |
2517 | } |
2518 | |
2519 | tensorflow::SavedModelBundle bundle; |
2520 | status->status = |
2521 | tensorflow::LoadSavedModel(session_options->options, run_options_proto, |
2522 | export_dir, tag_set, &bundle); |
2523 | if (!status->status.ok()) return nullptr; |
2524 | |
2525 | // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session |
2526 | // extends using GraphDefs. The Graph instance is different, but equivalent |
2527 | // to the one used to create the session. |
2528 | // |
2529 | // TODO(jhseu): When Session is modified to take Graphs instead of |
2530 | // GraphDefs, return the Graph generated in LoadSavedModel(). |
2531 | TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); |
2532 | TF_ImportGraphDefResults results; |
2533 | GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), |
2534 | import_opts, &results, status); |
2535 | TF_DeleteImportGraphDefOptions(import_opts); |
2536 | if (TF_GetCode(status) != TF_OK) return nullptr; |
2537 | |
2538 | if (meta_graph_def != nullptr) { |
2539 | status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); |
2540 | if (!status->status.ok()) return nullptr; |
2541 | } |
2542 | |
2543 | TF_Session* session = new TF_Session(bundle.session.release(), graph); |
2544 | |
2545 | graph->sessions[session] = "" ; |
2546 | session->last_num_graph_nodes = graph->graph.num_node_ids(); |
2547 | return session; |
2548 | #endif // __ANDROID__ |
2549 | } |
2550 | |
2551 | void TF_CloseSession(TF_Session* s, TF_Status* status) { |
2552 | status->status = s->session->Close(); |
2553 | } |
2554 | |
2555 | void TF_DeleteSession(TF_Session* s, TF_Status* status) { |
2556 | status->status = Status::OK(); |
2557 | TF_Graph* const graph = s->graph; |
2558 | if (graph != nullptr) { |
2559 | graph->mu.lock(); |
2560 | graph->sessions.erase(s); |
2561 | const bool del = graph->delete_requested && graph->sessions.empty(); |
2562 | graph->mu.unlock(); |
2563 | if (del) delete graph; |
2564 | } |
2565 | delete s->session; |
2566 | delete s; |
2567 | } |
2568 | |
2569 | void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, |
2570 | const TF_Output* inputs, TF_Tensor* const* input_values, |
2571 | int ninputs, const TF_Output* outputs, |
2572 | TF_Tensor** output_values, int noutputs, |
2573 | const TF_Operation* const* target_opers, int ntargets, |
2574 | TF_Buffer* run_metadata, TF_Status* status) { |
2575 | // TODO(josh11b,mrry): Change Session to be able to use a Graph* |
2576 | // directly, instead of requiring us to serialize to a GraphDef and |
2577 | // call Session::Extend(). |
2578 | if (session->extend_before_run && |
2579 | !ExtendSessionGraphHelper(session, status)) { |
2580 | return; |
2581 | } |
2582 | |
2583 | TF_Run_Setup(noutputs, output_values, status); |
2584 | |
2585 | // Convert from TF_Output and TF_Tensor to a string and Tensor. |
2586 | std::vector<std::pair<string, Tensor>> input_pairs(ninputs); |
2587 | if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; |
2588 | for (int i = 0; i < ninputs; ++i) { |
2589 | input_pairs[i].first = OutputName(inputs[i]); |
2590 | } |
2591 | |
2592 | // Convert from TF_Output to string names. |
2593 | std::vector<string> output_names(noutputs); |
2594 | for (int i = 0; i < noutputs; ++i) { |
2595 | output_names[i] = OutputName(outputs[i]); |
2596 | } |
2597 | |
2598 | // Convert from TF_Operation* to string names. |
2599 | std::vector<string> target_names(ntargets); |
2600 | for (int i = 0; i < ntargets; ++i) { |
2601 | target_names[i] = target_opers[i]->node.name(); |
2602 | } |
2603 | |
2604 | // Actually run. |
2605 | TF_Run_Helper(session->session, nullptr, run_options, input_pairs, |
2606 | output_names, output_values, target_names, run_metadata, |
2607 | status); |
2608 | } |
2609 | |
2610 | void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, |
2611 | int ninputs, const TF_Output* outputs, int noutputs, |
2612 | const TF_Operation* const* target_opers, int ntargets, |
2613 | const char** handle, TF_Status* status) { |
2614 | *handle = nullptr; |
2615 | |
2616 | if (session->extend_before_run && |
2617 | !ExtendSessionGraphHelper(session, status)) { |
2618 | return; |
2619 | } |
2620 | |
2621 | std::vector<string> input_names(ninputs); |
2622 | for (int i = 0; i < ninputs; ++i) { |
2623 | input_names[i] = OutputName(inputs[i]); |
2624 | } |
2625 | |
2626 | std::vector<string> output_names(noutputs); |
2627 | for (int i = 0; i < noutputs; ++i) { |
2628 | output_names[i] = OutputName(outputs[i]); |
2629 | } |
2630 | |
2631 | std::vector<string> target_names(ntargets); |
2632 | for (int i = 0; i < ntargets; ++i) { |
2633 | target_names[i] = target_opers[i]->node.name(); |
2634 | } |
2635 | |
2636 | string new_handle; |
2637 | status->status = session->session->PRunSetup(input_names, output_names, |
2638 | target_names, &new_handle); |
2639 | if (status->status.ok()) { |
2640 | char* buf = new char[new_handle.size() + 1]; |
2641 | memcpy(buf, new_handle.c_str(), new_handle.size() + 1); |
2642 | *handle = buf; |
2643 | } |
2644 | } |
2645 | |
2646 | void TF_DeletePRunHandle(const char* handle) { |
2647 | delete[] handle; |
2648 | // TODO(suharshs): Free up any resources held by the partial run state. |
2649 | } |
2650 | |
2651 | void TF_SessionPRun(TF_Session* session, const char* handle, |
2652 | const TF_Output* inputs, TF_Tensor* const* input_values, |
2653 | int ninputs, const TF_Output* outputs, |
2654 | TF_Tensor** output_values, int noutputs, |
2655 | const TF_Operation* const* target_opers, int ntargets, |
2656 | TF_Status* status) { |
2657 | // TODO(josh11b,mrry): Change Session to be able to use a Graph* |
2658 | // directly, instead of requiring us to serialize to a GraphDef and |
2659 | // call Session::Extend(). |
2660 | if (session->extend_before_run && |
2661 | !ExtendSessionGraphHelper(session, status)) { |
2662 | return; |
2663 | } |
2664 | |
2665 | TF_Run_Setup(noutputs, output_values, status); |
2666 | |
2667 | // Convert from TF_Output and TF_Tensor to a string and Tensor. |
2668 | std::vector<std::pair<string, Tensor>> input_pairs(ninputs); |
2669 | if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; |
2670 | for (int i = 0; i < ninputs; ++i) { |
2671 | input_pairs[i].first = OutputName(inputs[i]); |
2672 | } |
2673 | |
2674 | // Convert from TF_Output to string names. |
2675 | std::vector<string> output_names(noutputs); |
2676 | for (int i = 0; i < noutputs; ++i) { |
2677 | output_names[i] = OutputName(outputs[i]); |
2678 | } |
2679 | |
2680 | // Convert from TF_Operation* to string names. |
2681 | std::vector<string> target_names(ntargets); |
2682 | for (int i = 0; i < ntargets; ++i) { |
2683 | target_names[i] = target_opers[i]->node.name(); |
2684 | } |
2685 | |
2686 | TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names, |
2687 | output_values, target_names, nullptr, status); |
2688 | } |
2689 | |
2690 | unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, |
2691 | TF_Tensor** result, TF_Status* status) { |
2692 | *result = nullptr; |
2693 | mutex_lock l(graph->mu); |
2694 | OutputTensor tensor(&output.oper->node, output.index); |
2695 | bool evaluated; |
2696 | Tensor result_tensor; |
2697 | status->status = EvaluateConstantTensor( |
2698 | tensor, graph->refiner, *graph->graph.op_registry(), |
2699 | graph->graph.versions().producer(), &evaluated, &result_tensor); |
2700 | if (evaluated) { |
2701 | DCHECK(status->status.ok()); |
2702 | *result = TF_TensorFromTensor(result_tensor, status); |
2703 | if (!status->status.ok()) evaluated = false; |
2704 | } |
2705 | return evaluated; |
2706 | } |
2707 | |
2708 | TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { |
2709 | tensorflow::OpList op_list; |
2710 | if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { |
2711 | status->status = InvalidArgument("Unparseable OpList" ); |
2712 | return nullptr; |
2713 | } |
2714 | status->status = Status::OK(); |
2715 | return new TF_ApiDefMap(op_list); |
2716 | } |
2717 | |
2718 | void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } |
2719 | |
2720 | void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, |
2721 | size_t text_len, TF_Status* status) { |
2722 | #ifdef __ANDROID__ |
2723 | status->status = tensorflow::errors::Unimplemented( |
2724 | "ApiDefMap is not supported in Android." ); |
2725 | #else |
2726 | mutex_lock l(api_def_map->lock); |
2727 | if (api_def_map->update_docs_called) { |
2728 | status->status = FailedPrecondition( |
2729 | "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " |
2730 | "called." ); |
2731 | return; |
2732 | } |
2733 | string api_def_text(text, text_len); |
2734 | status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); |
2735 | #endif // __ANDROID__ |
2736 | } |
2737 | |
2738 | TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, |
2739 | size_t name_len, TF_Status* status) { |
2740 | #ifdef __ANDROID__ |
2741 | status->status = tensorflow::errors::Unimplemented( |
2742 | "ApiDefMap is not supported in Android." ); |
2743 | return nullptr; |
2744 | #else |
2745 | mutex_lock l(api_def_map->lock); |
2746 | if (!api_def_map->update_docs_called) { |
2747 | api_def_map->api_def_map.UpdateDocs(); |
2748 | api_def_map->update_docs_called = true; |
2749 | } |
2750 | string name_str(name, name_len); |
2751 | const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); |
2752 | |
2753 | TF_Buffer* ret = TF_NewBuffer(); |
2754 | status->status = MessageToBuffer(*api_def, ret); |
2755 | return ret; |
2756 | #endif // __ANDROID__ |
2757 | } |
2758 | } // end extern "C" |
2759 | |