1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/contrib/lite/interpreter.h" |
17 | #include <cassert> |
18 | #include <cstdarg> |
19 | #include <cstdint> |
20 | #include <cstring> |
21 | #include "tensorflow/contrib/lite/arena_planner.h" |
22 | #include "tensorflow/contrib/lite/context.h" |
23 | #include "tensorflow/contrib/lite/error_reporter.h" |
24 | #include "tensorflow/contrib/lite/graph_info.h" |
25 | #include "tensorflow/contrib/lite/kernels/eigen_support.h" |
26 | #include "tensorflow/contrib/lite/kernels/gemm_support.h" |
27 | #include "tensorflow/contrib/lite/memory_planner.h" |
28 | #include "tensorflow/contrib/lite/nnapi_delegate.h" |
29 | #include "tensorflow/contrib/lite/schema/schema_generated.h" |
30 | #include "tensorflow/contrib/lite/util.h" |
31 | |
32 | namespace tflite { |
33 | |
34 | namespace { |
35 | |
36 | // Stub method which returns kTfLiteError when the function is forbidden. |
37 | // We're registrating this function to several different function to save |
38 | // compiled binary size. Please note the restrictions: |
39 | // * The type of first parameter have to be `TfLiteContext*`. |
40 | // * All paramteters must be trivailly destructible. (E.g. No C++ class) |
41 | TfLiteStatus ForbiddenContextFunction(TfLiteContext* context, ...) { |
42 | context->ReportError(context, |
43 | "The function is forbidden if not calling in delegate." ); |
44 | return kTfLiteError; |
45 | } |
46 | |
47 | // Set the ForbiddenContextFunction to a compatible function pointer. |
48 | template <typename FunctionType> |
49 | void SetForbiddenContextFunction(FunctionType* func) { |
50 | *func = reinterpret_cast<FunctionType>(ForbiddenContextFunction); |
51 | } |
52 | |
53 | } // namespace |
54 | |
55 | // A trivial implementation of GraphInfo around the Interpreter. |
56 | // NOTE: this interpreter info represents the subset of the |
57 | // graph that is executed according to execution plan. Thus, |
58 | // the indices are execution plan indices rather than raw node |
59 | // indices. |
60 | class InterpreterInfo : public GraphInfo { |
61 | public: |
62 | explicit InterpreterInfo(Interpreter* interpreter) |
63 | : interpreter_(interpreter) {} |
64 | |
65 | size_t num_tensors() const override { return interpreter_->tensors_size(); } |
66 | TfLiteTensor* tensor(size_t index) override { |
67 | return interpreter_->tensor(index); |
68 | } |
69 | size_t num_nodes() const override { |
70 | return interpreter_->execution_plan().size(); |
71 | } |
72 | const TfLiteNode& node(size_t index) const override { |
73 | int node_index = interpreter_->execution_plan()[index]; |
74 | return interpreter_->node_and_registration(node_index)->first; |
75 | } |
76 | const std::vector<int>& inputs() const override { |
77 | return interpreter_->inputs(); |
78 | } |
79 | const std::vector<int>& outputs() const override { |
80 | return interpreter_->outputs(); |
81 | } |
82 | |
83 | public: |
84 | Interpreter* interpreter_; |
85 | }; |
86 | |
87 | Interpreter::Interpreter(ErrorReporter* error_reporter) |
88 | : error_reporter_(error_reporter ? error_reporter |
89 | : DefaultErrorReporter()) { |
90 | context_.impl_ = static_cast<void*>(this); |
91 | context_.ResizeTensor = ResizeTensor; |
92 | context_.ReportError = ReportError; |
93 | context_.AddTensors = AddTensors; |
94 | context_.tensors = nullptr; |
95 | context_.tensors_size = 0; |
96 | context_.eigen_context = nullptr; |
97 | context_.gemm_context = nullptr; |
98 | context_.recommended_num_threads = -1; |
99 | |
100 | // Invalid to call these these except from TfLiteDelegate |
101 | SetForbiddenContextFunction(&context_.GetNodeAndRegistration); |
102 | SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); |
103 | SetForbiddenContextFunction(&context_.GetExecutionPlan); |
104 | |
105 | // Reserve some space for the tensors to avoid excessive resizing. |
106 | tensors_.reserve(kTensorsReservedCapacity); |
107 | nodes_and_registration_.reserve(kTensorsReservedCapacity); |
108 | next_execution_plan_index_to_prepare_ = 0; |
109 | UseNNAPI(false); |
110 | } |
111 | |
112 | Interpreter::~Interpreter() { |
113 | for (auto& nodeAndReg : nodes_and_registration_) { |
114 | TfLiteNode& node = nodeAndReg.first; |
115 | TfLiteIntArrayFree(node.inputs); |
116 | TfLiteIntArrayFree(node.outputs); |
117 | TfLiteIntArrayFree(node.temporaries); |
118 | if (node.builtin_data) free(node.builtin_data); |
119 | OpFree(nodeAndReg.second, node.user_data); |
120 | node.builtin_data = nullptr; |
121 | } |
122 | |
123 | for (int i = 0; i < context_.tensors_size; i++) { |
124 | TfLiteTensor* tensor = &context_.tensors[i]; |
125 | if (tensor->buffer_handle != kTfLiteNullBufferHandle) { |
126 | tensor->delegate->FreeBufferHandle(tensor->delegate, |
127 | &tensor->buffer_handle); |
128 | } |
129 | TfLiteTensorFree(tensor); |
130 | } |
131 | } |
132 | |
133 | TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( |
134 | TfLiteContext* context, TfLiteRegistration registration, |
135 | const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) { |
136 | return static_cast<Interpreter*>(context->impl_) |
137 | ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace, |
138 | delegate); |
139 | } |
140 | |
141 | namespace { |
142 | |
143 | // Copy a std::vector<int> to an existing TfLiteIntArray. |
144 | // This is a low-level data manipulation function, and it's caller's |
145 | // responsibility to ensure TfLiteIntArray has enough size. |
146 | void CopyVectorToTfLiteIntArray(const std::vector<int>& vec, |
147 | TfLiteIntArray* arr) { |
148 | arr->size = vec.size(); |
149 | memcpy(arr->data, vec.data(), sizeof(int) * arr->size); |
150 | } |
151 | |
152 | // This function allocates a continuous memory space that contains a |
153 | // TfLiteDelegateParams followed by a several TfLiteIntArray. |
154 | // When calling `free` at TfLiteDelegateParams*, all the allocated space |
155 | // will be freed together. |
156 | // |
157 | // +-----------------------------------+ |
158 | // | TfLiteDelegateParams | |
159 | // | TfLiteDelegate* delegate; | |
160 | // | TfLiteIntArray* nodes_to_replace; |--\ |
161 | // | TfLiteIntArray* input_tensors; |--+--\ |
162 | // | TfLiteIntArray* output_tensors; |--+--+--\ |
163 | // +-----------------------------------+ | | | |
164 | // | TfLiteIntArray (variable size) |<-/ | | |
165 | // +-----------------------------------+ | | |
166 | // | TfLiteIntArray (variable size) |<----/ | |
167 | // +-----------------------------------+ | |
168 | // | TfLiteIntArray (variable size) |<-------/ |
169 | // +-----------------------------------+ |
170 | TfLiteDelegateParams* CreateDelegateParams(TfLiteDelegate* delegate, |
171 | const Subgraph& subgraph) { |
172 | // Step 1: Calculate the allocation size. |
173 | int allocation_size = sizeof(TfLiteDelegateParams); |
174 | |
175 | int nodes_to_replace_size = |
176 | TfLiteIntArrayGetSizeInBytes(subgraph.nodes.size()); |
177 | allocation_size += nodes_to_replace_size; |
178 | |
179 | int input_tensors_size = |
180 | TfLiteIntArrayGetSizeInBytes(subgraph.input_tensors.size()); |
181 | allocation_size += input_tensors_size; |
182 | |
183 | int output_tensors_size = |
184 | TfLiteIntArrayGetSizeInBytes(subgraph.output_tensors.size()); |
185 | allocation_size += output_tensors_size; |
186 | |
187 | // Step 2: Allocate the memory. |
188 | // Use `char*` for conveniently step through the allocated space by bytes. |
189 | char* allocation = reinterpret_cast<char*>(malloc(allocation_size)); |
190 | |
191 | // Step 3: Fill all data structures structures. |
192 | TfLiteDelegateParams* params = |
193 | reinterpret_cast<TfLiteDelegateParams*>(allocation); |
194 | params->delegate = delegate; |
195 | allocation += sizeof(TfLiteDelegateParams); |
196 | |
197 | params->nodes_to_replace = reinterpret_cast<TfLiteIntArray*>(allocation); |
198 | CopyVectorToTfLiteIntArray(subgraph.nodes, params->nodes_to_replace); |
199 | allocation += nodes_to_replace_size; |
200 | |
201 | params->input_tensors = reinterpret_cast<TfLiteIntArray*>(allocation); |
202 | CopyVectorToTfLiteIntArray(subgraph.input_tensors, params->input_tensors); |
203 | allocation += input_tensors_size; |
204 | |
205 | params->output_tensors = reinterpret_cast<TfLiteIntArray*>(allocation); |
206 | CopyVectorToTfLiteIntArray(subgraph.output_tensors, params->output_tensors); |
207 | allocation += output_tensors_size; |
208 | |
209 | return params; |
210 | } |
211 | |
212 | } // namespace |
213 | |
214 | TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( |
215 | TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, |
216 | TfLiteDelegate* delegate) { |
217 | // Annotate the registration as DELEGATE op. |
218 | registration.builtin_code = BuiltinOperator_DELEGATE; |
219 | |
220 | // Analyze the graph to find all independent subgraphs that are either |
221 | // fully not-this-delegate or this-delegate computation. |
222 | InterpreterInfo info(this); |
223 | std::vector<Subgraph> subgraphs; |
224 | PartitionGraphIntoIndependentSubgraphs(&info, nodes_to_replace, &subgraphs); |
225 | |
226 | execution_plan_.clear(); |
227 | for (auto& subgraph : subgraphs) { |
228 | // Subgraphs calimed by the delegate should have a "macro" op created, the |
229 | // other subgraphs (kTfNonPartition) just have their nodes added back to |
230 | // the execution plan. |
231 | switch (subgraph.type) { |
232 | case Subgraph::kTfNonPartition: |
233 | for (auto it = subgraph.nodes.begin(); it != subgraph.nodes.end(); |
234 | ++it) { |
235 | execution_plan_.push_back(*it); |
236 | } |
237 | break; |
238 | case Subgraph::kTfPartition: { |
239 | int node_index; |
240 | |
241 | TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph); |
242 | AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors, |
243 | nullptr, 0, params, ®istration, &node_index); |
244 | |
245 | // Initialize the output tensors's delegate-related fields. |
246 | for (int tensor_index : subgraph.output_tensors) { |
247 | TfLiteTensor* tensor = &tensors_[tensor_index]; |
248 | TF_LITE_ENSURE_EQ(&context_, tensor->delegate, nullptr); |
249 | TF_LITE_ENSURE_EQ(&context_, tensor->buffer_handle, |
250 | kTfLiteNullBufferHandle); |
251 | // buffer_handle will be filled in delegate's `Prepare` |
252 | // function. |
253 | tensor->delegate = delegate; |
254 | } |
255 | |
256 | // Associate the node with the delegate. |
257 | TfLiteNode* node = &nodes_and_registration_[node_index].first; |
258 | node->delegate = delegate; |
259 | } break; |
260 | case Subgraph::kTfUnexplored: |
261 | return kTfLiteError; |
262 | break; |
263 | } |
264 | } |
265 | return kTfLiteOk; |
266 | } |
267 | |
268 | // Gets an TfLiteIntArray* representing the execution plan. The interpreter owns |
269 | // this memory and it is only guaranteed to exist during the invocation of the |
270 | // delegate prepare. |
271 | TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) { |
272 | // TODO(aselle): Do not make a copy here |
273 | plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size())); |
274 | *execution_plan = plan_cache_.get(); |
275 | static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]), |
276 | "TfLiteIntArray and execution_plan do not contain same type." ); |
277 | std::memcpy(plan_cache_->data, execution_plan_.data(), |
278 | sizeof(plan_cache_->data[0]) * execution_plan_.size()); |
279 | return kTfLiteOk; |
280 | } |
281 | |
282 | // WARNING: This is an experimental interface that is subject to change. |
283 | // Entry point for C node plugin API to get the execution plan |
284 | TfLiteStatus Interpreter::GetExecutionPlan(struct TfLiteContext* context, |
285 | TfLiteIntArray** execution_plan) { |
286 | return static_cast<Interpreter*>(context->impl_) |
287 | ->GetExecutionPlan(execution_plan); |
288 | } |
289 | |
290 | TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) { |
291 | TF_LITE_ENSURE_OK(&context_, |
292 | CheckTensorIndices("inputs" , inputs.data(), inputs.size())); |
293 | inputs_ = std::move(inputs); |
294 | return kTfLiteOk; |
295 | } |
296 | |
297 | TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) { |
298 | TF_LITE_ENSURE_OK( |
299 | &context_, CheckTensorIndices("outputs" , outputs.data(), outputs.size())); |
300 | outputs_ = std::move(outputs); |
301 | return kTfLiteOk; |
302 | } |
303 | |
304 | TfLiteStatus Interpreter::CheckTensorIndices(const char* label, |
305 | const int* indices, int length) { |
306 | // Making sure kOptionalTensor is not re-defined to something other than -1. |
307 | static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1" ); |
308 | |
309 | for (int i = 0; i < length; i++) { |
310 | int index = indices[i]; |
311 | if (index < kOptionalTensor || index >= context_.tensors_size) { |
312 | ReportError(&context_, "Invalid tensor index %d in %s\n" , index, label); |
313 | consistent_ = false; |
314 | return kTfLiteError; |
315 | } |
316 | } |
317 | return kTfLiteOk; |
318 | } |
319 | |
320 | TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, |
321 | int dims_size, size_t* bytes) { |
322 | // TODO(aselle): Check for overflow here using overflow.h in TensorFlow |
323 | // MultiplyWithoutOverflow. |
324 | TF_LITE_ENSURE(&context_, bytes != nullptr); |
325 | size_t count = 1; |
326 | for (int k = 0; k < dims_size; k++) count *= dims[k]; |
327 | switch (type) { |
328 | case kTfLiteFloat32: |
329 | *bytes = sizeof(float) * count; |
330 | break; |
331 | case kTfLiteInt32: |
332 | *bytes = sizeof(int32_t) * count; |
333 | break; |
334 | case kTfLiteUInt8: |
335 | *bytes = sizeof(uint8_t) * count; |
336 | break; |
337 | case kTfLiteInt64: |
338 | *bytes = sizeof(int64_t) * count; |
339 | break; |
340 | default: |
341 | ReportError(&context_, |
342 | "Only float32, int32, int64, uint8 supported currently." ); |
343 | return kTfLiteError; |
344 | } |
345 | return kTfLiteOk; |
346 | } |
347 | |
348 | TfLiteStatus Interpreter::AllocateTensors() { |
349 | next_execution_plan_index_to_prepare_ = 0; |
350 | if (memory_planner_) { |
351 | TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); |
352 | } |
353 | |
354 | if (!consistent_) { |
355 | ReportError(&context_, "AllocateTensors() called on inconsistent model." ); |
356 | return kTfLiteError; |
357 | } |
358 | |
359 | TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); |
360 | if (state_ == kStateUninvokable) { |
361 | state_ = kStateInvokable; |
362 | } |
363 | TF_LITE_ENSURE(&context_, state_ == kStateInvokable || |
364 | state_ == kStateInvokableAndImmutable); |
365 | return kTfLiteOk; |
366 | } |
367 | |
368 | TfLiteStatus Interpreter::AddNodeWithParameters( |
369 | const std::vector<int>& inputs, const std::vector<int>& outputs, |
370 | const char* init_data, size_t init_data_size, void* builtin_data, |
371 | const TfLiteRegistration* registration, int* node_index) { |
372 | if (state_ == kStateInvokableAndImmutable) { |
373 | ReportError(&context_, |
374 | "AddNodeWithParameters is disallowed when graph is immutable." ); |
375 | return kTfLiteError; |
376 | } |
377 | state_ = kStateUninvokable; |
378 | |
379 | std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data, |
380 | free); |
381 | |
382 | TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs" , inputs.data(), |
383 | inputs.size())); |
384 | TF_LITE_ENSURE_OK( |
385 | &context_, |
386 | CheckTensorIndices("node outputs" , outputs.data(), outputs.size())); |
387 | |
388 | int new_node_index = nodes_and_registration_.size(); |
389 | if (node_index) *node_index = new_node_index; |
390 | nodes_and_registration_.resize(nodes_and_registration_.size() + 1); |
391 | auto& node_and_reg = nodes_and_registration_.back(); |
392 | TfLiteNode& node = node_and_reg.first; |
393 | if (node.inputs) TfLiteIntArrayFree(node.inputs); |
394 | if (node.outputs) TfLiteIntArrayFree(node.outputs); |
395 | if (node.temporaries) TfLiteIntArrayFree(node.temporaries); |
396 | |
397 | // NOTE, here we are not using move semantics yet, since our internal |
398 | // representation isn't std::vector, but in the future we would like to avoid |
399 | // copies, so we want the interface to take r-value references now. |
400 | node.inputs = ConvertVectorToTfLiteIntArray(inputs); |
401 | node.outputs = ConvertVectorToTfLiteIntArray(outputs); |
402 | node.temporaries = TfLiteIntArrayCreate(0); |
403 | if (init_data) { |
404 | node.user_data = OpInit(*registration, init_data, init_data_size); |
405 | } else { |
406 | node.user_data = |
407 | OpInit(*registration, |
408 | reinterpret_cast<const char*>(builtin_data_deleter.get()), 0); |
409 | } |
410 | |
411 | node.builtin_data = builtin_data_deleter.release(); |
412 | // TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size` |
413 | // properly for nodes generated by ReplaceSubgraphsWithDelegateKernels. |
414 | |
415 | if (registration->builtin_code == BuiltinOperator_CUSTOM) { |
416 | // When it's a CUSTOM op, the `custom_options` field in the Flatbuffer |
417 | // `Operator` table is passed in. |
418 | node.custom_initial_data = init_data; |
419 | node.custom_initial_data_size = init_data_size; |
420 | } else { |
421 | node.custom_initial_data = nullptr; |
422 | node.custom_initial_data_size = 0; |
423 | } |
424 | |
425 | node.delegate = nullptr; |
426 | node_and_reg.second = *registration; |
427 | execution_plan_.push_back(new_node_index); |
428 | return kTfLiteOk; |
429 | } |
430 | |
431 | TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, |
432 | const std::vector<int>& dims) { |
433 | if (state_ == kStateInvokableAndImmutable) { |
434 | ReportError(&context_, |
435 | "ResizeInputTensor is disallowed when graph is immutable." ); |
436 | return kTfLiteError; |
437 | } |
438 | state_ = kStateUninvokable; |
439 | |
440 | // TODO(aselle): All bounds checks can be implemented as one-sided bounds |
441 | // checks by casting to unsigned for efficiency. Profile before doing this. |
442 | TF_LITE_ENSURE(&context_, |
443 | tensor_index < context_.tensors_size && tensor_index >= 0); |
444 | TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims); |
445 | return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); |
446 | } |
447 | |
448 | // Returns true if at least one tensor in the given list is kTfLiteDynamic. |
449 | bool HasDynamicTensor(const TfLiteContext& context, |
450 | const TfLiteIntArray* tensors) { |
451 | for (int i = 0; i < tensors->size; ++i) { |
452 | const TfLiteTensor& tensor = context.tensors[tensors->data[i]]; |
453 | if (tensor.allocation_type == kTfLiteDynamic) { |
454 | return true; |
455 | } |
456 | } |
457 | return false; |
458 | } |
459 | |
460 | TfLiteStatus Interpreter::PrepareOpsStartingAt( |
461 | int first_execution_plan_index, int* last_execution_plan_index_prepared) { |
462 | for (int execution_plan_index = first_execution_plan_index; |
463 | execution_plan_index < execution_plan_.size(); execution_plan_index++) { |
464 | int node_index = execution_plan_[execution_plan_index]; |
465 | TfLiteNode& node = nodes_and_registration_[node_index].first; |
466 | const TfLiteRegistration& registration = |
467 | nodes_and_registration_[node_index].second; |
468 | EnsureTensorsVectorCapacity(); |
469 | if (OpPrepare(registration, &node) == kTfLiteError) { |
470 | return kTfLiteError; |
471 | } |
472 | |
473 | *last_execution_plan_index_prepared = execution_plan_index; |
474 | |
475 | // Discontinue if the node has dynamic outputs. Note that we don't |
476 | // stop for dynamic temporary tensors since they won't affect the |
477 | // sizes of other tensors in the graph. |
478 | if (HasDynamicTensor(context_, node.outputs)) { |
479 | break; |
480 | } |
481 | } |
482 | return kTfLiteOk; |
483 | } |
484 | |
485 | TfLiteStatus Interpreter::PrepareOpsAndTensors() { |
486 | if (!memory_planner_) { |
487 | memory_planner_.reset(new ArenaPlanner( |
488 | &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)))); |
489 | memory_planner_->PlanAllocations(); |
490 | } |
491 | |
492 | int last_exec_plan_index_prepared = 0; |
493 | |
494 | TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt( |
495 | next_execution_plan_index_to_prepare_, &last_exec_plan_index_prepared)); |
496 | TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations( |
497 | next_execution_plan_index_to_prepare_, last_exec_plan_index_prepared)); |
498 | |
499 | next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1; |
500 | return kTfLiteOk; |
501 | } |
502 | |
503 | TfLiteStatus Interpreter::Invoke() { |
504 | if (!consistent_) { |
505 | ReportError(&context_, "Invoke called on model that is not consistent." ); |
506 | return kTfLiteError; |
507 | } |
508 | if (state_ == kStateUninvokable) { |
509 | ReportError(&context_, "Invoke called on model that is not ready." ); |
510 | return kTfLiteError; |
511 | } |
512 | |
513 | TfLiteStatus status = kTfLiteOk; |
514 | if (nnapi_delegate_) { |
515 | if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) { |
516 | TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); |
517 | return kTfLiteOk; |
518 | } else { |
519 | // TODO(aselle): In the future, we would like this to be an |
520 | // automatic tflite CPU fallback. |
521 | ReportError(&context_, |
522 | "NNAPI was requested, but dependent sized tensors " |
523 | "being used.\n" ); |
524 | return kTfLiteError; |
525 | } |
526 | } |
527 | |
528 | // Invocations are always done in node order. |
529 | // Note that calling Invoke repeatedly will cause the original memory plan to |
530 | // be reused, unless either ResizeInputTensor() or AllocateTensors() has been |
531 | // called. |
532 | // TODO(b/71913981): we should force recalculation in the presence of dynamic |
533 | // tensors, because they may have new value which in turn may affect shapes |
534 | // and allocations. |
535 | for (int execution_plan_index = 0; |
536 | execution_plan_index < execution_plan_.size(); execution_plan_index++) { |
537 | if (execution_plan_index == next_execution_plan_index_to_prepare_) { |
538 | TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); |
539 | TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >= |
540 | execution_plan_index); |
541 | } |
542 | int node_index = execution_plan_[execution_plan_index]; |
543 | TfLiteNode& node = nodes_and_registration_[node_index].first; |
544 | const TfLiteRegistration& registration = |
545 | nodes_and_registration_[node_index].second; |
546 | |
547 | // TODO(ycling): This is an extra loop through inputs to check if the data |
548 | // need to be copied from Delegate buffer to raw memory, which is often not |
549 | // needed. We may want to cache this in prepare to know if this needs to be |
550 | // done for a node or not. |
551 | for (int i = 0; i < node.inputs->size; ++i) { |
552 | int tensor_index = node.inputs->data[i]; |
553 | if (tensor_index == kOptionalTensor) { |
554 | continue; |
555 | } |
556 | TfLiteTensor* tensor = &tensors_[tensor_index]; |
557 | if (tensor->delegate && tensor->delegate != node.delegate && |
558 | tensor->data_is_stale) { |
559 | EnsureTensorDataIsReadable(tensor_index); |
560 | } |
561 | } |
562 | |
563 | EnsureTensorsVectorCapacity(); |
564 | if (OpInvoke(registration, &node) == kTfLiteError) { |
565 | status = kTfLiteError; |
566 | } |
567 | } |
568 | |
569 | return status; |
570 | } |
571 | |
572 | TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context, |
573 | TfLiteTensor* tensor, |
574 | TfLiteIntArray* new_size) { |
575 | // Note here that context->impl_ is recovering the this pointer for an |
576 | // instance of Interpreter to call into the member function ResizeTensorImpl |
577 | // (this function is static). |
578 | return static_cast<Interpreter*>(context->impl_) |
579 | ->ResizeTensorImpl(tensor, new_size); |
580 | } |
581 | |
582 | void Interpreter::ReportErrorImpl(const char* format, va_list args) { |
583 | error_reporter_->Report(format, args); |
584 | } |
585 | |
586 | void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) { |
587 | va_list args; |
588 | va_start(args, format); |
589 | auto* f = static_cast<Interpreter*>(context->impl_); |
590 | // Note here that context->impl_ is recovering the this pointer for an |
591 | // instance of Interpreter to call into the member function ReportErrorImpl |
592 | // (this function is static). |
593 | f->ReportErrorImpl(format, args); |
594 | va_end(args); |
595 | } |
596 | |
597 | TfLiteStatus Interpreter::AddTensors(int tensors_to_add, |
598 | int* first_new_tensor_index) { |
599 | int base_index = tensors_.size(); |
600 | if (first_new_tensor_index) *first_new_tensor_index = base_index; |
601 | tensors_.resize(tensors_.size() + tensors_to_add); |
602 | for (int i = base_index; i < tensors_.size(); i++) { |
603 | memset(&tensors_[i], 0, sizeof(tensors_[i])); |
604 | tensors_[i].buffer_handle = kTfLiteNullBufferHandle; |
605 | } |
606 | context_.tensors = tensors_.data(); |
607 | context_.tensors_size = tensors_.size(); |
608 | return kTfLiteOk; |
609 | } |
610 | |
611 | TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add, |
612 | int* first_new_tensor_index) { |
613 | // Note here that context->impl_ is recovering the this pointer for an |
614 | // instance of Interpreter to call into the member function AddTensors |
615 | // (this function is static). |
616 | return static_cast<Interpreter*>(context->impl_) |
617 | ->AddTensors(tensors_to_add, first_new_tensor_index); |
618 | } |
619 | |
620 | TfLiteStatus Interpreter::GetNodeAndRegistration( |
621 | int node_index, TfLiteNode** node, TfLiteRegistration** registration) { |
622 | TF_LITE_ENSURE(&context_, node_index < nodes_size() && node_index >= 0); |
623 | TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr); |
624 | *node = &nodes_and_registration_[node_index].first; |
625 | *registration = &nodes_and_registration_[node_index].second; |
626 | return kTfLiteOk; |
627 | } |
628 | |
629 | TfLiteStatus Interpreter::GetNodeAndRegistration( |
630 | struct TfLiteContext* context, int node_index, TfLiteNode** node, |
631 | TfLiteRegistration** registration) { |
632 | return static_cast<Interpreter*>(context->impl_) |
633 | ->GetNodeAndRegistration(node_index, node, registration); |
634 | } |
635 | |
636 | TfLiteStatus Interpreter::SetTensorParametersReadOnly( |
637 | int tensor_index, TfLiteType type, const char* name, const int rank, |
638 | const int* dims, TfLiteQuantizationParams quantization, const char* buffer, |
639 | size_t bytes, const Allocation* allocation) { |
640 | if (state_ == kStateInvokableAndImmutable) { |
641 | ReportError( |
642 | &context_, |
643 | "SetTensorParametersReadOnly is disallowed when graph is immutable." ); |
644 | return kTfLiteError; |
645 | } |
646 | |
647 | TF_LITE_ENSURE(&context_, |
648 | tensor_index < context_.tensors_size && tensor_index >= 0); |
649 | // For most tensors we know exactly how much memory is necessary so we can |
650 | // ensure the buffer is large enough. However, we need to skip string tensors |
651 | // because their sizes change with the contents of the individual strings. |
652 | if (type != kTfLiteString) { |
653 | size_t required_bytes; |
654 | TF_LITE_ENSURE_OK(&context_, |
655 | BytesRequired(type, dims, rank, &required_bytes)); |
656 | TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes); |
657 | } |
658 | |
659 | TfLiteTensor& tensor = context_.tensors[tensor_index]; |
660 | if (type == tensor.type && |
661 | EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) { |
662 | // Fast path which does not invalidate the invokable property. |
663 | TfLiteTensorDataFree(&tensor); |
664 | tensor.data.raw = const_cast<char*>(buffer); |
665 | if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims); |
666 | tensor.params = quantization; |
667 | tensor.allocation_type = kTfLiteMmapRo; |
668 | tensor.allocation = allocation; |
669 | } else { |
670 | state_ = kStateUninvokable; |
671 | TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), |
672 | quantization, const_cast<char*>(buffer), bytes, |
673 | kTfLiteMmapRo, allocation, &tensor); |
674 | } |
675 | return kTfLiteOk; |
676 | } |
677 | |
678 | // Set description of inputs/outputs/data/fptrs for node `node_index`. |
679 | // This variant assumes an external buffer has been allocated of size |
680 | // bytes. The lifetime of buffer must be ensured to be greater or equal |
681 | // to Interpreter. |
682 | TfLiteStatus Interpreter::SetTensorParametersReadWrite( |
683 | int tensor_index, TfLiteType type, const char* name, const int rank, |
684 | const int* dims, TfLiteQuantizationParams quantization) { |
685 | if (state_ == kStateInvokableAndImmutable) { |
686 | ReportError( |
687 | &context_, |
688 | "SetTensorParametersReadWrite is disallowed when graph is immutable." ); |
689 | return kTfLiteError; |
690 | } |
691 | TF_LITE_ENSURE(&context_, |
692 | tensor_index < context_.tensors_size && tensor_index >= 0); |
693 | size_t required_bytes = 0; |
694 | if (type != kTfLiteString) { |
695 | // These types will be allocated in our arena so we need to record how |
696 | // many bytes we will need based on the dimensions. String tensors are |
697 | // allocated dynamically and we can't know ahead of time how much space |
698 | // they will require. |
699 | TF_LITE_ENSURE_OK(&context_, |
700 | BytesRequired(type, dims, rank, &required_bytes)); |
701 | } |
702 | TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), |
703 | quantization, |
704 | /*buffer=*/nullptr, required_bytes, |
705 | type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, |
706 | nullptr, &context_.tensors[tensor_index]); |
707 | return kTfLiteOk; |
708 | } |
709 | |
710 | TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) { |
711 | for (int node_index : new_plan) { |
712 | TF_LITE_ENSURE(&context_, node_index >= 0 && node_index < nodes_size()); |
713 | } |
714 | execution_plan_ = new_plan; |
715 | return kTfLiteOk; |
716 | } |
717 | |
718 | TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, |
719 | TfLiteIntArray* new_size) { |
720 | // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. |
721 | if (tensor->allocation_type == kTfLiteArenaRw || |
722 | tensor->allocation_type == kTfLiteDynamic) { |
723 | if (tensor->type != kTfLiteString) { |
724 | size_t bytesRequired; |
725 | TfLiteStatus status = BytesRequired(tensor->type, new_size->data, |
726 | new_size->size, &bytesRequired); |
727 | if (status != kTfLiteOk) { |
728 | TfLiteIntArrayFree(new_size); |
729 | return kTfLiteError; |
730 | } |
731 | |
732 | // Realloc space for kTfLiteDynamic tensors. |
733 | TfLiteTensorRealloc(bytesRequired, tensor); |
734 | tensor->bytes = bytesRequired; |
735 | } |
736 | if (tensor->dims) TfLiteIntArrayFree(tensor->dims); |
737 | tensor->dims = new_size; |
738 | |
739 | if (tensor->allocation_type != kTfLiteDynamic) { |
740 | tensor->data.raw = nullptr; |
741 | } |
742 | } else { |
743 | // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore |
744 | // of fixed size. |
745 | TfLiteIntArrayFree(new_size); |
746 | ReportError(&context_, "Attempting to resize a fixed-size tensor." ); |
747 | return kTfLiteError; |
748 | } |
749 | return kTfLiteOk; |
750 | } |
751 | |
752 | void Interpreter::UseNNAPI(bool enable) { |
753 | // TODO(aselle): This is a workaround for finding if NNAPI exists. |
754 | // We also need to make sure getLibraryHandle() is renamed to be NNAPI |
755 | // prefixed. |
756 | if (!NNAPIExists()) enable = false; |
757 | if (!enable) { |
758 | nnapi_delegate_.reset(); |
759 | } else if (!nnapi_delegate_) { |
760 | nnapi_delegate_.reset(new NNAPIDelegate); |
761 | } |
762 | } |
763 | |
764 | void Interpreter::SetNumThreads(int num_threads) { |
765 | context_.recommended_num_threads = num_threads; |
766 | |
767 | // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to |
768 | // be required in order to compile the framework. |
769 | gemm_support::SetNumThreads(&context_, num_threads); |
770 | eigen_support::SetNumThreads(&context_, num_threads); |
771 | } |
772 | |
773 | TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, |
774 | bool allow_dynamic_tensors) { |
775 | if (!allow_dynamic_tensors) { |
776 | int last_execution_plan_index_prepared; |
777 | TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt( |
778 | 0, &last_execution_plan_index_prepared)); |
779 | |
780 | bool has_dynamic_tensors = true; |
781 | // Dynamic tensors exist if not all nodes can be prepared. |
782 | if (last_execution_plan_index_prepared + 1 == execution_plan_.size()) { |
783 | // If all the nodes can be prepared, check if the last node has dynamic |
784 | // tensors. |
785 | int node_index = execution_plan_[last_execution_plan_index_prepared]; |
786 | TfLiteNode& node = nodes_and_registration_[node_index].first; |
787 | if (!HasDynamicTensor(context_, node.outputs)) { |
788 | has_dynamic_tensors = false; |
789 | } |
790 | } |
791 | if (has_dynamic_tensors) { |
792 | ReportError(&context_, "Attempting to resize a fixed-size tensor." ); |
793 | return kTfLiteError; |
794 | } |
795 | } |
796 | |
797 | // TODO(aselle): Consider if it is worth storing pointers to delegates. |
798 | // Setup additional context interface. |
799 | context_.GetNodeAndRegistration = GetNodeAndRegistration; |
800 | context_.ReplaceSubgraphsWithDelegateKernels = |
801 | ReplaceSubgraphsWithDelegateKernels; |
802 | context_.GetExecutionPlan = GetExecutionPlan; |
803 | |
804 | TfLiteStatus status = delegate->Prepare(&context_, delegate); |
805 | |
806 | // Remove additional context info. |
807 | SetForbiddenContextFunction(&context_.GetNodeAndRegistration); |
808 | SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); |
809 | SetForbiddenContextFunction(&context_.GetExecutionPlan); |
810 | |
811 | TF_LITE_ENSURE_OK(&context_, status); |
812 | |
813 | if (!allow_dynamic_tensors) { |
814 | TF_LITE_ENSURE_OK(&context_, AllocateTensors()); |
815 | TF_LITE_ENSURE(&context_, state_ == kStateInvokable || |
816 | state_ == kStateInvokableAndImmutable); |
817 | // After using a delegate which doesn't support dynamic tensors, make the |
818 | // entire graph immutable. |
819 | state_ = kStateInvokableAndImmutable; |
820 | } |
821 | |
822 | return status; |
823 | } |
824 | |
825 | TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, |
826 | TfLiteBufferHandle buffer_handle, |
827 | TfLiteDelegate* delegate) { |
828 | TF_LITE_ENSURE(&context_, tensor_index < tensors_size()); |
829 | TfLiteTensor* tensor = &tensors_[tensor_index]; |
830 | |
831 | TF_LITE_ENSURE(&context_, |
832 | tensor->delegate == nullptr || tensor->delegate == delegate); |
833 | tensor->delegate = delegate; |
834 | if (tensor->buffer_handle != kTfLiteNullBufferHandle) { |
835 | TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr); |
836 | tensor->delegate->FreeBufferHandle(tensor->delegate, |
837 | &tensor->buffer_handle); |
838 | } |
839 | tensor->buffer_handle = buffer_handle; |
840 | |
841 | return kTfLiteOk; |
842 | } |
843 | |
844 | TfLiteStatus Interpreter::GetBufferHandle(int tensor_index, |
845 | TfLiteBufferHandle* buffer_handle, |
846 | TfLiteDelegate** delegate) { |
847 | TF_LITE_ENSURE(&context_, tensor_index < tensors_size()); |
848 | TfLiteTensor* tensor = &tensors_[tensor_index]; |
849 | |
850 | *delegate = tensor->delegate; |
851 | *buffer_handle = tensor->buffer_handle; |
852 | |
853 | return kTfLiteOk; |
854 | } |
855 | |
856 | } // namespace tflite |
857 | |