1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/lite/interpreter.h"
17#include <cassert>
18#include <cstdarg>
19#include <cstdint>
20#include <cstring>
21#include "tensorflow/contrib/lite/arena_planner.h"
22#include "tensorflow/contrib/lite/context.h"
23#include "tensorflow/contrib/lite/error_reporter.h"
24#include "tensorflow/contrib/lite/graph_info.h"
25#include "tensorflow/contrib/lite/kernels/eigen_support.h"
26#include "tensorflow/contrib/lite/kernels/gemm_support.h"
27#include "tensorflow/contrib/lite/memory_planner.h"
28#include "tensorflow/contrib/lite/nnapi_delegate.h"
29#include "tensorflow/contrib/lite/schema/schema_generated.h"
30#include "tensorflow/contrib/lite/util.h"
31
32namespace tflite {
33
34namespace {
35
36// Stub method which returns kTfLiteError when the function is forbidden.
37// We're registrating this function to several different function to save
38// compiled binary size. Please note the restrictions:
39// * The type of first parameter have to be `TfLiteContext*`.
40// * All paramteters must be trivailly destructible. (E.g. No C++ class)
41TfLiteStatus ForbiddenContextFunction(TfLiteContext* context, ...) {
42 context->ReportError(context,
43 "The function is forbidden if not calling in delegate.");
44 return kTfLiteError;
45}
46
47// Set the ForbiddenContextFunction to a compatible function pointer.
48template <typename FunctionType>
49void SetForbiddenContextFunction(FunctionType* func) {
50 *func = reinterpret_cast<FunctionType>(ForbiddenContextFunction);
51}
52
53} // namespace
54
55// A trivial implementation of GraphInfo around the Interpreter.
56// NOTE: this interpreter info represents the subset of the
57// graph that is executed according to execution plan. Thus,
58// the indices are execution plan indices rather than raw node
59// indices.
60class InterpreterInfo : public GraphInfo {
61 public:
62 explicit InterpreterInfo(Interpreter* interpreter)
63 : interpreter_(interpreter) {}
64
65 size_t num_tensors() const override { return interpreter_->tensors_size(); }
66 TfLiteTensor* tensor(size_t index) override {
67 return interpreter_->tensor(index);
68 }
69 size_t num_nodes() const override {
70 return interpreter_->execution_plan().size();
71 }
72 const TfLiteNode& node(size_t index) const override {
73 int node_index = interpreter_->execution_plan()[index];
74 return interpreter_->node_and_registration(node_index)->first;
75 }
76 const std::vector<int>& inputs() const override {
77 return interpreter_->inputs();
78 }
79 const std::vector<int>& outputs() const override {
80 return interpreter_->outputs();
81 }
82
83 public:
84 Interpreter* interpreter_;
85};
86
87Interpreter::Interpreter(ErrorReporter* error_reporter)
88 : error_reporter_(error_reporter ? error_reporter
89 : DefaultErrorReporter()) {
90 context_.impl_ = static_cast<void*>(this);
91 context_.ResizeTensor = ResizeTensor;
92 context_.ReportError = ReportError;
93 context_.AddTensors = AddTensors;
94 context_.tensors = nullptr;
95 context_.tensors_size = 0;
96 context_.eigen_context = nullptr;
97 context_.gemm_context = nullptr;
98 context_.recommended_num_threads = -1;
99
100 // Invalid to call these these except from TfLiteDelegate
101 SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
102 SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
103 SetForbiddenContextFunction(&context_.GetExecutionPlan);
104
105 // Reserve some space for the tensors to avoid excessive resizing.
106 tensors_.reserve(kTensorsReservedCapacity);
107 nodes_and_registration_.reserve(kTensorsReservedCapacity);
108 next_execution_plan_index_to_prepare_ = 0;
109 UseNNAPI(false);
110}
111
112Interpreter::~Interpreter() {
113 for (auto& nodeAndReg : nodes_and_registration_) {
114 TfLiteNode& node = nodeAndReg.first;
115 TfLiteIntArrayFree(node.inputs);
116 TfLiteIntArrayFree(node.outputs);
117 TfLiteIntArrayFree(node.temporaries);
118 if (node.builtin_data) free(node.builtin_data);
119 OpFree(nodeAndReg.second, node.user_data);
120 node.builtin_data = nullptr;
121 }
122
123 for (int i = 0; i < context_.tensors_size; i++) {
124 TfLiteTensor* tensor = &context_.tensors[i];
125 if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
126 tensor->delegate->FreeBufferHandle(tensor->delegate,
127 &tensor->buffer_handle);
128 }
129 TfLiteTensorFree(tensor);
130 }
131}
132
133TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
134 TfLiteContext* context, TfLiteRegistration registration,
135 const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) {
136 return static_cast<Interpreter*>(context->impl_)
137 ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace,
138 delegate);
139}
140
141namespace {
142
143// Copy a std::vector<int> to an existing TfLiteIntArray.
144// This is a low-level data manipulation function, and it's caller's
145// responsibility to ensure TfLiteIntArray has enough size.
146void CopyVectorToTfLiteIntArray(const std::vector<int>& vec,
147 TfLiteIntArray* arr) {
148 arr->size = vec.size();
149 memcpy(arr->data, vec.data(), sizeof(int) * arr->size);
150}
151
152// This function allocates a continuous memory space that contains a
153// TfLiteDelegateParams followed by a several TfLiteIntArray.
154// When calling `free` at TfLiteDelegateParams*, all the allocated space
155// will be freed together.
156//
157// +-----------------------------------+
158// | TfLiteDelegateParams |
159// | TfLiteDelegate* delegate; |
160// | TfLiteIntArray* nodes_to_replace; |--\
161// | TfLiteIntArray* input_tensors; |--+--\
162// | TfLiteIntArray* output_tensors; |--+--+--\
163// +-----------------------------------+ | | |
164// | TfLiteIntArray (variable size) |<-/ | |
165// +-----------------------------------+ | |
166// | TfLiteIntArray (variable size) |<----/ |
167// +-----------------------------------+ |
168// | TfLiteIntArray (variable size) |<-------/
169// +-----------------------------------+
170TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate,
171 const Subgraph& subgraph) {
172 // Step 1: Calculate the allocation size.
173 int allocation_size = sizeof(TfLiteDelegateParams);
174
175 int nodes_to_replace_size =
176 TfLiteIntArrayGetSizeInBytes(subgraph.nodes.size());
177 allocation_size += nodes_to_replace_size;
178
179 int input_tensors_size =
180 TfLiteIntArrayGetSizeInBytes(subgraph.input_tensors.size());
181 allocation_size += input_tensors_size;
182
183 int output_tensors_size =
184 TfLiteIntArrayGetSizeInBytes(subgraph.output_tensors.size());
185 allocation_size += output_tensors_size;
186
187 // Step 2: Allocate the memory.
188 // Use `char*` for conveniently step through the allocated space by bytes.
189 char* allocation = reinterpret_cast<char*>(malloc(allocation_size));
190
191 // Step 3: Fill all data structures structures.
192 TfLiteDelegateParams* params =
193 reinterpret_cast<TfLiteDelegateParams*>(allocation);
194 params->delegate = delegate;
195 allocation += sizeof(TfLiteDelegateParams);
196
197 params->nodes_to_replace = reinterpret_cast<TfLiteIntArray*>(allocation);
198 CopyVectorToTfLiteIntArray(subgraph.nodes, params->nodes_to_replace);
199 allocation += nodes_to_replace_size;
200
201 params->input_tensors = reinterpret_cast<TfLiteIntArray*>(allocation);
202 CopyVectorToTfLiteIntArray(subgraph.input_tensors, params->input_tensors);
203 allocation += input_tensors_size;
204
205 params->output_tensors = reinterpret_cast<TfLiteIntArray*>(allocation);
206 CopyVectorToTfLiteIntArray(subgraph.output_tensors, params->output_tensors);
207 allocation += output_tensors_size;
208
209 return params;
210}
211
212} // namespace
213
214TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
215 TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
216 TfLiteDelegate* delegate) {
217 // Annotate the registration as DELEGATE op.
218 registration.builtin_code = BuiltinOperator_DELEGATE;
219
220 // Analyze the graph to find all independent subgraphs that are either
221 // fully not-this-delegate or this-delegate computation.
222 InterpreterInfo info(this);
223 std::vector<Subgraph> subgraphs;
224 PartitionGraphIntoIndependentSubgraphs(&info, nodes_to_replace, &subgraphs);
225
226 execution_plan_.clear();
227 for (auto& subgraph : subgraphs) {
228 // Subgraphs calimed by the delegate should have a "macro" op created, the
229 // other subgraphs (kTfNonPartition) just have their nodes added back to
230 // the execution plan.
231 switch (subgraph.type) {
232 case Subgraph::kTfNonPartition:
233 for (auto it = subgraph.nodes.begin(); it != subgraph.nodes.end();
234 ++it) {
235 execution_plan_.push_back(*it);
236 }
237 break;
238 case Subgraph::kTfPartition: {
239 int node_index;
240
241 TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph);
242 AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors,
243 nullptr, 0, params, &registration, &node_index);
244
245 // Initialize the output tensors's delegate-related fields.
246 for (int tensor_index : subgraph.output_tensors) {
247 TfLiteTensor* tensor = &tensors_[tensor_index];
248 TF_LITE_ENSURE_EQ(&context_, tensor->delegate, nullptr);
249 TF_LITE_ENSURE_EQ(&context_, tensor->buffer_handle,
250 kTfLiteNullBufferHandle);
251 // buffer_handle will be filled in delegate's `Prepare`
252 // function.
253 tensor->delegate = delegate;
254 }
255
256 // Associate the node with the delegate.
257 TfLiteNode* node = &nodes_and_registration_[node_index].first;
258 node->delegate = delegate;
259 } break;
260 case Subgraph::kTfUnexplored:
261 return kTfLiteError;
262 break;
263 }
264 }
265 return kTfLiteOk;
266}
267
268// Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
269// this memory and it is only guaranteed to exist during the invocation of the
270// delegate prepare.
271TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) {
272 // TODO(aselle): Do not make a copy here
273 plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size()));
274 *execution_plan = plan_cache_.get();
275 static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]),
276 "TfLiteIntArray and execution_plan do not contain same type.");
277 std::memcpy(plan_cache_->data, execution_plan_.data(),
278 sizeof(plan_cache_->data[0]) * execution_plan_.size());
279 return kTfLiteOk;
280}
281
282// WARNING: This is an experimental interface that is subject to change.
283// Entry point for C node plugin API to get the execution plan
284TfLiteStatus Interpreter::GetExecutionPlan(struct TfLiteContext* context,
285 TfLiteIntArray** execution_plan) {
286 return static_cast<Interpreter*>(context->impl_)
287 ->GetExecutionPlan(execution_plan);
288}
289
290TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
291 TF_LITE_ENSURE_OK(&context_,
292 CheckTensorIndices("inputs", inputs.data(), inputs.size()));
293 inputs_ = std::move(inputs);
294 return kTfLiteOk;
295}
296
297TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
298 TF_LITE_ENSURE_OK(
299 &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size()));
300 outputs_ = std::move(outputs);
301 return kTfLiteOk;
302}
303
304TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
305 const int* indices, int length) {
306 // Making sure kOptionalTensor is not re-defined to something other than -1.
307 static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1");
308
309 for (int i = 0; i < length; i++) {
310 int index = indices[i];
311 if (index < kOptionalTensor || index >= context_.tensors_size) {
312 ReportError(&context_, "Invalid tensor index %d in %s\n", index, label);
313 consistent_ = false;
314 return kTfLiteError;
315 }
316 }
317 return kTfLiteOk;
318}
319
320TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
321 int dims_size, size_t* bytes) {
322 // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
323 // MultiplyWithoutOverflow.
324 TF_LITE_ENSURE(&context_, bytes != nullptr);
325 size_t count = 1;
326 for (int k = 0; k < dims_size; k++) count *= dims[k];
327 switch (type) {
328 case kTfLiteFloat32:
329 *bytes = sizeof(float) * count;
330 break;
331 case kTfLiteInt32:
332 *bytes = sizeof(int32_t) * count;
333 break;
334 case kTfLiteUInt8:
335 *bytes = sizeof(uint8_t) * count;
336 break;
337 case kTfLiteInt64:
338 *bytes = sizeof(int64_t) * count;
339 break;
340 default:
341 ReportError(&context_,
342 "Only float32, int32, int64, uint8 supported currently.");
343 return kTfLiteError;
344 }
345 return kTfLiteOk;
346}
347
348TfLiteStatus Interpreter::AllocateTensors() {
349 next_execution_plan_index_to_prepare_ = 0;
350 if (memory_planner_) {
351 TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
352 }
353
354 if (!consistent_) {
355 ReportError(&context_, "AllocateTensors() called on inconsistent model.");
356 return kTfLiteError;
357 }
358
359 TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
360 if (state_ == kStateUninvokable) {
361 state_ = kStateInvokable;
362 }
363 TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
364 state_ == kStateInvokableAndImmutable);
365 return kTfLiteOk;
366}
367
368TfLiteStatus Interpreter::AddNodeWithParameters(
369 const std::vector<int>& inputs, const std::vector<int>& outputs,
370 const char* init_data, size_t init_data_size, void* builtin_data,
371 const TfLiteRegistration* registration, int* node_index) {
372 if (state_ == kStateInvokableAndImmutable) {
373 ReportError(&context_,
374 "AddNodeWithParameters is disallowed when graph is immutable.");
375 return kTfLiteError;
376 }
377 state_ = kStateUninvokable;
378
379 std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
380 free);
381
382 TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(),
383 inputs.size()));
384 TF_LITE_ENSURE_OK(
385 &context_,
386 CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
387
388 int new_node_index = nodes_and_registration_.size();
389 if (node_index) *node_index = new_node_index;
390 nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
391 auto& node_and_reg = nodes_and_registration_.back();
392 TfLiteNode& node = node_and_reg.first;
393 if (node.inputs) TfLiteIntArrayFree(node.inputs);
394 if (node.outputs) TfLiteIntArrayFree(node.outputs);
395 if (node.temporaries) TfLiteIntArrayFree(node.temporaries);
396
397 // NOTE, here we are not using move semantics yet, since our internal
398 // representation isn't std::vector, but in the future we would like to avoid
399 // copies, so we want the interface to take r-value references now.
400 node.inputs = ConvertVectorToTfLiteIntArray(inputs);
401 node.outputs = ConvertVectorToTfLiteIntArray(outputs);
402 node.temporaries = TfLiteIntArrayCreate(0);
403 if (init_data) {
404 node.user_data = OpInit(*registration, init_data, init_data_size);
405 } else {
406 node.user_data =
407 OpInit(*registration,
408 reinterpret_cast<const char*>(builtin_data_deleter.get()), 0);
409 }
410
411 node.builtin_data = builtin_data_deleter.release();
412 // TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size`
413 // properly for nodes generated by ReplaceSubgraphsWithDelegateKernels.
414
415 if (registration->builtin_code == BuiltinOperator_CUSTOM) {
416 // When it's a CUSTOM op, the `custom_options` field in the Flatbuffer
417 // `Operator` table is passed in.
418 node.custom_initial_data = init_data;
419 node.custom_initial_data_size = init_data_size;
420 } else {
421 node.custom_initial_data = nullptr;
422 node.custom_initial_data_size = 0;
423 }
424
425 node.delegate = nullptr;
426 node_and_reg.second = *registration;
427 execution_plan_.push_back(new_node_index);
428 return kTfLiteOk;
429}
430
431TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
432 const std::vector<int>& dims) {
433 if (state_ == kStateInvokableAndImmutable) {
434 ReportError(&context_,
435 "ResizeInputTensor is disallowed when graph is immutable.");
436 return kTfLiteError;
437 }
438 state_ = kStateUninvokable;
439
440 // TODO(aselle): All bounds checks can be implemented as one-sided bounds
441 // checks by casting to unsigned for efficiency. Profile before doing this.
442 TF_LITE_ENSURE(&context_,
443 tensor_index < context_.tensors_size && tensor_index >= 0);
444 TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims);
445 return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
446}
447
448// Returns true if at least one tensor in the given list is kTfLiteDynamic.
449bool HasDynamicTensor(const TfLiteContext& context,
450 const TfLiteIntArray* tensors) {
451 for (int i = 0; i < tensors->size; ++i) {
452 const TfLiteTensor& tensor = context.tensors[tensors->data[i]];
453 if (tensor.allocation_type == kTfLiteDynamic) {
454 return true;
455 }
456 }
457 return false;
458}
459
460TfLiteStatus Interpreter::PrepareOpsStartingAt(
461 int first_execution_plan_index, int* last_execution_plan_index_prepared) {
462 for (int execution_plan_index = first_execution_plan_index;
463 execution_plan_index < execution_plan_.size(); execution_plan_index++) {
464 int node_index = execution_plan_[execution_plan_index];
465 TfLiteNode& node = nodes_and_registration_[node_index].first;
466 const TfLiteRegistration& registration =
467 nodes_and_registration_[node_index].second;
468 EnsureTensorsVectorCapacity();
469 if (OpPrepare(registration, &node) == kTfLiteError) {
470 return kTfLiteError;
471 }
472
473 *last_execution_plan_index_prepared = execution_plan_index;
474
475 // Discontinue if the node has dynamic outputs. Note that we don't
476 // stop for dynamic temporary tensors since they won't affect the
477 // sizes of other tensors in the graph.
478 if (HasDynamicTensor(context_, node.outputs)) {
479 break;
480 }
481 }
482 return kTfLiteOk;
483}
484
485TfLiteStatus Interpreter::PrepareOpsAndTensors() {
486 if (!memory_planner_) {
487 memory_planner_.reset(new ArenaPlanner(
488 &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this))));
489 memory_planner_->PlanAllocations();
490 }
491
492 int last_exec_plan_index_prepared = 0;
493
494 TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt(
495 next_execution_plan_index_to_prepare_, &last_exec_plan_index_prepared));
496 TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations(
497 next_execution_plan_index_to_prepare_, last_exec_plan_index_prepared));
498
499 next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1;
500 return kTfLiteOk;
501}
502
503TfLiteStatus Interpreter::Invoke() {
504 if (!consistent_) {
505 ReportError(&context_, "Invoke called on model that is not consistent.");
506 return kTfLiteError;
507 }
508 if (state_ == kStateUninvokable) {
509 ReportError(&context_, "Invoke called on model that is not ready.");
510 return kTfLiteError;
511 }
512
513 TfLiteStatus status = kTfLiteOk;
514 if (nnapi_delegate_) {
515 if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) {
516 TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
517 return kTfLiteOk;
518 } else {
519 // TODO(aselle): In the future, we would like this to be an
520 // automatic tflite CPU fallback.
521 ReportError(&context_,
522 "NNAPI was requested, but dependent sized tensors "
523 "being used.\n");
524 return kTfLiteError;
525 }
526 }
527
528 // Invocations are always done in node order.
529 // Note that calling Invoke repeatedly will cause the original memory plan to
530 // be reused, unless either ResizeInputTensor() or AllocateTensors() has been
531 // called.
532 // TODO(b/71913981): we should force recalculation in the presence of dynamic
533 // tensors, because they may have new value which in turn may affect shapes
534 // and allocations.
535 for (int execution_plan_index = 0;
536 execution_plan_index < execution_plan_.size(); execution_plan_index++) {
537 if (execution_plan_index == next_execution_plan_index_to_prepare_) {
538 TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
539 TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >=
540 execution_plan_index);
541 }
542 int node_index = execution_plan_[execution_plan_index];
543 TfLiteNode& node = nodes_and_registration_[node_index].first;
544 const TfLiteRegistration& registration =
545 nodes_and_registration_[node_index].second;
546
547 // TODO(ycling): This is an extra loop through inputs to check if the data
548 // need to be copied from Delegate buffer to raw memory, which is often not
549 // needed. We may want to cache this in prepare to know if this needs to be
550 // done for a node or not.
551 for (int i = 0; i < node.inputs->size; ++i) {
552 int tensor_index = node.inputs->data[i];
553 if (tensor_index == kOptionalTensor) {
554 continue;
555 }
556 TfLiteTensor* tensor = &tensors_[tensor_index];
557 if (tensor->delegate && tensor->delegate != node.delegate &&
558 tensor->data_is_stale) {
559 EnsureTensorDataIsReadable(tensor_index);
560 }
561 }
562
563 EnsureTensorsVectorCapacity();
564 if (OpInvoke(registration, &node) == kTfLiteError) {
565 status = kTfLiteError;
566 }
567 }
568
569 return status;
570}
571
572TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context,
573 TfLiteTensor* tensor,
574 TfLiteIntArray* new_size) {
575 // Note here that context->impl_ is recovering the this pointer for an
576 // instance of Interpreter to call into the member function ResizeTensorImpl
577 // (this function is static).
578 return static_cast<Interpreter*>(context->impl_)
579 ->ResizeTensorImpl(tensor, new_size);
580}
581
582void Interpreter::ReportErrorImpl(const char* format, va_list args) {
583 error_reporter_->Report(format, args);
584}
585
586void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) {
587 va_list args;
588 va_start(args, format);
589 auto* f = static_cast<Interpreter*>(context->impl_);
590 // Note here that context->impl_ is recovering the this pointer for an
591 // instance of Interpreter to call into the member function ReportErrorImpl
592 // (this function is static).
593 f->ReportErrorImpl(format, args);
594 va_end(args);
595}
596
597TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
598 int* first_new_tensor_index) {
599 int base_index = tensors_.size();
600 if (first_new_tensor_index) *first_new_tensor_index = base_index;
601 tensors_.resize(tensors_.size() + tensors_to_add);
602 for (int i = base_index; i < tensors_.size(); i++) {
603 memset(&tensors_[i], 0, sizeof(tensors_[i]));
604 tensors_[i].buffer_handle = kTfLiteNullBufferHandle;
605 }
606 context_.tensors = tensors_.data();
607 context_.tensors_size = tensors_.size();
608 return kTfLiteOk;
609}
610
611TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add,
612 int* first_new_tensor_index) {
613 // Note here that context->impl_ is recovering the this pointer for an
614 // instance of Interpreter to call into the member function AddTensors
615 // (this function is static).
616 return static_cast<Interpreter*>(context->impl_)
617 ->AddTensors(tensors_to_add, first_new_tensor_index);
618}
619
620TfLiteStatus Interpreter::GetNodeAndRegistration(
621 int node_index, TfLiteNode** node, TfLiteRegistration** registration) {
622 TF_LITE_ENSURE(&context_, node_index < nodes_size() && node_index >= 0);
623 TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr);
624 *node = &nodes_and_registration_[node_index].first;
625 *registration = &nodes_and_registration_[node_index].second;
626 return kTfLiteOk;
627}
628
629TfLiteStatus Interpreter::GetNodeAndRegistration(
630 struct TfLiteContext* context, int node_index, TfLiteNode** node,
631 TfLiteRegistration** registration) {
632 return static_cast<Interpreter*>(context->impl_)
633 ->GetNodeAndRegistration(node_index, node, registration);
634}
635
636TfLiteStatus Interpreter::SetTensorParametersReadOnly(
637 int tensor_index, TfLiteType type, const char* name, const int rank,
638 const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
639 size_t bytes, const Allocation* allocation) {
640 if (state_ == kStateInvokableAndImmutable) {
641 ReportError(
642 &context_,
643 "SetTensorParametersReadOnly is disallowed when graph is immutable.");
644 return kTfLiteError;
645 }
646
647 TF_LITE_ENSURE(&context_,
648 tensor_index < context_.tensors_size && tensor_index >= 0);
649 // For most tensors we know exactly how much memory is necessary so we can
650 // ensure the buffer is large enough. However, we need to skip string tensors
651 // because their sizes change with the contents of the individual strings.
652 if (type != kTfLiteString) {
653 size_t required_bytes;
654 TF_LITE_ENSURE_OK(&context_,
655 BytesRequired(type, dims, rank, &required_bytes));
656 TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
657 }
658
659 TfLiteTensor& tensor = context_.tensors[tensor_index];
660 if (type == tensor.type &&
661 EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) {
662 // Fast path which does not invalidate the invokable property.
663 TfLiteTensorDataFree(&tensor);
664 tensor.data.raw = const_cast<char*>(buffer);
665 if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims);
666 tensor.params = quantization;
667 tensor.allocation_type = kTfLiteMmapRo;
668 tensor.allocation = allocation;
669 } else {
670 state_ = kStateUninvokable;
671 TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
672 quantization, const_cast<char*>(buffer), bytes,
673 kTfLiteMmapRo, allocation, &tensor);
674 }
675 return kTfLiteOk;
676}
677
678// Set description of inputs/outputs/data/fptrs for node `node_index`.
679// This variant assumes an external buffer has been allocated of size
680// bytes. The lifetime of buffer must be ensured to be greater or equal
681// to Interpreter.
682TfLiteStatus Interpreter::SetTensorParametersReadWrite(
683 int tensor_index, TfLiteType type, const char* name, const int rank,
684 const int* dims, TfLiteQuantizationParams quantization) {
685 if (state_ == kStateInvokableAndImmutable) {
686 ReportError(
687 &context_,
688 "SetTensorParametersReadWrite is disallowed when graph is immutable.");
689 return kTfLiteError;
690 }
691 TF_LITE_ENSURE(&context_,
692 tensor_index < context_.tensors_size && tensor_index >= 0);
693 size_t required_bytes = 0;
694 if (type != kTfLiteString) {
695 // These types will be allocated in our arena so we need to record how
696 // many bytes we will need based on the dimensions. String tensors are
697 // allocated dynamically and we can't know ahead of time how much space
698 // they will require.
699 TF_LITE_ENSURE_OK(&context_,
700 BytesRequired(type, dims, rank, &required_bytes));
701 }
702 TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
703 quantization,
704 /*buffer=*/nullptr, required_bytes,
705 type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
706 nullptr, &context_.tensors[tensor_index]);
707 return kTfLiteOk;
708}
709
710TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
711 for (int node_index : new_plan) {
712 TF_LITE_ENSURE(&context_, node_index >= 0 && node_index < nodes_size());
713 }
714 execution_plan_ = new_plan;
715 return kTfLiteOk;
716}
717
718TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
719 TfLiteIntArray* new_size) {
720 // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
721 if (tensor->allocation_type == kTfLiteArenaRw ||
722 tensor->allocation_type == kTfLiteDynamic) {
723 if (tensor->type != kTfLiteString) {
724 size_t bytesRequired;
725 TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
726 new_size->size, &bytesRequired);
727 if (status != kTfLiteOk) {
728 TfLiteIntArrayFree(new_size);
729 return kTfLiteError;
730 }
731
732 // Realloc space for kTfLiteDynamic tensors.
733 TfLiteTensorRealloc(bytesRequired, tensor);
734 tensor->bytes = bytesRequired;
735 }
736 if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
737 tensor->dims = new_size;
738
739 if (tensor->allocation_type != kTfLiteDynamic) {
740 tensor->data.raw = nullptr;
741 }
742 } else {
743 // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore
744 // of fixed size.
745 TfLiteIntArrayFree(new_size);
746 ReportError(&context_, "Attempting to resize a fixed-size tensor.");
747 return kTfLiteError;
748 }
749 return kTfLiteOk;
750}
751
752void Interpreter::UseNNAPI(bool enable) {
753 // TODO(aselle): This is a workaround for finding if NNAPI exists.
754 // We also need to make sure getLibraryHandle() is renamed to be NNAPI
755 // prefixed.
756 if (!NNAPIExists()) enable = false;
757 if (!enable) {
758 nnapi_delegate_.reset();
759 } else if (!nnapi_delegate_) {
760 nnapi_delegate_.reset(new NNAPIDelegate);
761 }
762}
763
764void Interpreter::SetNumThreads(int num_threads) {
765 context_.recommended_num_threads = num_threads;
766
767 // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to
768 // be required in order to compile the framework.
769 gemm_support::SetNumThreads(&context_, num_threads);
770 eigen_support::SetNumThreads(&context_, num_threads);
771}
772
773TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
774 bool allow_dynamic_tensors) {
775 if (!allow_dynamic_tensors) {
776 int last_execution_plan_index_prepared;
777 TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt(
778 0, &last_execution_plan_index_prepared));
779
780 bool has_dynamic_tensors = true;
781 // Dynamic tensors exist if not all nodes can be prepared.
782 if (last_execution_plan_index_prepared + 1 == execution_plan_.size()) {
783 // If all the nodes can be prepared, check if the last node has dynamic
784 // tensors.
785 int node_index = execution_plan_[last_execution_plan_index_prepared];
786 TfLiteNode& node = nodes_and_registration_[node_index].first;
787 if (!HasDynamicTensor(context_, node.outputs)) {
788 has_dynamic_tensors = false;
789 }
790 }
791 if (has_dynamic_tensors) {
792 ReportError(&context_, "Attempting to resize a fixed-size tensor.");
793 return kTfLiteError;
794 }
795 }
796
797 // TODO(aselle): Consider if it is worth storing pointers to delegates.
798 // Setup additional context interface.
799 context_.GetNodeAndRegistration = GetNodeAndRegistration;
800 context_.ReplaceSubgraphsWithDelegateKernels =
801 ReplaceSubgraphsWithDelegateKernels;
802 context_.GetExecutionPlan = GetExecutionPlan;
803
804 TfLiteStatus status = delegate->Prepare(&context_, delegate);
805
806 // Remove additional context info.
807 SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
808 SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
809 SetForbiddenContextFunction(&context_.GetExecutionPlan);
810
811 TF_LITE_ENSURE_OK(&context_, status);
812
813 if (!allow_dynamic_tensors) {
814 TF_LITE_ENSURE_OK(&context_, AllocateTensors());
815 TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
816 state_ == kStateInvokableAndImmutable);
817 // After using a delegate which doesn't support dynamic tensors, make the
818 // entire graph immutable.
819 state_ = kStateInvokableAndImmutable;
820 }
821
822 return status;
823}
824
825TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
826 TfLiteBufferHandle buffer_handle,
827 TfLiteDelegate* delegate) {
828 TF_LITE_ENSURE(&context_, tensor_index < tensors_size());
829 TfLiteTensor* tensor = &tensors_[tensor_index];
830
831 TF_LITE_ENSURE(&context_,
832 tensor->delegate == nullptr || tensor->delegate == delegate);
833 tensor->delegate = delegate;
834 if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
835 TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr);
836 tensor->delegate->FreeBufferHandle(tensor->delegate,
837 &tensor->buffer_handle);
838 }
839 tensor->buffer_handle = buffer_handle;
840
841 return kTfLiteOk;
842}
843
844TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
845 TfLiteBufferHandle* buffer_handle,
846 TfLiteDelegate** delegate) {
847 TF_LITE_ENSURE(&context_, tensor_index < tensors_size());
848 TfLiteTensor* tensor = &tensors_[tensor_index];
849
850 *delegate = tensor->delegate;
851 *buffer_handle = tensor->buffer_handle;
852
853 return kTfLiteOk;
854}
855
856} // namespace tflite
857