1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <fcntl.h>
16#include <stdint.h>
17#include <stdio.h>
18#include <stdlib.h>
19#include <sys/mman.h>
20#include <sys/stat.h>
21#include <sys/types.h>
22#include <unistd.h>
23
24#include "tensorflow/contrib/lite/allocation.h"
25#include "tensorflow/contrib/lite/builtin_op_data.h"
26#include "tensorflow/contrib/lite/error_reporter.h"
27#include "tensorflow/contrib/lite/model.h"
28#include "tensorflow/contrib/lite/nnapi_delegate.h"
29#include "tensorflow/contrib/lite/version.h"
30
31namespace tflite {
32
33namespace {
34// Ensure that ErrorReporter is non-null.
35ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
36 return e ? e : DefaultErrorReporter();
37}
38} // namespace
39
40const char* kEmptyTensorName = "";
41
42TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
43 ErrorReporter* error_reporter) {
44 switch (tensor_type) {
45 case TensorType_FLOAT32:
46 *type = kTfLiteFloat32;
47 break;
48 case TensorType_INT32:
49 *type = kTfLiteInt32;
50 break;
51 case TensorType_UINT8:
52 *type = kTfLiteUInt8;
53 break;
54 case TensorType_INT64:
55 *type = kTfLiteInt64;
56 break;
57 case TensorType_STRING:
58 *type = kTfLiteString;
59 break;
60 default:
61 error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
62 EnumNameTensorType(tensor_type), tensor_type);
63 return kTfLiteError;
64 }
65 return kTfLiteOk;
66}
67
68// Loads a model from `filename`. If `mmap_file` is true then use mmap,
69// otherwise make a copy of the model in a buffer.
70std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
71 bool mmap_file,
72 ErrorReporter* error_reporter,
73 bool use_nnapi) {
74 std::unique_ptr<Allocation> allocation;
75 if (mmap_file) {
76 if (use_nnapi && NNAPIExists())
77 allocation.reset(new NNAPIAllocation(filename, error_reporter));
78 else
79 allocation.reset(new MMAPAllocation(filename, error_reporter));
80 } else {
81 allocation.reset(new FileCopyAllocation(filename, error_reporter));
82 }
83 return allocation;
84}
85
86std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
87 const char* filename, ErrorReporter* error_reporter) {
88 error_reporter = ValidateErrorReporter(error_reporter);
89
90 std::unique_ptr<FlatBufferModel> model;
91 auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
92 error_reporter, /*use_nnapi=*/true);
93 model.reset(new FlatBufferModel(allocation.release(), error_reporter));
94 if (!model->initialized()) model.reset();
95 return model;
96}
97
98std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
99 const char* filename, TfLiteVerifier* verifier,
100 ErrorReporter* error_reporter) {
101 error_reporter = ValidateErrorReporter(error_reporter);
102
103 std::unique_ptr<FlatBufferModel> model;
104 auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
105 error_reporter, /*use_nnapi=*/true);
106 if (verifier &&
107 !verifier->Verify(static_cast<const char*>(allocation->base()),
108 allocation->bytes(), error_reporter)) {
109 return model;
110 }
111 model.reset(new FlatBufferModel(allocation.release(), error_reporter));
112 if (!model->initialized()) model.reset();
113 return model;
114}
115
116std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
117 const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
118 error_reporter = ValidateErrorReporter(error_reporter);
119
120 std::unique_ptr<FlatBufferModel> model;
121 Allocation* allocation =
122 new MemoryAllocation(buffer, buffer_size, error_reporter);
123 model.reset(new FlatBufferModel(allocation, error_reporter));
124 if (!model->initialized()) model.reset();
125 return model;
126}
127
128std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
129 const tflite::Model* model_spec, ErrorReporter* error_reporter) {
130 error_reporter = ValidateErrorReporter(error_reporter);
131
132 std::unique_ptr<FlatBufferModel> model;
133 model.reset(new FlatBufferModel(model_spec, error_reporter));
134 if (!model->initialized()) model.reset();
135 return model;
136}
137
138bool FlatBufferModel::CheckModelIdentifier() const {
139 if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
140 const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
141 error_reporter_->Report(
142 "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
143 ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
144 return false;
145 }
146 return true;
147}
148
149FlatBufferModel::FlatBufferModel(const Model* model,
150 ErrorReporter* error_reporter)
151 : error_reporter_(ValidateErrorReporter(error_reporter)) {
152 model_ = model;
153}
154
155FlatBufferModel::FlatBufferModel(Allocation* allocation,
156 ErrorReporter* error_reporter)
157 : error_reporter_(ValidateErrorReporter(error_reporter)) {
158 allocation_ = allocation;
159 if (!allocation_->valid() || !CheckModelIdentifier()) return;
160
161 model_ = ::tflite::GetModel(allocation_->base());
162}
163
164FlatBufferModel::~FlatBufferModel() { delete allocation_; }
165
166InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
167 const OpResolver& op_resolver)
168 : model_(model.GetModel()),
169 op_resolver_(op_resolver),
170 error_reporter_(ValidateErrorReporter(model.error_reporter())),
171 allocation_(model.allocation()) {}
172
173InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
174 const OpResolver& op_resolver,
175 ErrorReporter* error_reporter)
176 : model_(model),
177 op_resolver_(op_resolver),
178 error_reporter_(ValidateErrorReporter(error_reporter)) {}
179
180TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
181 TfLiteStatus status = kTfLiteOk;
182 auto opcodes = model_->operator_codes();
183 for (const OperatorCode* opcode : *opcodes) {
184 TfLiteRegistration* registration = nullptr;
185 auto builtin_code = opcode->builtin_code();
186 if (builtin_code > BuiltinOperator_MAX ||
187 builtin_code < BuiltinOperator_MIN) {
188 error_reporter_->Report(
189 "Op builtin_code out or range: %d. Are you using old TFLite binary "
190 "with newer model?",
191 builtin_code);
192 status = kTfLiteError;
193 } else if (builtin_code != BuiltinOperator_CUSTOM) {
194 flatbuffer_op_index_to_registration_types_.push_back(builtin_code);
195 registration = op_resolver_.FindOp(builtin_code);
196 if (registration == nullptr) {
197 error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
198 EnumNameBuiltinOperator(builtin_code));
199 status = kTfLiteError;
200 }
201 } else if (!opcode->custom_code()) {
202 error_reporter_->Report(
203 "Operator with CUSTOM builtin_code has no custom_code.\n");
204 status = kTfLiteError;
205 } else {
206 const char* name = opcode->custom_code()->c_str();
207 registration = op_resolver_.FindOp(name);
208 flatbuffer_op_index_to_registration_types_.push_back(
209 BuiltinOperator_CUSTOM);
210 if (registration == nullptr) {
211 error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
212 status = kTfLiteError;
213 }
214 }
215 flatbuffer_op_index_to_registration_.push_back(registration);
216 }
217 return status;
218}
219
220namespace {
221template <class T>
222std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
223 std::vector<int> ret(flat_array->Length());
224 for (int i = 0; i < flat_array->Length(); i++) {
225 ret[i] = flat_array->Get(i);
226 }
227 return ret;
228}
229
230// Copies the contents from the flatbuffer int vector `flatbuffer` into the
231// int array `buffer`. `flat_vector` and `buffer` represent the same
232// configuration operation for a given operation.
233void FlatBufferIntVectorToArray(int max_size_of_buffer,
234 const flatbuffers::Vector<int32_t>* flat_vector,
235 int* buffer, ErrorReporter* error_reporter) {
236 if (!flat_vector) {
237 error_reporter->Report("Input array not provided for operation.\n");
238 } else {
239 int num_dimensions = flat_vector->Length();
240 if (num_dimensions > max_size_of_buffer / sizeof(int)) {
241 error_reporter->Report(
242 "Found too many dimensions in the operation's input array.\n");
243 } else {
244 for (int i = 0; i < num_dimensions; ++i) {
245 buffer[i] = flat_vector->Get(i);
246 }
247 }
248 }
249}
250
251// Allocate a structure using C malloc, but make sure the structure is a
252// POD structure that doesn't require constructors to run. The reason we do
253// this, is that Interpreter's C extension part will take ownership and wants
254// to use malloc() and free().
255template <class T>
256T* MallocPOD() {
257 static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
258 return static_cast<T*>(malloc(sizeof(T)));
259}
260
261// Parse the appropriate data out of the op.
262//
263// This handles builtin data explicitly as there are flatbuffer schemas.
264//
265// Returns memory that must be feed.
266//
267// TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program
268// crashes if error reporter is called.
269void* ParseOpData(const Operator* op, BuiltinOperator op_type,
270 ErrorReporter* error_reporter) {
271 auto parse_padding = [](Padding padding) {
272 switch (padding) {
273 case Padding_SAME:
274 return kTfLitePaddingSame;
275 case Padding_VALID:
276 return kTfLitePaddingValid;
277 }
278 return kTfLitePaddingUnknown;
279 };
280 auto parse_activation = [](ActivationFunctionType activation) {
281 switch (activation) {
282 case ActivationFunctionType_NONE:
283 return kTfLiteActNone;
284 case ActivationFunctionType_RELU:
285 return kTfLiteActRelu;
286 case ActivationFunctionType_RELU_N1_TO_1:
287 return kTfLiteActRelu1;
288 case ActivationFunctionType_RELU6:
289 return kTfLiteActRelu6;
290 case ActivationFunctionType_TANH:
291 return kTfLiteActTanh;
292 case ActivationFunctionType_SIGN_BIT:
293 return kTfLiteActSignBit;
294 }
295 return kTfLiteActNone;
296 };
297 auto parseLSHProjectionType = [](LSHProjectionType type) {
298 switch (type) {
299 case LSHProjectionType_SPARSE:
300 return kTfLiteLshProjectionSparse;
301 case LSHProjectionType_DENSE:
302 return kTfLiteLshProjectionDense;
303 default:
304 return kTfLiteLshProjectionUnknown;
305 }
306 };
307 auto parseCombinerType = [](CombinerType type) {
308 switch (type) {
309 case CombinerType_MEAN:
310 return kTfLiteCombinerTypeMean;
311 case CombinerType_SQRTN:
312 return kTfLiteCombinerTypeSqrtn;
313 case CombinerType_SUM:
314 default:
315 return kTfLiteCombinerTypeSum;
316 }
317 };
318
319 void* builtin_data = nullptr;
320 switch (op_type) {
321 case BuiltinOperator_CALL:
322 // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
323 // ok for now, since there is no call implementation either.
324 break;
325 case BuiltinOperator_CUSTOM:
326 break;
327 case BuiltinOperator_CONV_2D: {
328 TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
329 if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
330 params->padding = parse_padding(conv_params->padding());
331 params->stride_width = conv_params->stride_w();
332 params->stride_height = conv_params->stride_h();
333 params->activation =
334 parse_activation(conv_params->fused_activation_function());
335 }
336 builtin_data = reinterpret_cast<void*>(params);
337 break;
338 }
339 case BuiltinOperator_TANH:
340 case BuiltinOperator_LOGISTIC:
341 case BuiltinOperator_RELU:
342 case BuiltinOperator_RELU_N1_TO_1:
343 case BuiltinOperator_RELU6:
344 case BuiltinOperator_CONCAT_EMBEDDINGS:
345 case BuiltinOperator_EXP:
346 case BuiltinOperator_TOPK_V2:
347 case BuiltinOperator_LOG_SOFTMAX:
348 case BuiltinOperator_DEQUANTIZE:
349 case BuiltinOperator_PRELU:
350 break;
351 case BuiltinOperator_CAST: {
352 TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
353 if (auto* schema_params = op->builtin_options_as_CastOptions()) {
354 auto in_status =
355 ConvertTensorType(schema_params->in_data_type(),
356 &params->in_data_type, error_reporter);
357 auto out_status =
358 ConvertTensorType(schema_params->out_data_type(),
359 &params->out_data_type, error_reporter);
360 if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
361 break;
362 }
363 }
364 builtin_data = reinterpret_cast<void*>(params);
365 break;
366 }
367 case BuiltinOperator_LSH_PROJECTION: {
368 TfLiteLSHProjectionParams* params =
369 MallocPOD<TfLiteLSHProjectionParams>();
370 if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
371 params->type = parseLSHProjectionType(lshParams->type());
372 }
373 builtin_data = reinterpret_cast<void*>(params);
374 break;
375 }
376 case BuiltinOperator_AVERAGE_POOL_2D:
377 case BuiltinOperator_MAX_POOL_2D:
378 case BuiltinOperator_L2_POOL_2D: {
379 TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
380 if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
381 params->padding = parse_padding(pool_params->padding());
382 params->stride_width = pool_params->stride_w();
383 params->stride_height = pool_params->stride_h();
384 params->filter_width = pool_params->filter_width();
385 params->filter_height = pool_params->filter_height();
386 params->activation =
387 parse_activation(pool_params->fused_activation_function());
388 }
389 builtin_data = reinterpret_cast<void*>(params);
390 break;
391 }
392 case BuiltinOperator_DEPTHWISE_CONV_2D: {
393 TfLiteDepthwiseConvParams* params =
394 MallocPOD<TfLiteDepthwiseConvParams>();
395 if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
396 params->padding = parse_padding(conv_params->padding());
397 params->stride_width = conv_params->stride_w();
398 params->stride_height = conv_params->stride_h();
399 params->depth_multiplier = conv_params->depth_multiplier();
400 params->activation =
401 parse_activation(conv_params->fused_activation_function());
402 }
403 builtin_data = reinterpret_cast<void*>(params);
404 break;
405 }
406 case BuiltinOperator_SVDF: {
407 TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
408 if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
409 params->rank = svdf_params->rank();
410 params->activation =
411 parse_activation(svdf_params->fused_activation_function());
412 }
413 builtin_data = reinterpret_cast<void*>(params);
414 break;
415 }
416 case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
417 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
418 TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
419 if (auto* sequence_rnn_params =
420 op->builtin_options_as_SequenceRNNOptions()) {
421 params->activation =
422 parse_activation(sequence_rnn_params->fused_activation_function());
423 params->time_major = sequence_rnn_params->time_major();
424 }
425 builtin_data = reinterpret_cast<void*>(params);
426 break;
427 }
428 case BuiltinOperator_RNN: {
429 TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
430 if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
431 params->activation =
432 parse_activation(rnn_params->fused_activation_function());
433 }
434 builtin_data = reinterpret_cast<void*>(params);
435 break;
436 }
437 case BuiltinOperator_EMBEDDING_LOOKUP:
438 // no-op.
439 break;
440 case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
441 TfLiteEmbeddingLookupSparseParams* params =
442 MallocPOD<TfLiteEmbeddingLookupSparseParams>();
443 if (auto* embedding_params =
444 op->builtin_options_as_EmbeddingLookupSparseOptions()) {
445 params->combiner = parseCombinerType(embedding_params->combiner());
446 }
447 builtin_data = reinterpret_cast<void*>(params);
448 break;
449 }
450 case BuiltinOperator_FULLY_CONNECTED: {
451 TfLiteFullyConnectedParams* params =
452 MallocPOD<TfLiteFullyConnectedParams>();
453 if (auto* fully_connected_params =
454 op->builtin_options_as_FullyConnectedOptions()) {
455 params->activation = parse_activation(
456 fully_connected_params->fused_activation_function());
457 }
458 builtin_data = reinterpret_cast<void*>(params);
459 break;
460 }
461 case BuiltinOperator_HASHTABLE_LOOKUP:
462 // no-op.
463 break;
464 case BuiltinOperator_SOFTMAX: {
465 TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
466 if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
467 params->beta = softmax_params->beta();
468 }
469 builtin_data = reinterpret_cast<void*>(params);
470 break;
471 }
472 case BuiltinOperator_CONCATENATION: {
473 TfLiteConcatenationParams* params =
474 MallocPOD<TfLiteConcatenationParams>();
475 if (auto* concatenation_params =
476 op->builtin_options_as_ConcatenationOptions()) {
477 params->activation =
478 parse_activation(concatenation_params->fused_activation_function());
479 params->axis = concatenation_params->axis();
480 }
481 builtin_data = reinterpret_cast<void*>(params);
482 break;
483 }
484 case BuiltinOperator_MUL: {
485 auto* params = MallocPOD<TfLiteMulParams>();
486 if (auto* schema_params = op->builtin_options_as_MulOptions()) {
487 params->activation =
488 parse_activation(schema_params->fused_activation_function());
489 }
490 builtin_data = reinterpret_cast<void*>(params);
491 break;
492 }
493 case BuiltinOperator_ADD: {
494 auto* params = MallocPOD<TfLiteAddParams>();
495 if (auto* schema_params = op->builtin_options_as_AddOptions()) {
496 params->activation =
497 parse_activation(schema_params->fused_activation_function());
498 }
499 builtin_data = reinterpret_cast<void*>(params);
500 break;
501 }
502 case BuiltinOperator_DIV: {
503 auto* params = MallocPOD<TfLiteDivParams>();
504 if (auto* schema_params = op->builtin_options_as_DivOptions()) {
505 params->activation =
506 parse_activation(schema_params->fused_activation_function());
507 }
508 builtin_data = reinterpret_cast<void*>(params);
509 break;
510 }
511 case BuiltinOperator_SUB: {
512 auto* params = MallocPOD<TfLiteSubParams>();
513 if (auto* schema_params = op->builtin_options_as_SubOptions()) {
514 params->activation =
515 parse_activation(schema_params->fused_activation_function());
516 }
517 builtin_data = reinterpret_cast<void*>(params);
518 break;
519 }
520 case BuiltinOperator_L2_NORMALIZATION: {
521 auto* params = MallocPOD<TfLiteL2NormParams>();
522 if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
523 params->activation =
524 parse_activation(schema_params->fused_activation_function());
525 }
526 builtin_data = reinterpret_cast<void*>(params);
527 break;
528 }
529 case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
530 auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
531 if (auto* schema_params =
532 op->builtin_options_as_LocalResponseNormalizationOptions()) {
533 params->radius = schema_params->radius();
534 params->bias = schema_params->bias();
535 params->alpha = schema_params->alpha();
536 params->beta = schema_params->beta();
537 }
538 builtin_data = reinterpret_cast<void*>(params);
539 break;
540 }
541 case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
542 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
543 case BuiltinOperator_LSTM: {
544 TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
545 if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
546 params->activation =
547 parse_activation(lstm_params->fused_activation_function());
548 params->cell_clip = lstm_params->cell_clip();
549 params->proj_clip = lstm_params->proj_clip();
550 }
551 builtin_data = reinterpret_cast<void*>(params);
552 break;
553 }
554 case BuiltinOperator_RESIZE_BILINEAR: {
555 auto* params = MallocPOD<TfLiteResizeBilinearParams>();
556 if (auto* schema_params =
557 op->builtin_options_as_ResizeBilinearOptions()) {
558 params->align_corners = schema_params->align_corners();
559 }
560 builtin_data = reinterpret_cast<void*>(params);
561 break;
562 }
563 case BuiltinOperator_PAD: {
564 break;
565 }
566 case BuiltinOperator_RESHAPE: {
567 auto* params = MallocPOD<TfLiteReshapeParams>();
568 if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
569 auto* new_shape = schema_params->new_shape();
570 FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
571 params->shape, error_reporter);
572 params->num_dimensions = new_shape->Length();
573 }
574 builtin_data = reinterpret_cast<void*>(params);
575 break;
576 }
577 case BuiltinOperator_SKIP_GRAM: {
578 TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
579 if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
580 params->ngram_size = skip_gram_params->ngram_size();
581 params->max_skip_size = skip_gram_params->max_skip_size();
582 params->include_all_ngrams = skip_gram_params->include_all_ngrams();
583 }
584 builtin_data = reinterpret_cast<void*>(params);
585 break;
586 }
587 case BuiltinOperator_SPACE_TO_DEPTH: {
588 auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
589 if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
590 params->block_size = schema_params->block_size();
591 }
592 builtin_data = reinterpret_cast<void*>(params);
593 break;
594 }
595 case BuiltinOperator_GATHER: {
596 TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
597 params->axis = 0;
598 if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
599 params->axis = gather_params->axis();
600 }
601
602 builtin_data = reinterpret_cast<void*>(params);
603 break;
604 }
605 case BuiltinOperator_SPACE_TO_BATCH_ND: {
606 break;
607 }
608 case BuiltinOperator_BATCH_TO_SPACE_ND: {
609 break;
610 }
611 case BuiltinOperator_TRANSPOSE: {
612 break;
613 }
614 case BuiltinOperator_MEAN: {
615 auto* params = MallocPOD<TfLiteMeanParams>();
616 if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
617 params->keep_dims = schema_params->keep_dims();
618 }
619 builtin_data = reinterpret_cast<void*>(params);
620 break;
621 }
622 case BuiltinOperator_SPLIT: {
623 auto* params = MallocPOD<TfLiteSplitParams>();
624 if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
625 params->num_splits = schema_params->num_splits();
626 }
627 builtin_data = reinterpret_cast<void*>(params);
628 break;
629 }
630 case BuiltinOperator_SQUEEZE: {
631 auto* params = MallocPOD<TfLiteSqueezeParams>();
632 if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
633 const auto& squeeze_dims = schema_params->squeeze_dims();
634 FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
635 params->squeeze_dims, error_reporter);
636 params->num_squeeze_dims = squeeze_dims->Length();
637 }
638 builtin_data = reinterpret_cast<void*>(params);
639 break;
640 }
641 case BuiltinOperator_STRIDED_SLICE: {
642 auto* params = MallocPOD<TfLiteStridedSliceParams>();
643 if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
644 params->begin_mask = schema_params->begin_mask();
645 params->end_mask = schema_params->end_mask();
646 params->ellipsis_mask = schema_params->ellipsis_mask();
647 params->new_axis_mask = schema_params->new_axis_mask();
648 params->shrink_axis_mask = schema_params->shrink_axis_mask();
649 }
650 builtin_data = reinterpret_cast<void*>(params);
651 break;
652 }
653 case BuiltinOperator_MAXIMUM: {
654 break;
655 }
656 case BuiltinOperator_DELEGATE: {
657 // TODO(ycling): Revisit when supporting saving delegated models.
658 error_reporter->Report("DELEGATE op shouldn't exist in model.");
659 break;
660 }
661 }
662 return builtin_data;
663}
664
665} // namespace
666
667TfLiteStatus InterpreterBuilder::ParseNodes(
668 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
669 Interpreter* interpreter) {
670 TfLiteStatus status = kTfLiteOk;
671 for (int i = 0; i < operators->Length(); ++i) {
672 const auto* op = operators->Get(i);
673 int index = op->opcode_index();
674 if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
675 error_reporter_->Report("Missing registration for opcode_index %d\n",
676 index);
677 status = kTfLiteError;
678 continue;
679 }
680 const TfLiteRegistration* reg =
681 flatbuffer_op_index_to_registration_[op->opcode_index()];
682 if (reg == nullptr) {
683 error_reporter_->Report("Skipping op for opcode_index %d\n", index);
684 status = kTfLiteError;
685 continue;
686 }
687
688 auto op_type =
689 flatbuffer_op_index_to_registration_types_[op->opcode_index()];
690 if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
691 error_reporter_->Report(
692 "Found builtin operator %s with custom options.\n",
693 EnumNameBuiltinOperator(op_type));
694 }
695 if (op->custom_options()) {
696 interpreter->AddNodeWithParameters(
697 FlatBufferIntArrayToVector(op->inputs()),
698 FlatBufferIntArrayToVector(op->outputs()),
699 reinterpret_cast<const char*>(op->custom_options()->data()),
700 op->custom_options()->size(), nullptr, reg);
701 } else {
702 interpreter->AddNodeWithParameters(
703 FlatBufferIntArrayToVector(op->inputs()),
704 FlatBufferIntArrayToVector(op->outputs()), nullptr, 0,
705 ParseOpData(op, op_type, error_reporter_), reg);
706 }
707 }
708
709 return status;
710}
711
712TfLiteStatus InterpreterBuilder::ParseTensors(
713 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
714 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
715 Interpreter* interpreter) {
716 TfLiteStatus status = kTfLiteOk;
717
718 // A little helper to get the names of inputs and outputs. Note that they
719 // must outlive the interpreter.
720 auto get_name = [](const tflite::Tensor* t) -> const char* {
721 auto name = t->name();
722 if (name) return name->c_str();
723 return kEmptyTensorName;
724 };
725
726 for (int i = 0; i < tensors->Length(); ++i) {
727 const auto* tensor = tensors->Get(i);
728 std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
729
730 TfLiteQuantizationParams quantization;
731 quantization.scale = 0;
732 quantization.zero_point = 0;
733 auto* q_params = tensor->quantization();
734 if (q_params) {
735 // Note that the schema could hold per-channel quantization parameters
736 // but we really only support one value for the whole tensor.
737 // TODO(aselle): This breaks as well if these are nullptr's.
738 // TODO(aselle): This assumes non per-channel quantization.
739
740 if (q_params->scale()) {
741 if (q_params->scale()->size() != 1) {
742 error_reporter_->Report(
743 "QuantizationParam has %d scale values (only 1 is supported).",
744 q_params->scale()->size());
745 return kTfLiteError;
746 }
747 quantization.scale = q_params->scale()->Get(0);
748 }
749
750 if (q_params->zero_point()) {
751 if (q_params->zero_point()->size() != 1) {
752 error_reporter_->Report(
753 "QuantizationParam has %d zero_point values"
754 " (only 1 is supported).",
755 q_params->zero_point()->size());
756 return kTfLiteError;
757 }
758 quantization.zero_point = q_params->zero_point()->Get(0);
759 }
760 }
761
762 TfLiteType type;
763 if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
764 kTfLiteOk) {
765 status = kTfLiteError;
766 continue;
767 }
768 auto get_readonly_data = [&](const char** buffer_data,
769 size_t* buffer_size) {
770 // TODO(aselle): Check what happens if we have an unspecified size
771 // constant.
772 *buffer_data = nullptr;
773 if (tensor->buffer() == 0) return kTfLiteOk;
774 if (tensor->buffer() >= buffers->size()) {
775 error_reporter_->Report(
776 "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
777 i, tensor->buffer(), buffers->size());
778 return kTfLiteError;
779 }
780 if (auto* buffer = (*buffers)[tensor->buffer()]) {
781 if (auto* array = buffer->data()) {
782 if (size_t size = array->size()) {
783 *buffer_size = size;
784 *buffer_data = reinterpret_cast<const char*>(array->data());
785 return kTfLiteOk;
786 }
787 }
788 }
789 return kTfLiteOk;
790 };
791 size_t buffer_size = 0;
792 const char* buffer_ptr;
793 TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
794
795 if (buffer_ptr) {
796 if (interpreter->SetTensorParametersReadOnly(
797 i, type, get_name(tensor), dims, quantization, buffer_ptr,
798 buffer_size, allocation_) != kTfLiteOk) {
799 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
800 i);
801 status = kTfLiteError;
802 }
803 } else {
804 if (interpreter->SetTensorParametersReadWrite(
805 i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
806 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
807 i);
808 status = kTfLiteError;
809 }
810 }
811 }
812
813 return status;
814}
815
816TfLiteStatus InterpreterBuilder::operator()(
817 std::unique_ptr<Interpreter>* interpreter) {
818 return operator()(interpreter, /*num_threads=*/-1);
819}
820
821TfLiteStatus InterpreterBuilder::operator()(
822 std::unique_ptr<Interpreter>* interpreter, int num_threads) {
823 if (!interpreter) {
824 error_reporter_->Report(
825 "Null output pointer passed to InterpreterBuilder.");
826 return kTfLiteError;
827 }
828
829 // Safe exit by deleting partially created interpreter, to reduce verbosity
830 // on error conditions. Use by return cleanup_on_error();
831 auto cleanup_and_error = [&interpreter]() {
832 interpreter->reset();
833 return kTfLiteError;
834 };
835
836 if (!model_) {
837 error_reporter_->Report("Null pointer passed in as model.");
838 return cleanup_and_error();
839 }
840
841 if (model_->version() != TFLITE_SCHEMA_VERSION) {
842 error_reporter_->Report(
843 "Model provided is schema version %d not equal "
844 "to supported version %d.\n",
845 model_->version(), TFLITE_SCHEMA_VERSION);
846 return cleanup_and_error();
847 }
848
849 if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
850 error_reporter_->Report("Registration failed.\n");
851 return cleanup_and_error();
852 }
853
854 // Flatbuffer model schemas define a list of opcodes independent of the graph.
855 // We first map those to registrations. This reduces string lookups for custom
856 // ops since we only do it once per custom op rather than once per custom op
857 // invocation in the model graph.
858 // Construct interpreter with correct number of tensors and operators.
859 auto* subgraphs = model_->subgraphs();
860 auto* buffers = model_->buffers();
861 if (subgraphs->size() != 1) {
862 error_reporter_->Report("Only 1 subgraph is currently supported.\n");
863 return cleanup_and_error();
864 }
865 const tflite::SubGraph* subgraph = (*subgraphs)[0];
866 auto operators = subgraph->operators();
867 auto tensors = subgraph->tensors();
868 if (!operators || !tensors || !buffers) {
869 error_reporter_->Report(
870 "Did not get operators, tensors, or buffers in input flat buffer.\n");
871 return cleanup_and_error();
872 }
873 interpreter->reset(new Interpreter(error_reporter_));
874 if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
875 return cleanup_and_error();
876 }
877 // Set num threads
878 (**interpreter).SetNumThreads(num_threads);
879 // Parse inputs/outputs
880 (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
881 (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
882
883 // Finally setup nodes and tensors
884 if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
885 return cleanup_and_error();
886 if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
887 return cleanup_and_error();
888
889 return kTfLiteOk;
890}
891
892} // namespace tflite
893