1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/c_api.h"
17
18#include <algorithm>
19#include <limits>
20#include <memory>
21#include <vector>
22
23#ifndef __ANDROID__
24#include "tensorflow/cc/framework/gradients.h"
25#include "tensorflow/cc/framework/ops.h"
26#include "tensorflow/cc/framework/scope_internal.h"
27#include "tensorflow/cc/ops/while_loop.h"
28#include "tensorflow/cc/saved_model/loader.h"
29#include "tensorflow/core/framework/op_gen_lib.h"
30#endif
31#include "tensorflow/c/c_api_internal.h"
32#include "tensorflow/core/common_runtime/device_mgr.h"
33#include "tensorflow/core/common_runtime/eval_const_tensor.h"
34#include "tensorflow/core/common_runtime/shape_refiner.h"
35#include "tensorflow/core/framework/allocation_description.pb.h"
36#include "tensorflow/core/framework/log_memory.h"
37#include "tensorflow/core/framework/node_def_util.h"
38#include "tensorflow/core/framework/op_kernel.h"
39#include "tensorflow/core/framework/partial_tensor_shape.h"
40#include "tensorflow/core/framework/tensor.h"
41#include "tensorflow/core/framework/tensor_shape.h"
42#include "tensorflow/core/framework/tensor_shape.pb.h"
43#include "tensorflow/core/framework/types.h"
44#include "tensorflow/core/framework/versions.pb.h"
45#include "tensorflow/core/graph/graph.h"
46#include "tensorflow/core/graph/graph_constructor.h"
47#include "tensorflow/core/graph/node_builder.h"
48#include "tensorflow/core/lib/core/coding.h"
49#include "tensorflow/core/lib/core/errors.h"
50#include "tensorflow/core/lib/core/status.h"
51#include "tensorflow/core/lib/core/stringpiece.h"
52#include "tensorflow/core/lib/gtl/array_slice.h"
53#include "tensorflow/core/lib/strings/strcat.h"
54#include "tensorflow/core/platform/mem.h"
55#include "tensorflow/core/platform/mutex.h"
56#include "tensorflow/core/platform/protobuf.h"
57#include "tensorflow/core/platform/thread_annotations.h"
58#include "tensorflow/core/platform/types.h"
59#include "tensorflow/core/public/session.h"
60#include "tensorflow/core/public/version.h"
61
62// The implementation below is at the top level instead of the
63// brain namespace because we are defining 'extern "C"' functions.
64using tensorflow::AllocationDescription;
65using tensorflow::DataType;
66using tensorflow::ExtendSessionGraphHelper;
67using tensorflow::Graph;
68using tensorflow::GraphDef;
69using tensorflow::mutex_lock;
70using tensorflow::NameRangeMap;
71using tensorflow::NameRangesForNode;
72using tensorflow::NewSession;
73using tensorflow::Node;
74using tensorflow::NodeBuilder;
75using tensorflow::NodeDef;
76using tensorflow::OpDef;
77using tensorflow::OpRegistry;
78using tensorflow::OutputTensor;
79using tensorflow::PartialTensorShape;
80using tensorflow::RunMetadata;
81using tensorflow::RunOptions;
82using tensorflow::Session;
83using tensorflow::Status;
84using tensorflow::string;
85using tensorflow::Tensor;
86using tensorflow::TensorBuffer;
87using tensorflow::TensorId;
88using tensorflow::TensorShape;
89using tensorflow::TensorShapeProto;
90using tensorflow::VersionDef;
91using tensorflow::error::Code;
92using tensorflow::errors::FailedPrecondition;
93using tensorflow::errors::InvalidArgument;
94using tensorflow::gtl::ArraySlice;
95using tensorflow::strings::StrCat;
96
97extern "C" {
98
99// --------------------------------------------------------------------------
100const char* TF_Version() { return TF_VERSION_STRING; }
101
102// --------------------------------------------------------------------------
103size_t TF_DataTypeSize(TF_DataType dt) {
104 return static_cast<size_t>(
105 tensorflow::DataTypeSize(static_cast<DataType>(dt)));
106}
107
108// --------------------------------------------------------------------------
109
110TF_Status* TF_NewStatus() { return new TF_Status; }
111
112void TF_DeleteStatus(TF_Status* s) { delete s; }
113
114void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
115 if (code == TF_OK) {
116 s->status = Status::OK();
117 return;
118 }
119 s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
120}
121
122TF_Code TF_GetCode(const TF_Status* s) {
123 return static_cast<TF_Code>(s->status.code());
124}
125
126const char* TF_Message(const TF_Status* s) {
127 return s->status.error_message().c_str();
128}
129
130// --------------------------------------------------------------------------
131
132namespace {
133class TF_ManagedBuffer : public TensorBuffer {
134 public:
135 void* data_;
136 size_t len_;
137 void (*deallocator_)(void* data, size_t len, void* arg);
138 void* deallocator_arg_;
139
140 ~TF_ManagedBuffer() override {
141 (*deallocator_)(data_, len_, deallocator_arg_);
142 }
143
144 void* data() const override { return data_; }
145 size_t size() const override { return len_; }
146 TensorBuffer* root_buffer() override { return this; }
147 void FillAllocationDescription(AllocationDescription* proto) const override {
148 tensorflow::int64 rb = size();
149 proto->set_requested_bytes(rb);
150 proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
151 }
152
153 // Prevents input forwarding from mutating this buffer.
154 bool OwnsMemory() const override { return false; }
155};
156
157void* allocate_tensor(const char* operation, size_t len) {
158 void* data =
159 tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
160 if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
161 tensorflow::LogMemory::RecordRawAllocation(
162 operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
163 len, data, tensorflow::cpu_allocator());
164 }
165 return data;
166}
167
168void deallocate_buffer(void* data, size_t len, void* arg) {
169 if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
170 tensorflow::LogMemory::RecordRawDeallocation(
171 "TensorFlow C Api",
172 tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
173 tensorflow::cpu_allocator(), false);
174 }
175 tensorflow::cpu_allocator()->DeallocateRaw(data);
176}
177
178} // namespace
179
180TF_Tensor::~TF_Tensor() { buffer->Unref(); }
181
182TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
183 int num_dims, size_t len) {
184 void* data = allocate_tensor("TF_AllocateTensor", len);
185 return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
186 nullptr);
187}
188
189TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
190 void* data, size_t len,
191 void (*deallocator)(void* data, size_t len, void* arg),
192 void* deallocator_arg) {
193 std::vector<tensorflow::int64> dimvec(num_dims);
194 for (int i = 0; i < num_dims; ++i) {
195 dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
196 }
197
198 TF_ManagedBuffer* buf = new TF_ManagedBuffer;
199 buf->len_ = len;
200 if (dtype != TF_STRING && dtype != TF_RESOURCE &&
201 tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
202 reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
203 // TF_STRING and TF_RESOURCE tensors have a different representation in
204 // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
205 // (any alignment requirements will be taken care of by TF_TensorToTensor
206 // and TF_TensorFromTensor).
207 //
208 // Other types have the same representation, so copy only if it is safe to
209 // do so.
210 buf->data_ = allocate_tensor("TF_NewTensor", len);
211 std::memcpy(buf->data_, data, len);
212 buf->deallocator_ = deallocate_buffer;
213 buf->deallocator_arg_ = nullptr;
214 // Free the original buffer.
215 deallocator(data, len, deallocator_arg);
216 } else {
217 buf->data_ = data;
218 buf->deallocator_ = deallocator;
219 buf->deallocator_arg_ = deallocator_arg;
220 }
221 TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
222 size_t elem_size = TF_DataTypeSize(dtype);
223 if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
224 delete ret;
225 return nullptr;
226 }
227 return ret;
228}
229
230TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
231 // It is safe to move the Tensor if and only if we own the unique reference to
232 // it. In that case, we might as well not delete and reallocate, but a future
233 // implementation might need to do so.
234 TensorBuffer* buf = tensor->buffer;
235 if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
236 buf->OwnsMemory()) {
237 return tensor;
238 }
239 return nullptr;
240}
241
242void TF_DeleteTensor(TF_Tensor* t) { delete t; }
243
244TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
245int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
246int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
247 return static_cast<int64_t>(t->shape.dim_size(dim_index));
248}
249size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
250void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
251
252// --------------------------------------------------------------------------
253size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
254 size_t dst_len, TF_Status* status) {
255 const size_t sz = TF_StringEncodedSize(src_len);
256 if (sz < src_len) {
257 status->status = InvalidArgument("src string is too large to encode");
258 return 0;
259 }
260 if (dst_len < sz) {
261 status->status =
262 InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
263 src_len, "-byte string");
264 return 0;
265 }
266 dst = tensorflow::core::EncodeVarint64(dst, src_len);
267 memcpy(dst, src, src_len);
268 return sz;
269}
270
271static Status TF_StringDecode_Impl(const char* src, size_t src_len,
272 const char** dst, size_t* dst_len) {
273 tensorflow::uint64 len64 = 0;
274 const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
275 if (p == nullptr) {
276 return InvalidArgument("invalid string encoding or truncated src buffer");
277 }
278 if (len64 > std::numeric_limits<size_t>::max()) {
279 return InvalidArgument("encoded string is ", len64,
280 "-bytes, which is too large for this architecture");
281 }
282 *dst = p;
283 *dst_len = static_cast<size_t>(len64);
284 return Status::OK();
285}
286
287size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
288 size_t* dst_len, TF_Status* status) {
289 status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
290 if (!status->status.ok()) return 0;
291 return static_cast<size_t>(*dst - src) + *dst_len;
292}
293
294size_t TF_StringEncodedSize(size_t len) {
295 return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
296}
297
298// --------------------------------------------------------------------------
299TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
300void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
301
302void TF_SetTarget(TF_SessionOptions* options, const char* target) {
303 options->options.target = target;
304}
305
306void TF_SetConfig(TF_SessionOptions* options, const void* proto,
307 size_t proto_len, TF_Status* status) {
308 if (!options->options.config.ParseFromArray(proto, proto_len)) {
309 status->status = InvalidArgument("Unparseable ConfigProto");
310 }
311}
312// --------------------------------------------------------------------------
313TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
314
315TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
316 void* copy = tensorflow::port::Malloc(proto_len);
317 memcpy(copy, proto, proto_len);
318
319 TF_Buffer* buf = new TF_Buffer;
320 buf->data = copy;
321 buf->length = proto_len;
322 buf->data_deallocator = [](void* data, size_t length) {
323 tensorflow::port::Free(data);
324 };
325 return buf;
326}
327
328void TF_DeleteBuffer(TF_Buffer* buffer) {
329 if (buffer->data_deallocator != nullptr) {
330 (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
331 buffer->length);
332 }
333 delete buffer;
334}
335
336TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
337
338// --------------------------------------------------------------------------
339
340TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
341 TF_Status* status) {
342 Session* session;
343 status->status = NewSession(opt->options, &session);
344 if (status->status.ok()) {
345 return new TF_DeprecatedSession({session});
346 } else {
347 DCHECK_EQ(nullptr, session);
348 return nullptr;
349 }
350}
351
352void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
353 status->status = s->session->Close();
354}
355
356void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
357 status->status = Status::OK();
358 delete s->session;
359 delete s;
360}
361
362void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
363 size_t proto_len, TF_Status* status) {
364 GraphDef g;
365 if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
366 status->status = InvalidArgument("Invalid GraphDef");
367 return;
368 }
369 status->status = s->session->Extend(g);
370}
371
372static void DeleteArray(void* data, size_t size, void* arg) {
373 DCHECK_EQ(data, arg);
374 delete[] reinterpret_cast<char*>(arg);
375}
376
377} // end extern "C"
378
379namespace tensorflow {
380namespace {
381
382// Reset helper for converting character arrays to string vectors.
383void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
384 int ncontainers, TF_Status* status) {
385 std::vector<string> container_names(ncontainers);
386 for (int i = 0; i < ncontainers; ++i) {
387 container_names[i] = containers[i];
388 }
389
390 status->status = Reset(opt->options, container_names);
391}
392
393// This traverses the specified nodes in topological order to verify there are
394// no cycles. Starting with inputless nodes, it visits nodes whose inputs have
395// all been visited, and counts the total number of visited nodes. If there is a
396// cycle, nodes in the cycle will never be visited, and the visited count will
397// be less than the total node count.
398Status ValidateNoCycles(const Graph& g) {
399 // TODO(nolivia): check this on a subset of the graph instead of all of it.
400 // A node is ready when all of its inputs have been visited.
401 std::vector<const Node*> ready;
402 std::vector<int> pending_count(g.num_node_ids(), 0);
403
404 for (int i = 0; i < g.num_node_ids(); ++i) {
405 const Node* n = g.FindNodeId(i);
406 if (n == nullptr) continue;
407 pending_count[i] = n->in_edges().size();
408 if (n->IsMerge()) {
409 // While-loop cycles are legal cycles so we manually adjust the
410 // pending_count to make sure that the loop is visited.
411 for (const Edge* e : n->in_edges()) {
412 if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
413 pending_count[i]--;
414 }
415 }
416 }
417 if (pending_count[i] == 0) {
418 ready.push_back(n);
419 }
420 }
421
422 int processed = 0;
423 while (!ready.empty()) {
424 const Node* node = ready.back();
425 ready.pop_back();
426 ++processed;
427
428 for (const Edge* out : node->out_edges()) {
429 const int output_id = out->dst()->id();
430 pending_count[output_id]--;
431 if (pending_count[output_id] == 0) {
432 ready.push_back(out->dst());
433 }
434 }
435 }
436
437 if (processed < g.num_nodes()) {
438 std::vector<string> nodes_in_cycle;
439 for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
440 ++i) {
441 if (pending_count[i] != 0) {
442 nodes_in_cycle.push_back(g.FindNodeId(i)->name());
443 }
444 }
445 return errors::InvalidArgument(
446 "Graph is invalid, contains a cycle with ", g.num_nodes() - processed,
447 " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
448 }
449 return Status::OK();
450}
451} // namespace
452} // namespace tensorflow
453
454extern "C" {
455
456void TF_Reset(const TF_SessionOptions* opt, const char** containers,
457 int ncontainers, TF_Status* status) {
458 tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status);
459}
460
461} // end extern "C"
462
463namespace tensorflow {
464
465Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
466 if (src->dtype == TF_RESOURCE) {
467 if (src->shape.dims() != 0) {
468 return InvalidArgument(
469 "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
470 "shape ",
471 src->shape.DebugString());
472 }
473 *dst = Tensor(DT_RESOURCE, src->shape);
474 if (!dst->scalar<ResourceHandle>()().ParseFromString(
475 string(static_cast<const char*>(TF_TensorData(src)),
476 TF_TensorByteSize(src)))) {
477 return InvalidArgument(
478 "Malformed TF_RESOUCE tensor: unable to parse resource handle");
479 }
480 return Status::OK();
481 }
482 if (src->dtype != TF_STRING) {
483 *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
484 return Status::OK();
485 }
486 // TF_STRING tensors require copying since Tensor class expects a sequence of
487 // string objects.
488 const tensorflow::int64 num_elements = src->shape.num_elements();
489 const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
490 const size_t src_size = TF_TensorByteSize(src);
491 if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
492 num_elements) {
493 return InvalidArgument(
494 "Malformed TF_STRING tensor; too short to hold number of elements");
495 }
496 const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
497 const char* limit = input + src_size;
498
499 *dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
500 auto dstarray = dst->flat<string>();
501 for (tensorflow::int64 i = 0; i < num_elements; ++i) {
502 tensorflow::uint64 offset =
503 reinterpret_cast<const tensorflow::uint64*>(input)[i];
504 if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
505 return InvalidArgument("Malformed TF_STRING tensor; element ", i,
506 " out of range");
507 }
508 size_t len;
509 const char* p;
510 const char* srcp = data_start + offset;
511 Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
512 if (!status.ok()) return status;
513 dstarray(i).assign(p, len);
514 }
515 return Status::OK();
516}
517
518// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
519// result in a zero-sized tensor.
520static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
521 static char empty;
522 tensorflow::int64 nelems = 1;
523 std::vector<tensorflow::int64> dims;
524 for (int i = 0; i < shape.dims(); ++i) {
525 dims.push_back(shape.dim_size(i));
526 nelems *= shape.dim_size(i);
527 }
528 CHECK_EQ(nelems, 0);
529 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
530 "64-bit int types should match in size");
531 return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()),
532 shape.dims(), reinterpret_cast<void*>(&empty), 0,
533 [](void*, size_t, void*) {}, nullptr);
534}
535
536// Non-static for testing.
537TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
538 TF_Status* status) {
539 if (!src.IsInitialized()) {
540 status->status = FailedPrecondition(
541 "attempt to use a tensor with an uninitialized value");
542 return nullptr;
543 }
544 if (src.NumElements() == 0) {
545 return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
546 }
547 if (src.dtype() == DT_RESOURCE) {
548 if (src.shape().dims() != 0) {
549 status->status = InvalidArgument(
550 "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
551 src.shape().DebugString(),
552 "). Please file a bug at "
553 "https://github.com/tensorflow/tensorflow/issues/new, "
554 "ideally with a "
555 "short code snippet that reproduces this error.");
556 return nullptr;
557 }
558 const string str = src.scalar<ResourceHandle>()().SerializeAsString();
559 TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
560 std::memcpy(TF_TensorData(t), str.c_str(), str.size());
561 return t;
562 }
563 if (src.dtype() != DT_STRING) {
564 TensorBuffer* buf = TensorCApi::Buffer(src);
565 buf->Ref();
566 return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
567 buf};
568 }
569 // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
570 // encoded sequence of strings.
571
572 // Compute bytes needed for encoding.
573 size_t size = 0;
574 const auto& srcarray = src.flat<string>();
575 for (int i = 0; i < srcarray.size(); ++i) {
576 const string& s = srcarray(i);
577 // uint64 starting_offset, TF_StringEncode-d string.
578 size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
579 }
580
581 // Encode all strings.
582 char* base = new char[size];
583 char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
584 char* dst = data_start; // Where next string is encoded.
585 size_t dst_len = size - static_cast<size_t>(data_start - base);
586 tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
587 for (int i = 0; i < srcarray.size(); ++i) {
588 *offsets = (dst - data_start);
589 offsets++;
590 const string& s = srcarray(i);
591 size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
592 if (!status->status.ok()) {
593 status->status = InvalidArgument(
594 "invalid string tensor encoding (string #", i, " of ",
595 srcarray.size(), "): ", status->status.error_message());
596 delete[] base;
597 return nullptr;
598 }
599 dst += consumed;
600 dst_len -= consumed;
601 }
602 if (dst != base + size) {
603 status->status = InvalidArgument(
604 "invalid string tensor encoding (decoded ", (dst - base),
605 " bytes, but the tensor is encoded in ", size, " bytes");
606 delete[] base;
607 return nullptr;
608 }
609
610 auto dims = src.shape().dim_sizes();
611 std::vector<tensorflow::int64> dimvec(dims.size());
612 for (size_t i = 0; i < dims.size(); ++i) {
613 dimvec[i] = dims[i];
614 }
615 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
616 "64-bit int types should match in size");
617 return TF_NewTensor(TF_STRING,
618 reinterpret_cast<const int64_t*>(dimvec.data()),
619 dimvec.size(), base, size, DeleteArray, base);
620}
621
622Status MessageToBuffer(const tensorflow::protobuf::Message& in,
623 TF_Buffer* out) {
624 if (out->data != nullptr) {
625 return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
626 }
627 const size_t proto_size = in.ByteSizeLong();
628 void* buf = tensorflow::port::Malloc(proto_size);
629 if (buf == nullptr) {
630 return tensorflow::errors::ResourceExhausted(
631 "Failed to allocate memory to serialize message of type '",
632 in.GetTypeName(), "' and size ", proto_size);
633 }
634 in.SerializeToArray(buf, proto_size);
635 out->data = buf;
636 out->length = proto_size;
637 out->data_deallocator = [](void* data, size_t length) {
638 tensorflow::port::Free(data);
639 };
640 return Status::OK();
641}
642
643void RecordMutation(TF_Graph* graph, const TF_Operation& op,
644 const char* mutation_type) {
645 // If any session has already run this node_id, mark this session as
646 // unrunnable.
647 for (auto it : graph->sessions) {
648 mutex_lock session_lock(it.first->mu);
649 if (it.first->last_num_graph_nodes > op.node.id()) {
650 it.second = strings::StrCat(
651 "Operation '", op.node.DebugString(), "' was changed by ",
652 mutation_type,
653 " after it was run by a session. This mutation will have no effect, "
654 "and will trigger an error in the future. Either don't modify "
655 "nodes after running them or create a new session.");
656 }
657 }
658}
659
660namespace {
661
662// Helper method that creates a shape handle for a shape described by dims.
663tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
664 tensorflow::shape_inference::InferenceContext* ic, int num_dims,
665 const int64_t* dims) {
666 if (num_dims != -1) {
667 std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
668 dim_vec.reserve(num_dims);
669 for (int i = 0; i < num_dims; ++i) {
670 dim_vec.push_back(ic->MakeDim(dims[i]));
671 }
672 return ic->MakeShape(dim_vec);
673 } else {
674 return ic->UnknownShape();
675 }
676}
677
678} // namespace
679
680void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
681 int num_shapes_and_types,
682 const int64_t** shapes,
683 const int* ranks,
684 const TF_DataType* types,
685 TF_Status* status) {
686 Node* node = &output.oper->node;
687
688 mutex_lock l(graph->mu);
689 tensorflow::shape_inference::InferenceContext* ic =
690 graph->refiner.GetContext(node);
691 if (ic == nullptr) {
692 status->status =
693 InvalidArgument("Node ", node->name(), " was not found in the graph");
694 return;
695 }
696
697 auto shape_and_type_vec =
698 std::vector<tensorflow::shape_inference::ShapeAndType>(
699 num_shapes_and_types);
700 for (int i = 0; i < num_shapes_and_types; ++i) {
701 tensorflow::shape_inference::ShapeHandle shape_handle =
702 ShapeHandleFromDims(ic, ranks[i], shapes[i]);
703 shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
704 shape_handle, static_cast<DataType>(types[i]));
705 }
706
707 ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
708}
709
710// Helpers for loading a TensorFlow plugin (a .so file).
711Status LoadLibrary(const char* library_filename, void** result,
712 const void** buf, size_t* len);
713
714// TODO(josh11b,mrry): Change Session to be able to use a Graph*
715// directly, instead of requiring us to serialize to a GraphDef and
716// call Session::Extend().
717bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
718 if (session->graph != nullptr) {
719 // Take the graph lock before the session lock to avoid deadlock. This is
720 // safe since session->graph does not change.
721 session->graph->mu.lock();
722 mutex_lock session_lock(session->mu);
723 const Graph& graph = session->graph->graph;
724
725 const string& mutation_warning = session->graph->sessions[session];
726 if (!mutation_warning.empty()) {
727 // TODO(b/74949947): turn this back into an error status
728 LOG(WARNING) << mutation_warning;
729 session->graph->sessions[session].clear();
730 }
731
732 const auto num_nodes = graph.num_node_ids();
733 if (session->last_num_graph_nodes < num_nodes) {
734 status->status = tensorflow::ValidateNoCycles(session->graph->graph);
735 if (!status->status.ok()) {
736 session->graph->mu.unlock();
737 return false;
738 }
739
740 GraphDef graph_def;
741 *graph_def.mutable_versions() = graph.versions();
742 // Fill graph_def with nodes with ids in the range
743 // [session->last_num_graph_nodes, num_nodes), that is the nodes
744 // added since the last TF_SessionRun() call.
745 for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
746 Node* const node = graph.FindNodeId(id);
747 if (node != nullptr && node->IsOp()) {
748 NodeDef* const node_def = graph_def.add_node();
749 *node_def = node->def();
750 }
751 }
752 *graph_def.mutable_library() = graph.flib_def().ToProto();
753 session->graph->mu.unlock();
754 status->status = session->session->Extend(graph_def);
755 if (!status->status.ok()) {
756 // Contract is we always delete input_values[i].
757 return false;
758 }
759 // Note: session->session is not modified if Extend() fails, so
760 // we only set last_num_graph_nodes if it succeeds.
761 session->last_num_graph_nodes = num_nodes;
762 } else {
763 session->graph->mu.unlock();
764 }
765 }
766 return true;
767}
768
769} // namespace tensorflow
770
771static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
772 TF_Status* status) {
773 status->status = Status::OK();
774 for (int i = 0; i < noutputs; ++i) {
775 c_outputs[i] = nullptr;
776 }
777}
778
779static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
780 std::vector<std::pair<string, Tensor>>* input_pairs,
781 TF_Status* status) {
782 const int ninputs = input_pairs->size();
783 for (int i = 0; i < ninputs; ++i) {
784 status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
785 if (!status->status.ok()) return false;
786 }
787 return true;
788}
789
790static void TF_Run_Helper(
791 Session* session, const char* handle, const TF_Buffer* run_options,
792 // Input tensors
793 const std::vector<std::pair<string, Tensor>>& input_pairs,
794 // Output tensors
795 const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
796 // Target nodes
797 const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
798 TF_Status* status) {
799 const int noutputs = output_tensor_names.size();
800 std::vector<Tensor> outputs(noutputs);
801 Status result;
802
803 if (handle == nullptr) {
804 RunOptions run_options_proto;
805 if (run_options != nullptr && !run_options_proto.ParseFromArray(
806 run_options->data, run_options->length)) {
807 status->status = InvalidArgument("Unparseable RunOptions proto");
808 return;
809 }
810 if (run_metadata != nullptr && run_metadata->data != nullptr) {
811 status->status =
812 InvalidArgument("Passing non-empty run_metadata is invalid.");
813 return;
814 }
815
816 RunMetadata run_metadata_proto;
817 result = session->Run(run_options_proto, input_pairs, output_tensor_names,
818 target_oper_names, &outputs, &run_metadata_proto);
819
820 // Serialize back to upstream client, who now owns the new buffer
821 if (run_metadata != nullptr) {
822 status->status = MessageToBuffer(run_metadata_proto, run_metadata);
823 if (!status->status.ok()) return;
824 }
825 } else {
826 // NOTE(zongheng): PRun does not support RunOptions yet.
827 result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
828 }
829 if (!result.ok()) {
830 status->status = result;
831 return;
832 }
833
834 // Store results in c_outputs[]
835 for (int i = 0; i < noutputs; ++i) {
836 const Tensor& src = outputs[i];
837 if (!src.IsInitialized() || src.NumElements() == 0) {
838 c_outputs[i] =
839 EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
840 continue;
841 }
842 c_outputs[i] = TF_TensorFromTensor(src, status);
843 if (!status->status.ok()) return;
844 }
845}
846
847extern "C" {
848
849void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
850 // Input tensors
851 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
852 // Output tensors
853 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
854 // Target nodes
855 const char** c_target_oper_names, int ntargets,
856 TF_Buffer* run_metadata, TF_Status* status) {
857 TF_Run_Setup(noutputs, c_outputs, status);
858 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
859 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
860 for (int i = 0; i < ninputs; ++i) {
861 input_pairs[i].first = c_input_names[i];
862 }
863 std::vector<string> output_names(noutputs);
864 for (int i = 0; i < noutputs; ++i) {
865 output_names[i] = c_output_names[i];
866 }
867 std::vector<string> target_oper_names(ntargets);
868 for (int i = 0; i < ntargets; ++i) {
869 target_oper_names[i] = c_target_oper_names[i];
870 }
871 TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
872 c_outputs, target_oper_names, run_metadata, status);
873}
874
875void TF_PRunSetup(TF_DeprecatedSession* s,
876 // Input names
877 const char** c_input_names, int ninputs,
878 // Output names
879 const char** c_output_names, int noutputs,
880 // Target nodes
881 const char** c_target_oper_names, int ntargets,
882 const char** handle, TF_Status* status) {
883 *handle = nullptr;
884
885 std::vector<string> input_names(ninputs);
886 std::vector<string> output_names(noutputs);
887 std::vector<string> target_oper_names(ntargets);
888 for (int i = 0; i < ninputs; ++i) {
889 input_names[i] = c_input_names[i];
890 }
891 for (int i = 0; i < noutputs; ++i) {
892 output_names[i] = c_output_names[i];
893 }
894 for (int i = 0; i < ntargets; ++i) {
895 target_oper_names[i] = c_target_oper_names[i];
896 }
897 string new_handle;
898 status->status = s->session->PRunSetup(input_names, output_names,
899 target_oper_names, &new_handle);
900 if (status->status.ok()) {
901 char* buf = new char[new_handle.size() + 1];
902 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
903 *handle = buf;
904 }
905}
906
907void TF_PRun(TF_DeprecatedSession* s, const char* handle,
908 // Input tensors
909 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
910 // Output tensors
911 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
912 // Target nodes
913 const char** c_target_oper_names, int ntargets,
914 TF_Status* status) {
915 TF_Run_Setup(noutputs, c_outputs, status);
916 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
917 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
918 for (int i = 0; i < ninputs; ++i) {
919 input_pairs[i].first = c_input_names[i];
920 }
921
922 std::vector<string> output_names(noutputs);
923 for (int i = 0; i < noutputs; ++i) {
924 output_names[i] = c_output_names[i];
925 }
926 std::vector<string> target_oper_names(ntargets);
927 for (int i = 0; i < ntargets; ++i) {
928 target_oper_names[i] = c_target_oper_names[i];
929 }
930 TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
931 c_outputs, target_oper_names, nullptr, status);
932}
933
934TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
935 TF_Library* lib_handle = new TF_Library;
936 status->status = tensorflow::LoadLibrary(
937 library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
938 &lib_handle->op_list.length);
939 if (!status->status.ok()) {
940 delete lib_handle;
941 return nullptr;
942 }
943 return lib_handle;
944}
945
946TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
947
948void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
949 tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
950 delete lib_handle;
951}
952
953TF_Buffer* TF_GetAllOpList() {
954 std::vector<tensorflow::OpDef> op_defs;
955 tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
956 tensorflow::OpList op_list;
957 for (const auto& op : op_defs) {
958 *(op_list.add_op()) = op;
959 }
960 TF_Buffer* ret = TF_NewBuffer();
961 TF_CHECK_OK(MessageToBuffer(op_list, ret));
962 return ret;
963}
964
965// --------------------------------------------------------------------------
966// ListDevices & SessionListDevices API
967
968void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
969
970TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
971 TF_DeviceList* response = new TF_DeviceList;
972 status->status = session->session->ListDevices(&response->response);
973 return response;
974}
975
976TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
977 TF_Status* status) {
978 TF_DeviceList* response = new TF_DeviceList;
979 status->status = session->session->ListDevices(&response->response);
980 return response;
981}
982
983int TF_DeviceListCount(const TF_DeviceList* list) {
984 return list->response.size();
985}
986
987#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
988 return_type method_name(const TF_DeviceList* list, const int index, \
989 TF_Status* status) { \
990 if (list == nullptr) { \
991 status->status = InvalidArgument("list is null!"); \
992 return err_val; \
993 } \
994 if (index < 0 || index >= list->response.size()) { \
995 status->status = InvalidArgument("index out of bounds"); \
996 return err_val; \
997 } \
998 status->status = Status::OK(); \
999 return list->response[index].accessor; \
1000 }
1001
1002TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
1003TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
1004 nullptr);
1005TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
1006
1007#undef TF_DEVICELIST_METHOD
1008
1009} // end extern "C"
1010
1011// --------------------------------------------------------------------------
1012// New Graph and Session API
1013
1014// Helper functions -----------------------------------------------------------
1015
1016namespace {
1017
1018TF_Operation* ToOperation(Node* node) {
1019 return static_cast<TF_Operation*>(static_cast<void*>(node));
1020}
1021
1022string OutputName(const TF_Output& output) {
1023 return StrCat(output.oper->node.name(), ":", output.index);
1024}
1025
1026const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
1027 const char* attr_name,
1028 TF_Status* status) {
1029 const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
1030 if (attr == nullptr) {
1031 status->status = InvalidArgument("Operation '", oper->node.name(),
1032 "' has no attr named '", attr_name, "'.");
1033 }
1034 return attr;
1035}
1036
1037TensorId ToTensorId(const TF_Output& output) {
1038 return TensorId(output.oper->node.name(), output.index);
1039}
1040
1041#ifndef __ANDROID__
1042std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
1043 int n) {
1044 std::vector<tensorflow::Output> outputs(n);
1045 for (int i = 0; i < n; ++i) {
1046 outputs[i] =
1047 tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
1048 }
1049 return outputs;
1050}
1051
1052void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
1053 TF_Output* tf_outputs) {
1054 for (int i = 0; i < outputs.size(); i++) {
1055 tf_outputs[i].oper = ToOperation(outputs[i].node());
1056 tf_outputs[i].index = outputs[i].index();
1057 }
1058}
1059#endif // __ANDROID__
1060
1061} // namespace
1062
1063// Shape functions -----------------------------------------------------------
1064
1065void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
1066 const int64_t* dims, const int num_dims,
1067 TF_Status* status) {
1068 Node* node = &output.oper->node;
1069
1070 mutex_lock l(graph->mu);
1071 tensorflow::shape_inference::InferenceContext* ic =
1072 graph->refiner.GetContext(node);
1073 if (ic == nullptr) {
1074 status->status =
1075 InvalidArgument("Node ", node->name(), " was not found in the graph");
1076 return;
1077 }
1078 tensorflow::shape_inference::ShapeHandle new_shape =
1079 tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
1080 status->status = graph->refiner.SetShape(node, output.index, new_shape);
1081}
1082
1083int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
1084 TF_Status* status) {
1085 Node* node = &output.oper->node;
1086
1087 mutex_lock l(graph->mu);
1088 tensorflow::shape_inference::InferenceContext* ic =
1089 graph->refiner.GetContext(node);
1090 if (ic == nullptr) {
1091 status->status =
1092 InvalidArgument("Node ", node->name(), " was not found in the graph");
1093 return -1;
1094 }
1095
1096 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1097
1098 // Unknown rank means the number of dimensions is -1.
1099 if (!ic->RankKnown(shape)) {
1100 return -1;
1101 }
1102
1103 return ic->Rank(shape);
1104}
1105
1106void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
1107 int num_dims, TF_Status* status) {
1108 Node* node = &output.oper->node;
1109
1110 mutex_lock l(graph->mu);
1111 tensorflow::shape_inference::InferenceContext* ic =
1112 graph->refiner.GetContext(node);
1113 if (ic == nullptr) {
1114 status->status =
1115 InvalidArgument("Node ", node->name(), " was not found in the graph");
1116 return;
1117 }
1118
1119 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1120
1121 int rank = -1;
1122 if (ic->RankKnown(shape)) {
1123 rank = ic->Rank(shape);
1124 }
1125
1126 if (num_dims != rank) {
1127 status->status = InvalidArgument("Expected rank is ", num_dims,
1128 " but actual rank is ", rank);
1129 return;
1130 }
1131
1132 if (num_dims == 0) {
1133 // Output shape is a scalar.
1134 return;
1135 }
1136
1137 // Rank is greater than 0, so fill in the values, if known, and
1138 // -1 for unknown values.
1139 for (int i = 0; i < num_dims; ++i) {
1140 auto dim = ic->Dim(shape, i);
1141 tensorflow::int64 value = -1;
1142 if (ic->ValueKnown(dim)) {
1143 value = ic->Value(dim);
1144 }
1145 dims[i] = value;
1146 }
1147}
1148
1149// TF_OperationDescription functions ------------------------------------------
1150
1151extern "C" {
1152
1153static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
1154 const char* op_type,
1155 const char* oper_name)
1156 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1157 return new TF_OperationDescription(graph, op_type, oper_name);
1158}
1159
1160TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
1161 const char* oper_name) {
1162 mutex_lock l(graph->mu);
1163 return TF_NewOperationLocked(graph, op_type, oper_name);
1164}
1165
1166void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
1167 desc->node_builder.Device(device);
1168}
1169
1170void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
1171 desc->node_builder.Input(&input.oper->node, input.index);
1172}
1173
1174void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
1175 int num_inputs) {
1176 std::vector<NodeBuilder::NodeOut> input_list;
1177 input_list.reserve(num_inputs);
1178 for (int i = 0; i < num_inputs; ++i) {
1179 input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
1180 }
1181 desc->node_builder.Input(input_list);
1182}
1183
1184void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
1185 desc->node_builder.ControlInput(&input->node);
1186}
1187
1188void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
1189 desc->colocation_constraints.emplace(
1190 StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
1191}
1192
1193void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
1194 const void* value, size_t length) {
1195 tensorflow::StringPiece s(static_cast<const char*>(value), length);
1196 desc->node_builder.Attr(attr_name, s);
1197}
1198
1199void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
1200 const void* const* values, const size_t* lengths,
1201 int num_values) {
1202 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1203 desc->colocation_constraints.clear();
1204 for (int i = 0; i < num_values; ++i) {
1205 desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
1206 lengths[i]);
1207 }
1208 } else {
1209 std::vector<tensorflow::StringPiece> v;
1210 v.reserve(num_values);
1211 for (int i = 0; i < num_values; ++i) {
1212 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
1213 }
1214 desc->node_builder.Attr(attr_name, v);
1215 }
1216}
1217
1218void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
1219 int64_t value) {
1220 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1221 "64-bit int types should match in size");
1222 desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
1223}
1224
1225void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
1226 const int64_t* values, int num_values) {
1227 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1228 "64-bit int types should match in size");
1229 desc->node_builder.Attr(
1230 attr_name,
1231 ArraySlice<const tensorflow::int64>(
1232 reinterpret_cast<const tensorflow::int64*>(values), num_values));
1233}
1234
1235void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
1236 float value) {
1237 desc->node_builder.Attr(attr_name, value);
1238}
1239
1240void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
1241 const float* values, int num_values) {
1242 desc->node_builder.Attr(attr_name,
1243 ArraySlice<const float>(values, num_values));
1244}
1245
1246void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
1247 unsigned char value) {
1248 desc->node_builder.Attr(attr_name, static_cast<bool>(value));
1249}
1250
1251void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
1252 const unsigned char* values, int num_values) {
1253 std::unique_ptr<bool[]> b(new bool[num_values]);
1254 for (int i = 0; i < num_values; ++i) {
1255 b[i] = values[i];
1256 }
1257 desc->node_builder.Attr(attr_name,
1258 ArraySlice<const bool>(b.get(), num_values));
1259}
1260
1261void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
1262 TF_DataType value) {
1263 desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
1264}
1265
1266void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
1267 const TF_DataType* values, int num_values) {
1268 desc->node_builder.Attr(
1269 attr_name, ArraySlice<const DataType>(
1270 reinterpret_cast<const DataType*>(values), num_values));
1271}
1272
1273void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
1274 const char* value, size_t length) {
1275 tensorflow::NameAttrList func_name;
1276 func_name.set_name(std::string(value, value + length));
1277 desc->node_builder.Attr(attr_name, func_name);
1278}
1279
1280void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
1281 const int64_t* dims, int num_dims) {
1282 PartialTensorShape shape;
1283 if (num_dims >= 0) {
1284 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1285 "64-bit int types should match in size");
1286 shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
1287 reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
1288 }
1289 desc->node_builder.Attr(attr_name, shape);
1290}
1291
1292void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
1293 const int64_t* const* dims, const int* num_dims,
1294 int num_shapes) {
1295 std::vector<PartialTensorShape> shapes;
1296 shapes.reserve(num_shapes);
1297 for (int i = 0; i < num_shapes; ++i) {
1298 if (num_dims[i] < 0) {
1299 shapes.emplace_back();
1300 } else {
1301 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1302 "64-bit int types should match in size");
1303 shapes.emplace_back(ArraySlice<tensorflow::int64>(
1304 reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
1305 }
1306 }
1307 desc->node_builder.Attr(attr_name, shapes);
1308}
1309
1310void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
1311 const char* attr_name, const void* proto,
1312 size_t proto_len, TF_Status* status) {
1313 // shape.ParseFromArray takes an int as length, this function takes size_t,
1314 // make sure there is no information loss.
1315 if (proto_len > std::numeric_limits<int>::max()) {
1316 status->status = InvalidArgument(
1317 "proto_len (", proto_len,
1318 " bytes) is too large to be parsed by the protocol buffer library");
1319 return;
1320 }
1321 TensorShapeProto shape;
1322 if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
1323 desc->node_builder.Attr(attr_name, shape);
1324 status->status = Status::OK();
1325 } else {
1326 status->status = InvalidArgument("Unparseable TensorShapeProto");
1327 }
1328}
1329
1330void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
1331 const char* attr_name,
1332 const void* const* protos,
1333 const size_t* proto_lens, int num_shapes,
1334 TF_Status* status) {
1335 std::vector<TensorShapeProto> shapes;
1336 shapes.resize(num_shapes);
1337 for (int i = 0; i < num_shapes; ++i) {
1338 if (proto_lens[i] > std::numeric_limits<int>::max()) {
1339 status->status = InvalidArgument(
1340 "length of element ", i, " in the list (", proto_lens[i],
1341 " bytes) is too large to be parsed by the protocol buffer library");
1342 return;
1343 }
1344 if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
1345 status->status =
1346 InvalidArgument("Unparseable TensorShapeProto at index ", i);
1347 return;
1348 }
1349 }
1350 desc->node_builder.Attr(attr_name, shapes);
1351 status->status = Status::OK();
1352}
1353
1354void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1355 TF_Tensor* value, TF_Status* status) {
1356 Tensor t;
1357 status->status = TF_TensorToTensor(value, &t);
1358 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1359}
1360
1361void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1362 TF_Tensor* const* values, int num_values,
1363 TF_Status* status) {
1364 status->status = Status::OK();
1365 std::vector<Tensor> t;
1366 t.reserve(num_values);
1367
1368 for (int i = 0; i < num_values && status->status.ok(); ++i) {
1369 Tensor v;
1370 status->status = TF_TensorToTensor(values[i], &v);
1371 t.emplace_back(v);
1372 }
1373
1374 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1375}
1376
1377void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1378 const void* proto, size_t proto_len,
1379 TF_Status* status) {
1380 tensorflow::AttrValue attr_value;
1381 if (!attr_value.ParseFromArray(proto, proto_len)) {
1382 status->status = InvalidArgument("Unparseable AttrValue proto");
1383 return;
1384 }
1385
1386 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1387 if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1388 attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1389 status->status =
1390 InvalidArgument("Expected \"list\" field for \"",
1391 tensorflow::kColocationAttrName, "\" attribute");
1392 return;
1393 }
1394 desc->colocation_constraints.clear();
1395 for (const string& location : attr_value.list().s()) {
1396 desc->colocation_constraints.insert(location);
1397 }
1398 } else {
1399 desc->node_builder.Attr(attr_name, attr_value);
1400 }
1401
1402 status->status = Status::OK();
1403}
1404
1405static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1406 TF_Status* status)
1407 EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1408 Node* ret = nullptr;
1409
1410 if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1411 status->status = InvalidArgument("Duplicate node name in graph: '",
1412 desc->node_builder.node_name(), "'");
1413 } else {
1414 if (!desc->colocation_constraints.empty()) {
1415 desc->node_builder.Attr(
1416 tensorflow::kColocationAttrName,
1417 std::vector<string>(desc->colocation_constraints.begin(),
1418 desc->colocation_constraints.end()));
1419 }
1420 status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
1421
1422 if (status->status.ok()) {
1423 // Run shape inference function for newly added node.
1424 status->status = desc->graph->refiner.AddNode(ret);
1425 }
1426 if (status->status.ok()) {
1427 // Add the node to the name-to-node mapping.
1428 desc->graph->name_map[ret->name()] = ret;
1429 } else if (ret != nullptr) {
1430 desc->graph->graph.RemoveNode(ret);
1431 ret = nullptr;
1432 }
1433 }
1434
1435 delete desc;
1436
1437 return ToOperation(ret);
1438}
1439
1440TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1441 TF_Status* status) {
1442 mutex_lock l(desc->graph->mu);
1443 return TF_FinishOperationLocked(desc, status);
1444}
1445
1446// TF_Operation functions
1447// ----------------------------------------------------------
1448
1449const char* TF_OperationName(TF_Operation* oper) {
1450 return oper->node.name().c_str();
1451}
1452
1453const char* TF_OperationOpType(TF_Operation* oper) {
1454 return oper->node.type_string().c_str();
1455}
1456
1457const char* TF_OperationDevice(TF_Operation* oper) {
1458 return oper->node.requested_device().c_str();
1459}
1460
1461int TF_OperationNumOutputs(TF_Operation* oper) {
1462 return oper->node.num_outputs();
1463}
1464
1465TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1466 return static_cast<TF_DataType>(
1467 oper_out.oper->node.output_type(oper_out.index));
1468}
1469
1470int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1471 TF_Status* status) {
1472 NameRangeMap name_ranges;
1473 status->status =
1474 NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1475 if (!status->status.ok()) return -1;
1476 auto iter = name_ranges.find(arg_name);
1477 if (iter == name_ranges.end()) {
1478 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1479 return -1;
1480 }
1481 return iter->second.second - iter->second.first;
1482}
1483
1484int TF_OperationNumInputs(TF_Operation* oper) {
1485 return oper->node.num_inputs();
1486}
1487
1488TF_DataType TF_OperationInputType(TF_Input oper_in) {
1489 return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1490}
1491
1492int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1493 TF_Status* status) {
1494 NameRangeMap name_ranges;
1495 status->status =
1496 NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1497 if (!status->status.ok()) return -1;
1498 auto iter = name_ranges.find(arg_name);
1499 if (iter == name_ranges.end()) {
1500 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1501 return -1;
1502 }
1503 return iter->second.second - iter->second.first;
1504}
1505
1506TF_Output TF_OperationInput(TF_Input oper_in) {
1507 const tensorflow::Edge* edge;
1508 Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1509 if (!s.ok()) {
1510 return {nullptr, -1};
1511 }
1512
1513 return {ToOperation(edge->src()), edge->src_output()};
1514}
1515
1516int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1517 int count = 0;
1518 for (const auto* edge : oper_out.oper->node.out_edges()) {
1519 if (edge->src_output() == oper_out.index) {
1520 ++count;
1521 }
1522 }
1523 return count;
1524}
1525
1526int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1527 int max_consumers) {
1528 int count = 0;
1529 for (const auto* edge : oper_out.oper->node.out_edges()) {
1530 if (edge->src_output() == oper_out.index) {
1531 if (count < max_consumers) {
1532 consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1533 }
1534 ++count;
1535 }
1536 }
1537 return count;
1538}
1539
1540int TF_OperationNumControlInputs(TF_Operation* oper) {
1541 int count = 0;
1542 for (const auto* edge : oper->node.in_edges()) {
1543 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1544 ++count;
1545 }
1546 }
1547 return count;
1548}
1549
1550int TF_OperationGetControlInputs(TF_Operation* oper,
1551 TF_Operation** control_inputs,
1552 int max_control_inputs) {
1553 int count = 0;
1554 for (const auto* edge : oper->node.in_edges()) {
1555 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1556 if (count < max_control_inputs) {
1557 control_inputs[count] = ToOperation(edge->src());
1558 }
1559 ++count;
1560 }
1561 }
1562 return count;
1563}
1564
1565int TF_OperationNumControlOutputs(TF_Operation* oper) {
1566 int count = 0;
1567 for (const auto* edge : oper->node.out_edges()) {
1568 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1569 ++count;
1570 }
1571 }
1572 return count;
1573}
1574
1575int TF_OperationGetControlOutputs(TF_Operation* oper,
1576 TF_Operation** control_outputs,
1577 int max_control_outputs) {
1578 int count = 0;
1579 for (const auto* edge : oper->node.out_edges()) {
1580 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1581 if (count < max_control_outputs) {
1582 control_outputs[count] = ToOperation(edge->dst());
1583 }
1584 ++count;
1585 }
1586 }
1587 return count;
1588}
1589
1590TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1591 const char* attr_name,
1592 TF_Status* status) {
1593 TF_AttrMetadata metadata;
1594 const auto* attr = GetAttrValue(oper, attr_name, status);
1595 if (!status->status.ok()) return metadata;
1596 switch (attr->value_case()) {
1597#define SINGLE_CASE(kK, attr_type, size_expr) \
1598 case tensorflow::AttrValue::kK: \
1599 metadata.is_list = 0; \
1600 metadata.list_size = -1; \
1601 metadata.type = attr_type; \
1602 metadata.total_size = size_expr; \
1603 break;
1604
1605 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1606 SINGLE_CASE(kI, TF_ATTR_INT, -1);
1607 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1608 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1609 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1610 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1611 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1612 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1613#undef SINGLE_CASE
1614
1615 case tensorflow::AttrValue::kList:
1616 metadata.is_list = 1;
1617 metadata.list_size = 0;
1618 metadata.total_size = -1;
1619#define LIST_CASE(field, attr_type, ...) \
1620 if (attr->list().field##_size() > 0) { \
1621 metadata.type = attr_type; \
1622 metadata.list_size = attr->list().field##_size(); \
1623 __VA_ARGS__; \
1624 break; \
1625 }
1626
1627 LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0;
1628 for (int i = 0; i < attr->list().s_size();
1629 ++i) { metadata.total_size += attr->list().s(i).size(); });
1630 LIST_CASE(i, TF_ATTR_INT);
1631 LIST_CASE(f, TF_ATTR_FLOAT);
1632 LIST_CASE(b, TF_ATTR_BOOL);
1633 LIST_CASE(type, TF_ATTR_TYPE);
1634 LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1635 for (int i = 0; i < attr->list().shape_size(); ++i) {
1636 const auto& s = attr->list().shape(i);
1637 metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1638 });
1639 LIST_CASE(tensor, TF_ATTR_TENSOR);
1640 LIST_CASE(tensor, TF_ATTR_FUNC);
1641#undef LIST_CASE
1642 // All lists empty, determine the type from the OpDef.
1643 if (metadata.list_size == 0) {
1644 for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1645 const auto& a = oper->node.op_def().attr(i);
1646 if (a.name().compare(attr_name) != 0) continue;
1647 const string& typestr = a.type();
1648 if (typestr == "list(string)") {
1649 metadata.type = TF_ATTR_STRING;
1650 } else if (typestr == "list(int)") {
1651 metadata.type = TF_ATTR_INT;
1652 } else if (typestr == "list(float)") {
1653 metadata.type = TF_ATTR_FLOAT;
1654 } else if (typestr == "list(bool)") {
1655 metadata.type = TF_ATTR_BOOL;
1656 } else if (typestr == "list(type)") {
1657 metadata.type = TF_ATTR_TYPE;
1658 } else if (typestr == "list(shape)") {
1659 metadata.type = TF_ATTR_SHAPE;
1660 } else if (typestr == "list(tensor)") {
1661 metadata.type = TF_ATTR_TENSOR;
1662 } else if (typestr == "list(func)") {
1663 metadata.type = TF_ATTR_FUNC;
1664 } else {
1665 status->status = InvalidArgument(
1666 "Attribute '", attr_name,
1667 "' has an empty value of an unrecognized type '", typestr, "'");
1668 return metadata;
1669 }
1670 }
1671 }
1672 break;
1673
1674 case tensorflow::AttrValue::kPlaceholder:
1675 metadata.is_list = 0;
1676 metadata.list_size = -1;
1677 metadata.type = TF_ATTR_PLACEHOLDER;
1678 metadata.total_size = -1;
1679 break;
1680
1681 case tensorflow::AttrValue::kFunc:
1682 metadata.is_list = 0;
1683 metadata.list_size = -1;
1684 metadata.type = TF_ATTR_FUNC;
1685 metadata.total_size = -1;
1686 break;
1687
1688 case tensorflow::AttrValue::VALUE_NOT_SET:
1689 status->status =
1690 InvalidArgument("Attribute '", attr_name, "' has no value set");
1691 break;
1692 }
1693 return metadata;
1694}
1695
1696void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1697 void* value, size_t max_length,
1698 TF_Status* status) {
1699 const auto* attr = GetAttrValue(oper, attr_name, status);
1700 if (!status->status.ok()) return;
1701 if (attr->value_case() != tensorflow::AttrValue::kS) {
1702 status->status =
1703 InvalidArgument("Attribute '", attr_name, "' is not a string");
1704 return;
1705 }
1706 if (max_length <= 0) {
1707 return;
1708 }
1709 const auto& s = attr->s();
1710 std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1711}
1712
1713void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1714 void** values, size_t* lengths,
1715 int max_values, void* storage,
1716 size_t storage_size, TF_Status* status) {
1717 const auto* attr = GetAttrValue(oper, attr_name, status);
1718 if (!status->status.ok()) return;
1719 if (attr->value_case() != tensorflow::AttrValue::kList) {
1720 status->status =
1721 InvalidArgument("Value for '", attr_name, "' is not a list");
1722 return;
1723 }
1724 const auto len = std::min(max_values, attr->list().s_size());
1725 char* p = static_cast<char*>(storage);
1726 for (int i = 0; i < len; ++i) {
1727 const string& s = attr->list().s(i);
1728 values[i] = p;
1729 lengths[i] = s.size();
1730 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1731 status->status = InvalidArgument(
1732 "Not enough storage to hold the requested list of strings");
1733 return;
1734 }
1735 memcpy(values[i], s.data(), s.size());
1736 p += s.size();
1737 }
1738}
1739
1740#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
1741 void func(TF_Operation* oper, const char* attr_name, c_type* value, \
1742 TF_Status* status) { \
1743 cpp_type v; \
1744 status->status = \
1745 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
1746 *value = static_cast<c_type>(v); \
1747 } \
1748 void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1749 int max_values, TF_Status* status) { \
1750 const auto* attr = GetAttrValue(oper, attr_name, status); \
1751 if (!status->status.ok()) return; \
1752 if (attr->value_case() != tensorflow::AttrValue::kList) { \
1753 status->status = \
1754 InvalidArgument("Value for '", attr_name, "' is not a list."); \
1755 return; \
1756 } \
1757 const auto len = std::min(max_values, attr->list().list_field##_size()); \
1758 for (int i = 0; i < len; ++i) { \
1759 values[i] = static_cast<c_type>(attr->list().list_field(i)); \
1760 } \
1761 }
1762DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1763DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1764DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1765DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1766#undef DEFINE_GETATTR
1767
1768void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1769 int64_t* value, int num_dims, TF_Status* status) {
1770 PartialTensorShape shape;
1771 status->status =
1772 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1773 if (!status->status.ok()) return;
1774 auto len = std::min(shape.dims(), num_dims);
1775 for (int i = 0; i < len; ++i) {
1776 value[i] = shape.dim_size(i);
1777 }
1778}
1779
1780void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1781 int64_t** values, int* num_dims,
1782 int max_values, int64_t* storage,
1783 int storage_size, TF_Status* status) {
1784 std::vector<PartialTensorShape> shapes;
1785 status->status =
1786 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1787 if (!status->status.ok()) return;
1788 auto len = std::min(static_cast<int>(shapes.size()), max_values);
1789 int64_t* p = storage;
1790 int storage_left = storage_size;
1791 for (int i = 0; i < len; ++i) {
1792 // shapes[i].dims() == -1 for shapes with an unknown rank.
1793 int64_t n = shapes[i].dims();
1794 num_dims[i] = n;
1795 values[i] = p;
1796 if (n < 0) {
1797 continue;
1798 }
1799 if (storage_left < n) {
1800 status->status = InvalidArgument(
1801 "Not enough storage to hold the requested list of shapes");
1802 return;
1803 }
1804 storage_left -= n;
1805 for (int j = 0; j < n; ++j, ++p) {
1806 *p = shapes[i].dim_size(j);
1807 }
1808 }
1809}
1810
1811void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1812 const char* attr_name,
1813 TF_Buffer* value, TF_Status* status) {
1814 const auto* attr = GetAttrValue(oper, attr_name, status);
1815 if (!status->status.ok()) return;
1816 if (attr->value_case() != tensorflow::AttrValue::kShape) {
1817 status->status =
1818 InvalidArgument("Value for '", attr_name, "' is not a shape.");
1819 return;
1820 }
1821 status->status = MessageToBuffer(attr->shape(), value);
1822}
1823
1824void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1825 const char* attr_name,
1826 TF_Buffer** values, int max_values,
1827 TF_Status* status) {
1828 const auto* attr = GetAttrValue(oper, attr_name, status);
1829 if (!status->status.ok()) return;
1830 if (attr->value_case() != tensorflow::AttrValue::kList) {
1831 status->status =
1832 InvalidArgument("Value for '", attr_name, "' is not a list");
1833 return;
1834 }
1835 const auto len = std::min(max_values, attr->list().shape_size());
1836 for (int i = 0; i < len; ++i) {
1837 values[i] = TF_NewBuffer();
1838 status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1839 if (!status->status.ok()) {
1840 // Delete everything allocated to far, the operation has failed.
1841 for (int j = 0; j <= i; ++j) {
1842 TF_DeleteBuffer(values[j]);
1843 }
1844 return;
1845 }
1846 }
1847}
1848
1849void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1850 TF_Tensor** value, TF_Status* status) {
1851 *value = nullptr;
1852 Tensor t;
1853 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1854 if (!status->status.ok()) return;
1855 *value = TF_TensorFromTensor(t, status);
1856}
1857
1858void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1859 TF_Tensor** values, int max_values,
1860 TF_Status* status) {
1861 std::vector<Tensor> ts;
1862 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1863 if (!status->status.ok()) return;
1864 const auto len = std::min(max_values, static_cast<int>(ts.size()));
1865 for (int i = 0; i < len; ++i) {
1866 values[i] = TF_TensorFromTensor(ts[i], status);
1867 }
1868}
1869
1870void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1871 TF_Buffer* output_attr_value,
1872 TF_Status* status) {
1873 const auto* attr = GetAttrValue(oper, attr_name, status);
1874 if (!status->status.ok()) return;
1875 status->status = MessageToBuffer(*attr, output_attr_value);
1876}
1877
1878void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1879 TF_Status* status) {
1880 status->status = MessageToBuffer(oper->node.def(), output_node_def);
1881}
1882
1883// TF_Graph functions ---------------------------------------------------------
1884
1885TF_Graph::TF_Graph()
1886 : graph(tensorflow::OpRegistry::Global()),
1887 refiner(graph.versions().producer(), graph.op_registry()),
1888 delete_requested(false),
1889 parent(nullptr),
1890 parent_inputs(nullptr) {}
1891
1892TF_Graph* TF_NewGraph() { return new TF_Graph; }
1893
1894void TF_DeleteGraph(TF_Graph* g) {
1895 g->mu.lock();
1896 g->delete_requested = true;
1897 const bool del = g->sessions.empty();
1898 g->mu.unlock();
1899 if (del) delete g;
1900}
1901
1902TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1903 mutex_lock l(graph->mu);
1904 auto iter = graph->name_map.find(oper_name);
1905 if (iter == graph->name_map.end()) {
1906 return nullptr;
1907 } else {
1908 return ToOperation(iter->second);
1909 }
1910}
1911
1912TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1913 if (*pos == 0) {
1914 // Advance past the first sentinel nodes in every graph (the source & sink).
1915 *pos += 2;
1916 } else {
1917 // Advance to the next node.
1918 *pos += 1;
1919 }
1920
1921 mutex_lock l(graph->mu);
1922 while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1923 Node* node = graph->graph.FindNodeId(*pos);
1924 // FindNodeId() returns nullptr for nodes that have been deleted.
1925 // We aren't currently allowing nodes to be deleted, but it is safer
1926 // to still check.
1927 if (node != nullptr) return ToOperation(node);
1928 *pos += 1;
1929 }
1930
1931 // No more nodes.
1932 return nullptr;
1933}
1934
1935void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1936 TF_Status* status) {
1937 GraphDef def;
1938 {
1939 mutex_lock l(graph->mu);
1940 graph->graph.ToGraphDef(&def);
1941 }
1942 status->status = MessageToBuffer(def, output_graph_def);
1943}
1944
1945void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1946 TF_Buffer* output_op_def, TF_Status* status) {
1947 const OpDef* op_def;
1948 {
1949 mutex_lock l(graph->mu);
1950 status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1951 if (!status->status.ok()) return;
1952 }
1953 status->status = MessageToBuffer(*op_def, output_op_def);
1954}
1955
1956void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1957 TF_Status* status) {
1958 VersionDef versions;
1959 {
1960 mutex_lock l(graph->mu);
1961 versions = graph->graph.versions();
1962 }
1963 status->status = MessageToBuffer(versions, output_version_def);
1964}
1965
1966TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1967 return new TF_ImportGraphDefOptions;
1968}
1969void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1970 delete opts;
1971}
1972void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1973 const char* prefix) {
1974 opts->opts.prefix = prefix;
1975}
1976
1977void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1978 unsigned char uniquify_names) {
1979 opts->opts.uniquify_names = uniquify_names;
1980}
1981
1982void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1983 unsigned char uniquify_prefix) {
1984 opts->opts.uniquify_prefix = uniquify_prefix;
1985}
1986
1987void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1988 const char* src_name,
1989 int src_index, TF_Output dst) {
1990 opts->tensor_id_data.push_back(src_name);
1991 const string& src_name_str = opts->tensor_id_data.back();
1992 // We don't need to store dst's name in tensor_id_data, since `dst` must
1993 // outlive the ImportGraphDef call.
1994 opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1995}
1996
1997void TF_ImportGraphDefOptionsRemapControlDependency(
1998 TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1999 opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
2000 TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
2001}
2002
2003extern void TF_ImportGraphDefOptionsAddControlDependency(
2004 TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
2005 opts->opts.control_dependencies.push_back(oper->node.name());
2006}
2007
2008void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
2009 const char* oper_name, int index) {
2010 opts->tensor_id_data.push_back(oper_name);
2011 const string& oper_name_str = opts->tensor_id_data.back();
2012 opts->opts.return_tensors.emplace_back(oper_name_str, index);
2013}
2014
2015int TF_ImportGraphDefOptionsNumReturnOutputs(
2016 const TF_ImportGraphDefOptions* opts) {
2017 return opts->opts.return_tensors.size();
2018}
2019
2020void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
2021 const char* oper_name) {
2022 opts->opts.return_nodes.push_back(oper_name);
2023}
2024
2025int TF_ImportGraphDefOptionsNumReturnOperations(
2026 const TF_ImportGraphDefOptions* opts) {
2027 return opts->opts.return_nodes.size();
2028}
2029
2030void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
2031 int* num_outputs,
2032 TF_Output** outputs) {
2033 *num_outputs = results->return_tensors.size();
2034 *outputs = results->return_tensors.data();
2035}
2036
2037void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
2038 int* num_opers,
2039 TF_Operation*** opers) {
2040 *num_opers = results->return_nodes.size();
2041 *opers = results->return_nodes.data();
2042}
2043
2044void TF_ImportGraphDefResultsMissingUnusedInputMappings(
2045 TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
2046 const char*** src_names, int** src_indexes) {
2047 *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
2048 *src_names = results->missing_unused_key_names.data();
2049 *src_indexes = results->missing_unused_key_indexes.data();
2050}
2051
2052void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
2053 delete results;
2054}
2055
2056static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
2057 const TF_ImportGraphDefOptions* opts,
2058 TF_ImportGraphDefResults* tf_results,
2059 TF_Status* status)
2060 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
2061 const int last_node_id = graph->graph.num_node_ids();
2062 tensorflow::ImportGraphDefResults results;
2063 status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
2064 &graph->refiner, &results);
2065 if (!status->status.ok()) return;
2066
2067 // Add new nodes to name_map
2068 for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
2069 auto* node = graph->graph.FindNodeId(i);
2070 if (node != nullptr) graph->name_map[node->name()] = node;
2071 }
2072
2073 // Populate return_tensors
2074 DCHECK(tf_results->return_tensors.empty());
2075 tf_results->return_tensors.resize(results.return_tensors.size());
2076 for (int i = 0; i < results.return_tensors.size(); ++i) {
2077 tf_results->return_tensors[i].oper =
2078 ToOperation(results.return_tensors[i].first);
2079 tf_results->return_tensors[i].index = results.return_tensors[i].second;
2080 }
2081
2082 // Populate return_nodes
2083 DCHECK(tf_results->return_nodes.empty());
2084 tf_results->return_nodes.resize(results.return_nodes.size());
2085 for (int i = 0; i < results.return_nodes.size(); ++i) {
2086 tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
2087 }
2088
2089 // Populate missing unused map keys
2090 DCHECK(tf_results->missing_unused_key_names.empty());
2091 DCHECK(tf_results->missing_unused_key_indexes.empty());
2092 DCHECK(tf_results->missing_unused_key_names_data.empty());
2093
2094 size_t size = results.missing_unused_input_map_keys.size();
2095 tf_results->missing_unused_key_names.resize(size);
2096 tf_results->missing_unused_key_indexes.resize(size);
2097
2098 for (int i = 0; i < size; ++i) {
2099 TensorId id = results.missing_unused_input_map_keys[i];
2100 tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
2101 tf_results->missing_unused_key_names[i] =
2102 tf_results->missing_unused_key_names_data.back().c_str();
2103 tf_results->missing_unused_key_indexes[i] = id.second;
2104 }
2105}
2106
2107TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
2108 TF_Graph* graph, const TF_Buffer* graph_def,
2109 const TF_ImportGraphDefOptions* options, TF_Status* status) {
2110 GraphDef def;
2111 if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2112 status->status = InvalidArgument("Invalid GraphDef");
2113 return nullptr;
2114 }
2115 auto results = new TF_ImportGraphDefResults();
2116 mutex_lock l(graph->mu);
2117 GraphImportGraphDefLocked(graph, def, options, results, status);
2118 if (!status->status.ok()) {
2119 delete results;
2120 return nullptr;
2121 }
2122 return results;
2123}
2124
2125void TF_GraphImportGraphDefWithReturnOutputs(
2126 TF_Graph* graph, const TF_Buffer* graph_def,
2127 const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
2128 int num_return_outputs, TF_Status* status) {
2129 if (num_return_outputs != options->opts.return_tensors.size()) {
2130 status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
2131 options->opts.return_tensors.size(),
2132 ", got ", num_return_outputs);
2133 return;
2134 }
2135 if (num_return_outputs > 0 && return_outputs == nullptr) {
2136 status->status = InvalidArgument(
2137 "'return_outputs' must be preallocated to length ", num_return_outputs);
2138 return;
2139 }
2140 GraphDef def;
2141 if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2142 status->status = InvalidArgument("Invalid GraphDef");
2143 return;
2144 }
2145 TF_ImportGraphDefResults results;
2146 mutex_lock l(graph->mu);
2147 GraphImportGraphDefLocked(graph, def, options, &results, status);
2148 DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
2149 memcpy(return_outputs, results.return_tensors.data(),
2150 num_return_outputs * sizeof(TF_Output));
2151}
2152
2153void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
2154 const TF_ImportGraphDefOptions* options,
2155 TF_Status* status) {
2156 TF_ImportGraphDefResults* results =
2157 TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
2158 TF_DeleteImportGraphDefResults(results);
2159}
2160
2161// While loop functions -------------------------------------------------------
2162
2163namespace {
2164
2165#ifndef __ANDROID__
2166
2167// Creates a placeholder representing an input to the cond or body graph.
2168// TODO(skyewm): remove these from final graph
2169bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
2170 TF_Output* input, TF_Status* status) {
2171 TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
2172 TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
2173 // TODO(skyewm): set placeholder shape
2174 TF_Operation* oper = TF_FinishOperation(desc, status);
2175 if (!status->status.ok()) return false;
2176 *input = {oper, 0};
2177 return true;
2178}
2179
2180// Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
2181// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
2182// will be prepended to copied node names. `control_deps` are nodes in
2183// `dst_graph` that the copied `src_graph` nodes will have control dependencies
2184// on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
2185// in `dst_graph` will be returned. `return_nodes` must be non-null.
2186Status CopyGraph(Graph* src_graph, Graph* dst_graph,
2187 tensorflow::ShapeRefiner* dst_refiner,
2188 const TF_Output* src_inputs,
2189 const std::vector<tensorflow::Output>& dst_inputs,
2190 const string& prefix,
2191 const std::vector<tensorflow::Operation>& control_deps,
2192 const TF_Output* nodes_to_return, int nreturn_nodes,
2193 std::vector<tensorflow::Output>* return_nodes) {
2194 DCHECK(return_nodes != nullptr);
2195 GraphDef gdef;
2196 src_graph->ToGraphDef(&gdef);
2197
2198 tensorflow::ImportGraphDefOptions opts;
2199 opts.prefix = prefix;
2200
2201 for (int i = 0; i < dst_inputs.size(); ++i) {
2202 opts.input_map[ToTensorId(src_inputs[i])] =
2203 TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
2204 }
2205 opts.skip_mapped_nodes = true;
2206
2207 for (const tensorflow::Operation& op : control_deps) {
2208 opts.control_dependencies.push_back(op.node()->name());
2209 }
2210
2211 for (int i = 0; i < nreturn_nodes; ++i) {
2212 opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
2213 }
2214
2215 // TODO(skyewm): change to OutputTensor
2216 tensorflow::ImportGraphDefResults results;
2217 TF_RETURN_IF_ERROR(
2218 ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
2219
2220 for (const auto& pair : results.return_tensors) {
2221 return_nodes->emplace_back(pair.first, pair.second);
2222 }
2223 return Status::OK();
2224}
2225
2226bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
2227 if (params.cond_graph == nullptr || params.body_graph == nullptr ||
2228 params.cond_graph->parent == nullptr ||
2229 params.cond_graph->parent != params.body_graph->parent ||
2230 params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
2231 params.ninputs <= 0 || params.cond_inputs == nullptr ||
2232 params.body_inputs == nullptr || params.body_outputs == nullptr) {
2233 s->status = InvalidArgument(
2234 "TF_WhileParams must be created by successful TF_NewWhile() call");
2235 return false;
2236 }
2237 return true;
2238}
2239
2240bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
2241 if (params.cond_output.oper == nullptr) {
2242 s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
2243 return false;
2244 }
2245 for (int i = 0; i < params.ninputs; ++i) {
2246 if (params.body_outputs[i].oper == nullptr) {
2247 s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
2248 "field isn't set");
2249 return false;
2250 }
2251 }
2252 if (params.name == nullptr) {
2253 s->status = InvalidArgument("TF_WhileParams `name` field is null");
2254 return false;
2255 }
2256 return true;
2257}
2258
2259#endif // __ANDROID__
2260
2261void FreeWhileResources(const TF_WhileParams* params) {
2262 TF_DeleteGraph(params->cond_graph);
2263 TF_DeleteGraph(params->body_graph);
2264 delete[] params->cond_inputs;
2265 delete[] params->body_inputs;
2266 delete[] params->body_outputs;
2267}
2268
2269TF_WhileParams EmptyWhileParams() {
2270 return {0, nullptr, nullptr, {nullptr, 0},
2271 nullptr, nullptr, nullptr, nullptr};
2272}
2273
2274} // namespace
2275
2276TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
2277 TF_Status* status) {
2278#ifdef __ANDROID__
2279 status->status = tensorflow::errors::Unimplemented(
2280 "Creating while loops is not supported in Android. File a bug at "
2281 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2282 "important to you");
2283 return EmptyWhileParams();
2284#else
2285 if (ninputs == 0) {
2286 status->status =
2287 InvalidArgument("TF_NewWhile() must be passed at least one input");
2288 return EmptyWhileParams();
2289 }
2290
2291 TF_Graph* cond_graph = TF_NewGraph();
2292 TF_Graph* body_graph = TF_NewGraph();
2293 cond_graph->parent = g;
2294 cond_graph->parent_inputs = inputs;
2295 body_graph->parent = g;
2296 body_graph->parent_inputs = inputs;
2297
2298 TF_Output* cond_inputs = new TF_Output[ninputs];
2299 TF_Output cond_output = {nullptr, -1};
2300 TF_Output* body_inputs = new TF_Output[ninputs];
2301 TF_Output* body_outputs = new TF_Output[ninputs];
2302 for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
2303 const char* name = nullptr;
2304
2305 for (int i = 0; i < ninputs; ++i) {
2306 // TODO(skyewm): prefix names with underscore (requires some plumbing)
2307 if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
2308 &cond_inputs[i], status)) {
2309 break;
2310 }
2311 if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
2312 &body_inputs[i], status)) {
2313 break;
2314 }
2315 }
2316
2317 TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
2318 body_graph, body_inputs, body_outputs, name};
2319
2320 if (!status->status.ok()) {
2321 FreeWhileResources(&params);
2322 return EmptyWhileParams();
2323 }
2324 return params;
2325#endif // __ANDROID__
2326}
2327
2328#ifndef __ANDROID__
2329namespace {
2330
2331// TODO(skyewm): make nodes in while loop unfetchable like in Python version
2332void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2333 TF_Output* outputs) {
2334 if (!ValidateInputWhileParams(*params, status)) return;
2335
2336 TF_Graph* parent = params->cond_graph->parent;
2337 TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2338 int num_loop_vars = params->ninputs;
2339
2340 mutex_lock l(parent->mu);
2341
2342 // 'cond_fn' copies the cond graph into the parent graph.
2343 tensorflow::ops::CondGraphBuilderFn cond_fn =
2344 [params, parent](const tensorflow::Scope& scope,
2345 const std::vector<tensorflow::Output>& inputs,
2346 tensorflow::Output* output) {
2347 DCHECK_EQ(scope.graph(), &parent->graph);
2348 std::vector<tensorflow::Output> cond_output;
2349 TF_RETURN_IF_ERROR(CopyGraph(
2350 &params->cond_graph->graph, &parent->graph, &parent->refiner,
2351 params->cond_inputs, inputs, scope.impl()->name(),
2352 scope.impl()->control_deps(), &params->cond_output,
2353 /* nreturn_nodes */ 1, &cond_output));
2354 *output = cond_output[0];
2355 return Status::OK();
2356 };
2357
2358 // 'body_fn' copies the body graph into the parent graph.
2359 tensorflow::ops::BodyGraphBuilderFn body_fn =
2360 [params, parent, num_loop_vars](
2361 const tensorflow::Scope& scope,
2362 const std::vector<tensorflow::Output>& inputs,
2363 std::vector<tensorflow::Output>* outputs) {
2364 DCHECK_EQ(scope.graph(), &parent->graph);
2365 TF_RETURN_IF_ERROR(
2366 CopyGraph(&params->body_graph->graph, &parent->graph,
2367 &parent->refiner, params->body_inputs, inputs,
2368 scope.impl()->name(), scope.impl()->control_deps(),
2369 params->body_outputs, num_loop_vars, outputs));
2370 return Status::OK();
2371 };
2372
2373 // Create the while loop using an internal scope.
2374 tensorflow::Scope scope =
2375 NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2376 .NewSubScope(params->name);
2377
2378 const int first_new_node_id = parent->graph.num_node_ids();
2379
2380 tensorflow::OutputList loop_outputs;
2381 status->status = tensorflow::ops::BuildWhileLoop(
2382 scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2383 body_fn, params->name, &loop_outputs);
2384
2385 // Update name_map with newly-created ops.
2386 // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2387 // a bad status. Once we fix this, we may want to return early instead of
2388 // executing the following code.
2389 for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2390 Node* new_node = parent->graph.FindNodeId(i);
2391 if (new_node == nullptr) continue;
2392 parent->name_map[new_node->name()] = new_node;
2393 }
2394
2395 // Populate 'outputs'.
2396 DCHECK_LE(loop_outputs.size(), num_loop_vars);
2397 for (int i = 0; i < loop_outputs.size(); ++i) {
2398 outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2399 }
2400}
2401
2402} // namespace
2403#endif // __ANDROID__
2404
2405void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2406 TF_Output* outputs) {
2407#ifdef __ANDROID__
2408 status->status = tensorflow::errors::Unimplemented(
2409 "Creating while loops is not supported in Android. File a bug at "
2410 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2411 "important to you");
2412#else
2413 // If it appears the caller created or modified `params`, don't free resources
2414 if (!ValidateConstWhileParams(*params, status)) return;
2415 TF_FinishWhileHelper(params, status, outputs);
2416 FreeWhileResources(params);
2417#endif // __ANDROID__
2418}
2419
2420void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2421
2422void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2423 TF_Output* dx, TF_Status* status, TF_Output* dy) {
2424#ifdef __ANDROID__
2425 status->status = tensorflow::errors::Unimplemented(
2426 "Adding gradients is not supported in Android. File a bug at "
2427 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2428 "important to you");
2429#else
2430 std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2431 std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2432 std::vector<tensorflow::Output> dy_arg;
2433
2434 {
2435 // We need to hold on to the lock while we have a scope that uses TF_Graph.
2436 mutex_lock graph_lock(g->mu);
2437
2438 const int first_new_node_id = g->graph.num_node_ids();
2439
2440 tensorflow::Scope scope =
2441 NewInternalScope(&g->graph, &status->status, &g->refiner)
2442 .NewSubScope("gradients");
2443
2444 if (dx != nullptr) {
2445 std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2446 status->status =
2447 AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2448 } else {
2449 status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2450 }
2451
2452 // Update g->name_map with the name_map from the scope, which will contain
2453 // the new gradient ops.
2454 for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2455 Node* n = g->graph.FindNodeId(i);
2456 if (n == nullptr) continue;
2457 g->name_map[n->name()] = n;
2458 }
2459 }
2460
2461 // Unpack the results from grad_outputs_arg.
2462 TFOutputsFromOutputs(dy_arg, dy);
2463#endif // __ANDROID__
2464}
2465
2466// TF_Session functions ----------------------------------------------
2467
2468TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2469 : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2470
2471TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2472 TF_Status* status) {
2473 Session* session;
2474 status->status = NewSession(opt->options, &session);
2475 if (status->status.ok()) {
2476 TF_Session* new_session = new TF_Session(session, graph);
2477 if (graph != nullptr) {
2478 mutex_lock l(graph->mu);
2479 graph->sessions[new_session] = "";
2480 }
2481 return new_session;
2482 } else {
2483 DCHECK_EQ(nullptr, session);
2484 return nullptr;
2485 }
2486}
2487
2488TF_Session* TF_LoadSessionFromSavedModel(
2489 const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2490 const char* export_dir, const char* const* tags, int tags_len,
2491 TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2492// TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
2493// the tensorflow/cc/saved_model:loader build target is Android friendly.
2494#ifdef __ANDROID__
2495 status->status = tensorflow::errors::Unimplemented(
2496 "Loading a SavedModel is not supported in Android. File a bug at "
2497 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2498 "important to you");
2499 return nullptr;
2500#else
2501 mutex_lock l(graph->mu);
2502 if (!graph->name_map.empty()) {
2503 status->status = InvalidArgument("Graph is non-empty.");
2504 return nullptr;
2505 }
2506
2507 RunOptions run_options_proto;
2508 if (run_options != nullptr && !run_options_proto.ParseFromArray(
2509 run_options->data, run_options->length)) {
2510 status->status = InvalidArgument("Unparseable RunOptions proto");
2511 return nullptr;
2512 }
2513
2514 std::unordered_set<string> tag_set;
2515 for (int i = 0; i < tags_len; i++) {
2516 tag_set.insert(string(tags[i]));
2517 }
2518
2519 tensorflow::SavedModelBundle bundle;
2520 status->status =
2521 tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2522 export_dir, tag_set, &bundle);
2523 if (!status->status.ok()) return nullptr;
2524
2525 // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2526 // extends using GraphDefs. The Graph instance is different, but equivalent
2527 // to the one used to create the session.
2528 //
2529 // TODO(jhseu): When Session is modified to take Graphs instead of
2530 // GraphDefs, return the Graph generated in LoadSavedModel().
2531 TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2532 TF_ImportGraphDefResults results;
2533 GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2534 import_opts, &results, status);
2535 TF_DeleteImportGraphDefOptions(import_opts);
2536 if (TF_GetCode(status) != TF_OK) return nullptr;
2537
2538 if (meta_graph_def != nullptr) {
2539 status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2540 if (!status->status.ok()) return nullptr;
2541 }
2542
2543 TF_Session* session = new TF_Session(bundle.session.release(), graph);
2544
2545 graph->sessions[session] = "";
2546 session->last_num_graph_nodes = graph->graph.num_node_ids();
2547 return session;
2548#endif // __ANDROID__
2549}
2550
2551void TF_CloseSession(TF_Session* s, TF_Status* status) {
2552 status->status = s->session->Close();
2553}
2554
2555void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2556 status->status = Status::OK();
2557 TF_Graph* const graph = s->graph;
2558 if (graph != nullptr) {
2559 graph->mu.lock();
2560 graph->sessions.erase(s);
2561 const bool del = graph->delete_requested && graph->sessions.empty();
2562 graph->mu.unlock();
2563 if (del) delete graph;
2564 }
2565 delete s->session;
2566 delete s;
2567}
2568
2569void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2570 const TF_Output* inputs, TF_Tensor* const* input_values,
2571 int ninputs, const TF_Output* outputs,
2572 TF_Tensor** output_values, int noutputs,
2573 const TF_Operation* const* target_opers, int ntargets,
2574 TF_Buffer* run_metadata, TF_Status* status) {
2575 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2576 // directly, instead of requiring us to serialize to a GraphDef and
2577 // call Session::Extend().
2578 if (session->extend_before_run &&
2579 !ExtendSessionGraphHelper(session, status)) {
2580 return;
2581 }
2582
2583 TF_Run_Setup(noutputs, output_values, status);
2584
2585 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2586 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2587 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2588 for (int i = 0; i < ninputs; ++i) {
2589 input_pairs[i].first = OutputName(inputs[i]);
2590 }
2591
2592 // Convert from TF_Output to string names.
2593 std::vector<string> output_names(noutputs);
2594 for (int i = 0; i < noutputs; ++i) {
2595 output_names[i] = OutputName(outputs[i]);
2596 }
2597
2598 // Convert from TF_Operation* to string names.
2599 std::vector<string> target_names(ntargets);
2600 for (int i = 0; i < ntargets; ++i) {
2601 target_names[i] = target_opers[i]->node.name();
2602 }
2603
2604 // Actually run.
2605 TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2606 output_names, output_values, target_names, run_metadata,
2607 status);
2608}
2609
2610void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2611 int ninputs, const TF_Output* outputs, int noutputs,
2612 const TF_Operation* const* target_opers, int ntargets,
2613 const char** handle, TF_Status* status) {
2614 *handle = nullptr;
2615
2616 if (session->extend_before_run &&
2617 !ExtendSessionGraphHelper(session, status)) {
2618 return;
2619 }
2620
2621 std::vector<string> input_names(ninputs);
2622 for (int i = 0; i < ninputs; ++i) {
2623 input_names[i] = OutputName(inputs[i]);
2624 }
2625
2626 std::vector<string> output_names(noutputs);
2627 for (int i = 0; i < noutputs; ++i) {
2628 output_names[i] = OutputName(outputs[i]);
2629 }
2630
2631 std::vector<string> target_names(ntargets);
2632 for (int i = 0; i < ntargets; ++i) {
2633 target_names[i] = target_opers[i]->node.name();
2634 }
2635
2636 string new_handle;
2637 status->status = session->session->PRunSetup(input_names, output_names,
2638 target_names, &new_handle);
2639 if (status->status.ok()) {
2640 char* buf = new char[new_handle.size() + 1];
2641 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2642 *handle = buf;
2643 }
2644}
2645
2646void TF_DeletePRunHandle(const char* handle) {
2647 delete[] handle;
2648 // TODO(suharshs): Free up any resources held by the partial run state.
2649}
2650
2651void TF_SessionPRun(TF_Session* session, const char* handle,
2652 const TF_Output* inputs, TF_Tensor* const* input_values,
2653 int ninputs, const TF_Output* outputs,
2654 TF_Tensor** output_values, int noutputs,
2655 const TF_Operation* const* target_opers, int ntargets,
2656 TF_Status* status) {
2657 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2658 // directly, instead of requiring us to serialize to a GraphDef and
2659 // call Session::Extend().
2660 if (session->extend_before_run &&
2661 !ExtendSessionGraphHelper(session, status)) {
2662 return;
2663 }
2664
2665 TF_Run_Setup(noutputs, output_values, status);
2666
2667 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2668 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2669 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2670 for (int i = 0; i < ninputs; ++i) {
2671 input_pairs[i].first = OutputName(inputs[i]);
2672 }
2673
2674 // Convert from TF_Output to string names.
2675 std::vector<string> output_names(noutputs);
2676 for (int i = 0; i < noutputs; ++i) {
2677 output_names[i] = OutputName(outputs[i]);
2678 }
2679
2680 // Convert from TF_Operation* to string names.
2681 std::vector<string> target_names(ntargets);
2682 for (int i = 0; i < ntargets; ++i) {
2683 target_names[i] = target_opers[i]->node.name();
2684 }
2685
2686 TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2687 output_values, target_names, nullptr, status);
2688}
2689
2690unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2691 TF_Tensor** result, TF_Status* status) {
2692 *result = nullptr;
2693 mutex_lock l(graph->mu);
2694 OutputTensor tensor(&output.oper->node, output.index);
2695 bool evaluated;
2696 Tensor result_tensor;
2697 status->status = EvaluateConstantTensor(
2698 tensor, graph->refiner, *graph->graph.op_registry(),
2699 graph->graph.versions().producer(), &evaluated, &result_tensor);
2700 if (evaluated) {
2701 DCHECK(status->status.ok());
2702 *result = TF_TensorFromTensor(result_tensor, status);
2703 if (!status->status.ok()) evaluated = false;
2704 }
2705 return evaluated;
2706}
2707
2708TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2709 tensorflow::OpList op_list;
2710 if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2711 status->status = InvalidArgument("Unparseable OpList");
2712 return nullptr;
2713 }
2714 status->status = Status::OK();
2715 return new TF_ApiDefMap(op_list);
2716}
2717
2718void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2719
2720void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2721 size_t text_len, TF_Status* status) {
2722#ifdef __ANDROID__
2723 status->status = tensorflow::errors::Unimplemented(
2724 "ApiDefMap is not supported in Android.");
2725#else
2726 mutex_lock l(api_def_map->lock);
2727 if (api_def_map->update_docs_called) {
2728 status->status = FailedPrecondition(
2729 "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2730 "called.");
2731 return;
2732 }
2733 string api_def_text(text, text_len);
2734 status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2735#endif // __ANDROID__
2736}
2737
2738TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2739 size_t name_len, TF_Status* status) {
2740#ifdef __ANDROID__
2741 status->status = tensorflow::errors::Unimplemented(
2742 "ApiDefMap is not supported in Android.");
2743 return nullptr;
2744#else
2745 mutex_lock l(api_def_map->lock);
2746 if (!api_def_map->update_docs_called) {
2747 api_def_map->api_def_map.UpdateDocs();
2748 api_def_map->update_docs_called = true;
2749 }
2750 string name_str(name, name_len);
2751 const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2752
2753 TF_Buffer* ret = TF_NewBuffer();
2754 status->status = MessageToBuffer(*api_def, ret);
2755 return ret;
2756#endif // __ANDROID__
2757}
2758} // end extern "C"
2759