1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <fcntl.h> |
16 | #include <stdint.h> |
17 | #include <stdio.h> |
18 | #include <stdlib.h> |
19 | #include <sys/mman.h> |
20 | #include <sys/stat.h> |
21 | #include <sys/types.h> |
22 | #include <unistd.h> |
23 | |
24 | #include "tensorflow/contrib/lite/allocation.h" |
25 | #include "tensorflow/contrib/lite/builtin_op_data.h" |
26 | #include "tensorflow/contrib/lite/error_reporter.h" |
27 | #include "tensorflow/contrib/lite/model.h" |
28 | #include "tensorflow/contrib/lite/nnapi_delegate.h" |
29 | #include "tensorflow/contrib/lite/version.h" |
30 | |
31 | namespace tflite { |
32 | |
33 | namespace { |
34 | // Ensure that ErrorReporter is non-null. |
35 | ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { |
36 | return e ? e : DefaultErrorReporter(); |
37 | } |
38 | } // namespace |
39 | |
40 | const char* kEmptyTensorName = "" ; |
41 | |
42 | TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, |
43 | ErrorReporter* error_reporter) { |
44 | switch (tensor_type) { |
45 | case TensorType_FLOAT32: |
46 | *type = kTfLiteFloat32; |
47 | break; |
48 | case TensorType_INT32: |
49 | *type = kTfLiteInt32; |
50 | break; |
51 | case TensorType_UINT8: |
52 | *type = kTfLiteUInt8; |
53 | break; |
54 | case TensorType_INT64: |
55 | *type = kTfLiteInt64; |
56 | break; |
57 | case TensorType_STRING: |
58 | *type = kTfLiteString; |
59 | break; |
60 | default: |
61 | error_reporter->Report("Unimplemented data type %s (%d) in tensor\n" , |
62 | EnumNameTensorType(tensor_type), tensor_type); |
63 | return kTfLiteError; |
64 | } |
65 | return kTfLiteOk; |
66 | } |
67 | |
68 | // Loads a model from `filename`. If `mmap_file` is true then use mmap, |
69 | // otherwise make a copy of the model in a buffer. |
70 | std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename, |
71 | bool mmap_file, |
72 | ErrorReporter* error_reporter, |
73 | bool use_nnapi) { |
74 | std::unique_ptr<Allocation> allocation; |
75 | if (mmap_file) { |
76 | if (use_nnapi && NNAPIExists()) |
77 | allocation.reset(new NNAPIAllocation(filename, error_reporter)); |
78 | else |
79 | allocation.reset(new MMAPAllocation(filename, error_reporter)); |
80 | } else { |
81 | allocation.reset(new FileCopyAllocation(filename, error_reporter)); |
82 | } |
83 | return allocation; |
84 | } |
85 | |
86 | std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile( |
87 | const char* filename, ErrorReporter* error_reporter) { |
88 | error_reporter = ValidateErrorReporter(error_reporter); |
89 | |
90 | std::unique_ptr<FlatBufferModel> model; |
91 | auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, |
92 | error_reporter, /*use_nnapi=*/true); |
93 | model.reset(new FlatBufferModel(allocation.release(), error_reporter)); |
94 | if (!model->initialized()) model.reset(); |
95 | return model; |
96 | } |
97 | |
98 | std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile( |
99 | const char* filename, TfLiteVerifier* verifier, |
100 | ErrorReporter* error_reporter) { |
101 | error_reporter = ValidateErrorReporter(error_reporter); |
102 | |
103 | std::unique_ptr<FlatBufferModel> model; |
104 | auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, |
105 | error_reporter, /*use_nnapi=*/true); |
106 | if (verifier && |
107 | !verifier->Verify(static_cast<const char*>(allocation->base()), |
108 | allocation->bytes(), error_reporter)) { |
109 | return model; |
110 | } |
111 | model.reset(new FlatBufferModel(allocation.release(), error_reporter)); |
112 | if (!model->initialized()) model.reset(); |
113 | return model; |
114 | } |
115 | |
116 | std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer( |
117 | const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { |
118 | error_reporter = ValidateErrorReporter(error_reporter); |
119 | |
120 | std::unique_ptr<FlatBufferModel> model; |
121 | Allocation* allocation = |
122 | new MemoryAllocation(buffer, buffer_size, error_reporter); |
123 | model.reset(new FlatBufferModel(allocation, error_reporter)); |
124 | if (!model->initialized()) model.reset(); |
125 | return model; |
126 | } |
127 | |
128 | std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel( |
129 | const tflite::Model* model_spec, ErrorReporter* error_reporter) { |
130 | error_reporter = ValidateErrorReporter(error_reporter); |
131 | |
132 | std::unique_ptr<FlatBufferModel> model; |
133 | model.reset(new FlatBufferModel(model_spec, error_reporter)); |
134 | if (!model->initialized()) model.reset(); |
135 | return model; |
136 | } |
137 | |
138 | bool FlatBufferModel::CheckModelIdentifier() const { |
139 | if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { |
140 | const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); |
141 | error_reporter_->Report( |
142 | "Model provided has model identifier '%c%c%c%c', should be '%s'\n" , |
143 | ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); |
144 | return false; |
145 | } |
146 | return true; |
147 | } |
148 | |
149 | FlatBufferModel::FlatBufferModel(const Model* model, |
150 | ErrorReporter* error_reporter) |
151 | : error_reporter_(ValidateErrorReporter(error_reporter)) { |
152 | model_ = model; |
153 | } |
154 | |
155 | FlatBufferModel::FlatBufferModel(Allocation* allocation, |
156 | ErrorReporter* error_reporter) |
157 | : error_reporter_(ValidateErrorReporter(error_reporter)) { |
158 | allocation_ = allocation; |
159 | if (!allocation_->valid() || !CheckModelIdentifier()) return; |
160 | |
161 | model_ = ::tflite::GetModel(allocation_->base()); |
162 | } |
163 | |
164 | FlatBufferModel::~FlatBufferModel() { delete allocation_; } |
165 | |
166 | InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, |
167 | const OpResolver& op_resolver) |
168 | : model_(model.GetModel()), |
169 | op_resolver_(op_resolver), |
170 | error_reporter_(ValidateErrorReporter(model.error_reporter())), |
171 | allocation_(model.allocation()) {} |
172 | |
173 | InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model, |
174 | const OpResolver& op_resolver, |
175 | ErrorReporter* error_reporter) |
176 | : model_(model), |
177 | op_resolver_(op_resolver), |
178 | error_reporter_(ValidateErrorReporter(error_reporter)) {} |
179 | |
180 | TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { |
181 | TfLiteStatus status = kTfLiteOk; |
182 | auto opcodes = model_->operator_codes(); |
183 | for (const OperatorCode* opcode : *opcodes) { |
184 | TfLiteRegistration* registration = nullptr; |
185 | auto builtin_code = opcode->builtin_code(); |
186 | if (builtin_code > BuiltinOperator_MAX || |
187 | builtin_code < BuiltinOperator_MIN) { |
188 | error_reporter_->Report( |
189 | "Op builtin_code out or range: %d. Are you using old TFLite binary " |
190 | "with newer model?" , |
191 | builtin_code); |
192 | status = kTfLiteError; |
193 | } else if (builtin_code != BuiltinOperator_CUSTOM) { |
194 | flatbuffer_op_index_to_registration_types_.push_back(builtin_code); |
195 | registration = op_resolver_.FindOp(builtin_code); |
196 | if (registration == nullptr) { |
197 | error_reporter_->Report("Didn't find op for builtin opcode '%s'\n" , |
198 | EnumNameBuiltinOperator(builtin_code)); |
199 | status = kTfLiteError; |
200 | } |
201 | } else if (!opcode->custom_code()) { |
202 | error_reporter_->Report( |
203 | "Operator with CUSTOM builtin_code has no custom_code.\n" ); |
204 | status = kTfLiteError; |
205 | } else { |
206 | const char* name = opcode->custom_code()->c_str(); |
207 | registration = op_resolver_.FindOp(name); |
208 | flatbuffer_op_index_to_registration_types_.push_back( |
209 | BuiltinOperator_CUSTOM); |
210 | if (registration == nullptr) { |
211 | error_reporter_->Report("Didn't find custom op for name '%s'\n" , name); |
212 | status = kTfLiteError; |
213 | } |
214 | } |
215 | flatbuffer_op_index_to_registration_.push_back(registration); |
216 | } |
217 | return status; |
218 | } |
219 | |
220 | namespace { |
221 | template <class T> |
222 | std::vector<int> FlatBufferIntArrayToVector(T* flat_array) { |
223 | std::vector<int> ret(flat_array->Length()); |
224 | for (int i = 0; i < flat_array->Length(); i++) { |
225 | ret[i] = flat_array->Get(i); |
226 | } |
227 | return ret; |
228 | } |
229 | |
230 | // Copies the contents from the flatbuffer int vector `flatbuffer` into the |
231 | // int array `buffer`. `flat_vector` and `buffer` represent the same |
232 | // configuration operation for a given operation. |
233 | void FlatBufferIntVectorToArray(int max_size_of_buffer, |
234 | const flatbuffers::Vector<int32_t>* flat_vector, |
235 | int* buffer, ErrorReporter* error_reporter) { |
236 | if (!flat_vector) { |
237 | error_reporter->Report("Input array not provided for operation.\n" ); |
238 | } else { |
239 | int num_dimensions = flat_vector->Length(); |
240 | if (num_dimensions > max_size_of_buffer / sizeof(int)) { |
241 | error_reporter->Report( |
242 | "Found too many dimensions in the operation's input array.\n" ); |
243 | } else { |
244 | for (int i = 0; i < num_dimensions; ++i) { |
245 | buffer[i] = flat_vector->Get(i); |
246 | } |
247 | } |
248 | } |
249 | } |
250 | |
251 | // Allocate a structure using C malloc, but make sure the structure is a |
252 | // POD structure that doesn't require constructors to run. The reason we do |
253 | // this, is that Interpreter's C extension part will take ownership and wants |
254 | // to use malloc() and free(). |
255 | template <class T> |
256 | T* MallocPOD() { |
257 | static_assert(std::is_pod<T>::value, "Builtin data structure must be POD." ); |
258 | return static_cast<T*>(malloc(sizeof(T))); |
259 | } |
260 | |
261 | // Parse the appropriate data out of the op. |
262 | // |
263 | // This handles builtin data explicitly as there are flatbuffer schemas. |
264 | // |
265 | // Returns memory that must be feed. |
266 | // |
267 | // TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program |
268 | // crashes if error reporter is called. |
269 | void* ParseOpData(const Operator* op, BuiltinOperator op_type, |
270 | ErrorReporter* error_reporter) { |
271 | auto parse_padding = [](Padding padding) { |
272 | switch (padding) { |
273 | case Padding_SAME: |
274 | return kTfLitePaddingSame; |
275 | case Padding_VALID: |
276 | return kTfLitePaddingValid; |
277 | } |
278 | return kTfLitePaddingUnknown; |
279 | }; |
280 | auto parse_activation = [](ActivationFunctionType activation) { |
281 | switch (activation) { |
282 | case ActivationFunctionType_NONE: |
283 | return kTfLiteActNone; |
284 | case ActivationFunctionType_RELU: |
285 | return kTfLiteActRelu; |
286 | case ActivationFunctionType_RELU_N1_TO_1: |
287 | return kTfLiteActRelu1; |
288 | case ActivationFunctionType_RELU6: |
289 | return kTfLiteActRelu6; |
290 | case ActivationFunctionType_TANH: |
291 | return kTfLiteActTanh; |
292 | case ActivationFunctionType_SIGN_BIT: |
293 | return kTfLiteActSignBit; |
294 | } |
295 | return kTfLiteActNone; |
296 | }; |
297 | auto parseLSHProjectionType = [](LSHProjectionType type) { |
298 | switch (type) { |
299 | case LSHProjectionType_SPARSE: |
300 | return kTfLiteLshProjectionSparse; |
301 | case LSHProjectionType_DENSE: |
302 | return kTfLiteLshProjectionDense; |
303 | default: |
304 | return kTfLiteLshProjectionUnknown; |
305 | } |
306 | }; |
307 | auto parseCombinerType = [](CombinerType type) { |
308 | switch (type) { |
309 | case CombinerType_MEAN: |
310 | return kTfLiteCombinerTypeMean; |
311 | case CombinerType_SQRTN: |
312 | return kTfLiteCombinerTypeSqrtn; |
313 | case CombinerType_SUM: |
314 | default: |
315 | return kTfLiteCombinerTypeSum; |
316 | } |
317 | }; |
318 | |
319 | void* builtin_data = nullptr; |
320 | switch (op_type) { |
321 | case BuiltinOperator_CALL: |
322 | // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are |
323 | // ok for now, since there is no call implementation either. |
324 | break; |
325 | case BuiltinOperator_CUSTOM: |
326 | break; |
327 | case BuiltinOperator_CONV_2D: { |
328 | TfLiteConvParams* params = MallocPOD<TfLiteConvParams>(); |
329 | if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { |
330 | params->padding = parse_padding(conv_params->padding()); |
331 | params->stride_width = conv_params->stride_w(); |
332 | params->stride_height = conv_params->stride_h(); |
333 | params->activation = |
334 | parse_activation(conv_params->fused_activation_function()); |
335 | } |
336 | builtin_data = reinterpret_cast<void*>(params); |
337 | break; |
338 | } |
339 | case BuiltinOperator_TANH: |
340 | case BuiltinOperator_LOGISTIC: |
341 | case BuiltinOperator_RELU: |
342 | case BuiltinOperator_RELU_N1_TO_1: |
343 | case BuiltinOperator_RELU6: |
344 | case BuiltinOperator_CONCAT_EMBEDDINGS: |
345 | case BuiltinOperator_EXP: |
346 | case BuiltinOperator_TOPK_V2: |
347 | case BuiltinOperator_LOG_SOFTMAX: |
348 | case BuiltinOperator_DEQUANTIZE: |
349 | case BuiltinOperator_PRELU: |
350 | break; |
351 | case BuiltinOperator_CAST: { |
352 | TfLiteCastParams* params = MallocPOD<TfLiteCastParams>(); |
353 | if (auto* schema_params = op->builtin_options_as_CastOptions()) { |
354 | auto in_status = |
355 | ConvertTensorType(schema_params->in_data_type(), |
356 | ¶ms->in_data_type, error_reporter); |
357 | auto out_status = |
358 | ConvertTensorType(schema_params->out_data_type(), |
359 | ¶ms->out_data_type, error_reporter); |
360 | if (in_status != kTfLiteOk || out_status != kTfLiteOk) { |
361 | break; |
362 | } |
363 | } |
364 | builtin_data = reinterpret_cast<void*>(params); |
365 | break; |
366 | } |
367 | case BuiltinOperator_LSH_PROJECTION: { |
368 | TfLiteLSHProjectionParams* params = |
369 | MallocPOD<TfLiteLSHProjectionParams>(); |
370 | if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { |
371 | params->type = parseLSHProjectionType(lshParams->type()); |
372 | } |
373 | builtin_data = reinterpret_cast<void*>(params); |
374 | break; |
375 | } |
376 | case BuiltinOperator_AVERAGE_POOL_2D: |
377 | case BuiltinOperator_MAX_POOL_2D: |
378 | case BuiltinOperator_L2_POOL_2D: { |
379 | TfLitePoolParams* params = MallocPOD<TfLitePoolParams>(); |
380 | if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { |
381 | params->padding = parse_padding(pool_params->padding()); |
382 | params->stride_width = pool_params->stride_w(); |
383 | params->stride_height = pool_params->stride_h(); |
384 | params->filter_width = pool_params->filter_width(); |
385 | params->filter_height = pool_params->filter_height(); |
386 | params->activation = |
387 | parse_activation(pool_params->fused_activation_function()); |
388 | } |
389 | builtin_data = reinterpret_cast<void*>(params); |
390 | break; |
391 | } |
392 | case BuiltinOperator_DEPTHWISE_CONV_2D: { |
393 | TfLiteDepthwiseConvParams* params = |
394 | MallocPOD<TfLiteDepthwiseConvParams>(); |
395 | if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { |
396 | params->padding = parse_padding(conv_params->padding()); |
397 | params->stride_width = conv_params->stride_w(); |
398 | params->stride_height = conv_params->stride_h(); |
399 | params->depth_multiplier = conv_params->depth_multiplier(); |
400 | params->activation = |
401 | parse_activation(conv_params->fused_activation_function()); |
402 | } |
403 | builtin_data = reinterpret_cast<void*>(params); |
404 | break; |
405 | } |
406 | case BuiltinOperator_SVDF: { |
407 | TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>(); |
408 | if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { |
409 | params->rank = svdf_params->rank(); |
410 | params->activation = |
411 | parse_activation(svdf_params->fused_activation_function()); |
412 | } |
413 | builtin_data = reinterpret_cast<void*>(params); |
414 | break; |
415 | } |
416 | case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: |
417 | case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { |
418 | TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>(); |
419 | if (auto* sequence_rnn_params = |
420 | op->builtin_options_as_SequenceRNNOptions()) { |
421 | params->activation = |
422 | parse_activation(sequence_rnn_params->fused_activation_function()); |
423 | params->time_major = sequence_rnn_params->time_major(); |
424 | } |
425 | builtin_data = reinterpret_cast<void*>(params); |
426 | break; |
427 | } |
428 | case BuiltinOperator_RNN: { |
429 | TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>(); |
430 | if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { |
431 | params->activation = |
432 | parse_activation(rnn_params->fused_activation_function()); |
433 | } |
434 | builtin_data = reinterpret_cast<void*>(params); |
435 | break; |
436 | } |
437 | case BuiltinOperator_EMBEDDING_LOOKUP: |
438 | // no-op. |
439 | break; |
440 | case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { |
441 | TfLiteEmbeddingLookupSparseParams* params = |
442 | MallocPOD<TfLiteEmbeddingLookupSparseParams>(); |
443 | if (auto* embedding_params = |
444 | op->builtin_options_as_EmbeddingLookupSparseOptions()) { |
445 | params->combiner = parseCombinerType(embedding_params->combiner()); |
446 | } |
447 | builtin_data = reinterpret_cast<void*>(params); |
448 | break; |
449 | } |
450 | case BuiltinOperator_FULLY_CONNECTED: { |
451 | TfLiteFullyConnectedParams* params = |
452 | MallocPOD<TfLiteFullyConnectedParams>(); |
453 | if (auto* fully_connected_params = |
454 | op->builtin_options_as_FullyConnectedOptions()) { |
455 | params->activation = parse_activation( |
456 | fully_connected_params->fused_activation_function()); |
457 | } |
458 | builtin_data = reinterpret_cast<void*>(params); |
459 | break; |
460 | } |
461 | case BuiltinOperator_HASHTABLE_LOOKUP: |
462 | // no-op. |
463 | break; |
464 | case BuiltinOperator_SOFTMAX: { |
465 | TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>(); |
466 | if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { |
467 | params->beta = softmax_params->beta(); |
468 | } |
469 | builtin_data = reinterpret_cast<void*>(params); |
470 | break; |
471 | } |
472 | case BuiltinOperator_CONCATENATION: { |
473 | TfLiteConcatenationParams* params = |
474 | MallocPOD<TfLiteConcatenationParams>(); |
475 | if (auto* concatenation_params = |
476 | op->builtin_options_as_ConcatenationOptions()) { |
477 | params->activation = |
478 | parse_activation(concatenation_params->fused_activation_function()); |
479 | params->axis = concatenation_params->axis(); |
480 | } |
481 | builtin_data = reinterpret_cast<void*>(params); |
482 | break; |
483 | } |
484 | case BuiltinOperator_MUL: { |
485 | auto* params = MallocPOD<TfLiteMulParams>(); |
486 | if (auto* schema_params = op->builtin_options_as_MulOptions()) { |
487 | params->activation = |
488 | parse_activation(schema_params->fused_activation_function()); |
489 | } |
490 | builtin_data = reinterpret_cast<void*>(params); |
491 | break; |
492 | } |
493 | case BuiltinOperator_ADD: { |
494 | auto* params = MallocPOD<TfLiteAddParams>(); |
495 | if (auto* schema_params = op->builtin_options_as_AddOptions()) { |
496 | params->activation = |
497 | parse_activation(schema_params->fused_activation_function()); |
498 | } |
499 | builtin_data = reinterpret_cast<void*>(params); |
500 | break; |
501 | } |
502 | case BuiltinOperator_DIV: { |
503 | auto* params = MallocPOD<TfLiteDivParams>(); |
504 | if (auto* schema_params = op->builtin_options_as_DivOptions()) { |
505 | params->activation = |
506 | parse_activation(schema_params->fused_activation_function()); |
507 | } |
508 | builtin_data = reinterpret_cast<void*>(params); |
509 | break; |
510 | } |
511 | case BuiltinOperator_SUB: { |
512 | auto* params = MallocPOD<TfLiteSubParams>(); |
513 | if (auto* schema_params = op->builtin_options_as_SubOptions()) { |
514 | params->activation = |
515 | parse_activation(schema_params->fused_activation_function()); |
516 | } |
517 | builtin_data = reinterpret_cast<void*>(params); |
518 | break; |
519 | } |
520 | case BuiltinOperator_L2_NORMALIZATION: { |
521 | auto* params = MallocPOD<TfLiteL2NormParams>(); |
522 | if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { |
523 | params->activation = |
524 | parse_activation(schema_params->fused_activation_function()); |
525 | } |
526 | builtin_data = reinterpret_cast<void*>(params); |
527 | break; |
528 | } |
529 | case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { |
530 | auto* params = MallocPOD<TfLiteLocalResponseNormParams>(); |
531 | if (auto* schema_params = |
532 | op->builtin_options_as_LocalResponseNormalizationOptions()) { |
533 | params->radius = schema_params->radius(); |
534 | params->bias = schema_params->bias(); |
535 | params->alpha = schema_params->alpha(); |
536 | params->beta = schema_params->beta(); |
537 | } |
538 | builtin_data = reinterpret_cast<void*>(params); |
539 | break; |
540 | } |
541 | case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: |
542 | case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: |
543 | case BuiltinOperator_LSTM: { |
544 | TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>(); |
545 | if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { |
546 | params->activation = |
547 | parse_activation(lstm_params->fused_activation_function()); |
548 | params->cell_clip = lstm_params->cell_clip(); |
549 | params->proj_clip = lstm_params->proj_clip(); |
550 | } |
551 | builtin_data = reinterpret_cast<void*>(params); |
552 | break; |
553 | } |
554 | case BuiltinOperator_RESIZE_BILINEAR: { |
555 | auto* params = MallocPOD<TfLiteResizeBilinearParams>(); |
556 | if (auto* schema_params = |
557 | op->builtin_options_as_ResizeBilinearOptions()) { |
558 | params->align_corners = schema_params->align_corners(); |
559 | } |
560 | builtin_data = reinterpret_cast<void*>(params); |
561 | break; |
562 | } |
563 | case BuiltinOperator_PAD: { |
564 | break; |
565 | } |
566 | case BuiltinOperator_RESHAPE: { |
567 | auto* params = MallocPOD<TfLiteReshapeParams>(); |
568 | if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { |
569 | auto* new_shape = schema_params->new_shape(); |
570 | FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, |
571 | params->shape, error_reporter); |
572 | params->num_dimensions = new_shape->Length(); |
573 | } |
574 | builtin_data = reinterpret_cast<void*>(params); |
575 | break; |
576 | } |
577 | case BuiltinOperator_SKIP_GRAM: { |
578 | TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>(); |
579 | if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { |
580 | params->ngram_size = skip_gram_params->ngram_size(); |
581 | params->max_skip_size = skip_gram_params->max_skip_size(); |
582 | params->include_all_ngrams = skip_gram_params->include_all_ngrams(); |
583 | } |
584 | builtin_data = reinterpret_cast<void*>(params); |
585 | break; |
586 | } |
587 | case BuiltinOperator_SPACE_TO_DEPTH: { |
588 | auto* params = MallocPOD<TfLiteSpaceToDepthParams>(); |
589 | if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { |
590 | params->block_size = schema_params->block_size(); |
591 | } |
592 | builtin_data = reinterpret_cast<void*>(params); |
593 | break; |
594 | } |
595 | case BuiltinOperator_GATHER: { |
596 | TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>(); |
597 | params->axis = 0; |
598 | if (auto* gather_params = op->builtin_options_as_GatherOptions()) { |
599 | params->axis = gather_params->axis(); |
600 | } |
601 | |
602 | builtin_data = reinterpret_cast<void*>(params); |
603 | break; |
604 | } |
605 | case BuiltinOperator_SPACE_TO_BATCH_ND: { |
606 | break; |
607 | } |
608 | case BuiltinOperator_BATCH_TO_SPACE_ND: { |
609 | break; |
610 | } |
611 | case BuiltinOperator_TRANSPOSE: { |
612 | break; |
613 | } |
614 | case BuiltinOperator_MEAN: { |
615 | auto* params = MallocPOD<TfLiteMeanParams>(); |
616 | if (auto* schema_params = op->builtin_options_as_MeanOptions()) { |
617 | params->keep_dims = schema_params->keep_dims(); |
618 | } |
619 | builtin_data = reinterpret_cast<void*>(params); |
620 | break; |
621 | } |
622 | case BuiltinOperator_SPLIT: { |
623 | auto* params = MallocPOD<TfLiteSplitParams>(); |
624 | if (auto* schema_params = op->builtin_options_as_SplitOptions()) { |
625 | params->num_splits = schema_params->num_splits(); |
626 | } |
627 | builtin_data = reinterpret_cast<void*>(params); |
628 | break; |
629 | } |
630 | case BuiltinOperator_SQUEEZE: { |
631 | auto* params = MallocPOD<TfLiteSqueezeParams>(); |
632 | if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { |
633 | const auto& squeeze_dims = schema_params->squeeze_dims(); |
634 | FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, |
635 | params->squeeze_dims, error_reporter); |
636 | params->num_squeeze_dims = squeeze_dims->Length(); |
637 | } |
638 | builtin_data = reinterpret_cast<void*>(params); |
639 | break; |
640 | } |
641 | case BuiltinOperator_STRIDED_SLICE: { |
642 | auto* params = MallocPOD<TfLiteStridedSliceParams>(); |
643 | if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { |
644 | params->begin_mask = schema_params->begin_mask(); |
645 | params->end_mask = schema_params->end_mask(); |
646 | params->ellipsis_mask = schema_params->ellipsis_mask(); |
647 | params->new_axis_mask = schema_params->new_axis_mask(); |
648 | params->shrink_axis_mask = schema_params->shrink_axis_mask(); |
649 | } |
650 | builtin_data = reinterpret_cast<void*>(params); |
651 | break; |
652 | } |
653 | case BuiltinOperator_MAXIMUM: { |
654 | break; |
655 | } |
656 | case BuiltinOperator_DELEGATE: { |
657 | // TODO(ycling): Revisit when supporting saving delegated models. |
658 | error_reporter->Report("DELEGATE op shouldn't exist in model." ); |
659 | break; |
660 | } |
661 | } |
662 | return builtin_data; |
663 | } |
664 | |
665 | } // namespace |
666 | |
667 | TfLiteStatus InterpreterBuilder::ParseNodes( |
668 | const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, |
669 | Interpreter* interpreter) { |
670 | TfLiteStatus status = kTfLiteOk; |
671 | for (int i = 0; i < operators->Length(); ++i) { |
672 | const auto* op = operators->Get(i); |
673 | int index = op->opcode_index(); |
674 | if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) { |
675 | error_reporter_->Report("Missing registration for opcode_index %d\n" , |
676 | index); |
677 | status = kTfLiteError; |
678 | continue; |
679 | } |
680 | const TfLiteRegistration* reg = |
681 | flatbuffer_op_index_to_registration_[op->opcode_index()]; |
682 | if (reg == nullptr) { |
683 | error_reporter_->Report("Skipping op for opcode_index %d\n" , index); |
684 | status = kTfLiteError; |
685 | continue; |
686 | } |
687 | |
688 | auto op_type = |
689 | flatbuffer_op_index_to_registration_types_[op->opcode_index()]; |
690 | if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { |
691 | error_reporter_->Report( |
692 | "Found builtin operator %s with custom options.\n" , |
693 | EnumNameBuiltinOperator(op_type)); |
694 | } |
695 | if (op->custom_options()) { |
696 | interpreter->AddNodeWithParameters( |
697 | FlatBufferIntArrayToVector(op->inputs()), |
698 | FlatBufferIntArrayToVector(op->outputs()), |
699 | reinterpret_cast<const char*>(op->custom_options()->data()), |
700 | op->custom_options()->size(), nullptr, reg); |
701 | } else { |
702 | interpreter->AddNodeWithParameters( |
703 | FlatBufferIntArrayToVector(op->inputs()), |
704 | FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, |
705 | ParseOpData(op, op_type, error_reporter_), reg); |
706 | } |
707 | } |
708 | |
709 | return status; |
710 | } |
711 | |
712 | TfLiteStatus InterpreterBuilder::ParseTensors( |
713 | const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, |
714 | const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, |
715 | Interpreter* interpreter) { |
716 | TfLiteStatus status = kTfLiteOk; |
717 | |
718 | // A little helper to get the names of inputs and outputs. Note that they |
719 | // must outlive the interpreter. |
720 | auto get_name = [](const tflite::Tensor* t) -> const char* { |
721 | auto name = t->name(); |
722 | if (name) return name->c_str(); |
723 | return kEmptyTensorName; |
724 | }; |
725 | |
726 | for (int i = 0; i < tensors->Length(); ++i) { |
727 | const auto* tensor = tensors->Get(i); |
728 | std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape()); |
729 | |
730 | TfLiteQuantizationParams quantization; |
731 | quantization.scale = 0; |
732 | quantization.zero_point = 0; |
733 | auto* q_params = tensor->quantization(); |
734 | if (q_params) { |
735 | // Note that the schema could hold per-channel quantization parameters |
736 | // but we really only support one value for the whole tensor. |
737 | // TODO(aselle): This breaks as well if these are nullptr's. |
738 | // TODO(aselle): This assumes non per-channel quantization. |
739 | |
740 | if (q_params->scale()) { |
741 | if (q_params->scale()->size() != 1) { |
742 | error_reporter_->Report( |
743 | "QuantizationParam has %d scale values (only 1 is supported)." , |
744 | q_params->scale()->size()); |
745 | return kTfLiteError; |
746 | } |
747 | quantization.scale = q_params->scale()->Get(0); |
748 | } |
749 | |
750 | if (q_params->zero_point()) { |
751 | if (q_params->zero_point()->size() != 1) { |
752 | error_reporter_->Report( |
753 | "QuantizationParam has %d zero_point values" |
754 | " (only 1 is supported)." , |
755 | q_params->zero_point()->size()); |
756 | return kTfLiteError; |
757 | } |
758 | quantization.zero_point = q_params->zero_point()->Get(0); |
759 | } |
760 | } |
761 | |
762 | TfLiteType type; |
763 | if (ConvertTensorType(tensor->type(), &type, error_reporter_) != |
764 | kTfLiteOk) { |
765 | status = kTfLiteError; |
766 | continue; |
767 | } |
768 | auto get_readonly_data = [&](const char** buffer_data, |
769 | size_t* buffer_size) { |
770 | // TODO(aselle): Check what happens if we have an unspecified size |
771 | // constant. |
772 | *buffer_data = nullptr; |
773 | if (tensor->buffer() == 0) return kTfLiteOk; |
774 | if (tensor->buffer() >= buffers->size()) { |
775 | error_reporter_->Report( |
776 | "Tensor %d specifies out of range buffer %d (only %d buffers).\n" , |
777 | i, tensor->buffer(), buffers->size()); |
778 | return kTfLiteError; |
779 | } |
780 | if (auto* buffer = (*buffers)[tensor->buffer()]) { |
781 | if (auto* array = buffer->data()) { |
782 | if (size_t size = array->size()) { |
783 | *buffer_size = size; |
784 | *buffer_data = reinterpret_cast<const char*>(array->data()); |
785 | return kTfLiteOk; |
786 | } |
787 | } |
788 | } |
789 | return kTfLiteOk; |
790 | }; |
791 | size_t buffer_size = 0; |
792 | const char* buffer_ptr; |
793 | TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); |
794 | |
795 | if (buffer_ptr) { |
796 | if (interpreter->SetTensorParametersReadOnly( |
797 | i, type, get_name(tensor), dims, quantization, buffer_ptr, |
798 | buffer_size, allocation_) != kTfLiteOk) { |
799 | error_reporter_->Report("Tensor %d is invalidly specified in schema.\n" , |
800 | i); |
801 | status = kTfLiteError; |
802 | } |
803 | } else { |
804 | if (interpreter->SetTensorParametersReadWrite( |
805 | i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { |
806 | error_reporter_->Report("Tensor %d is invalidly specified in schema.\n" , |
807 | i); |
808 | status = kTfLiteError; |
809 | } |
810 | } |
811 | } |
812 | |
813 | return status; |
814 | } |
815 | |
816 | TfLiteStatus InterpreterBuilder::operator()( |
817 | std::unique_ptr<Interpreter>* interpreter) { |
818 | return operator()(interpreter, /*num_threads=*/-1); |
819 | } |
820 | |
821 | TfLiteStatus InterpreterBuilder::operator()( |
822 | std::unique_ptr<Interpreter>* interpreter, int num_threads) { |
823 | if (!interpreter) { |
824 | error_reporter_->Report( |
825 | "Null output pointer passed to InterpreterBuilder." ); |
826 | return kTfLiteError; |
827 | } |
828 | |
829 | // Safe exit by deleting partially created interpreter, to reduce verbosity |
830 | // on error conditions. Use by return cleanup_on_error(); |
831 | auto cleanup_and_error = [&interpreter]() { |
832 | interpreter->reset(); |
833 | return kTfLiteError; |
834 | }; |
835 | |
836 | if (!model_) { |
837 | error_reporter_->Report("Null pointer passed in as model." ); |
838 | return cleanup_and_error(); |
839 | } |
840 | |
841 | if (model_->version() != TFLITE_SCHEMA_VERSION) { |
842 | error_reporter_->Report( |
843 | "Model provided is schema version %d not equal " |
844 | "to supported version %d.\n" , |
845 | model_->version(), TFLITE_SCHEMA_VERSION); |
846 | return cleanup_and_error(); |
847 | } |
848 | |
849 | if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) { |
850 | error_reporter_->Report("Registration failed.\n" ); |
851 | return cleanup_and_error(); |
852 | } |
853 | |
854 | // Flatbuffer model schemas define a list of opcodes independent of the graph. |
855 | // We first map those to registrations. This reduces string lookups for custom |
856 | // ops since we only do it once per custom op rather than once per custom op |
857 | // invocation in the model graph. |
858 | // Construct interpreter with correct number of tensors and operators. |
859 | auto* subgraphs = model_->subgraphs(); |
860 | auto* buffers = model_->buffers(); |
861 | if (subgraphs->size() != 1) { |
862 | error_reporter_->Report("Only 1 subgraph is currently supported.\n" ); |
863 | return cleanup_and_error(); |
864 | } |
865 | const tflite::SubGraph* subgraph = (*subgraphs)[0]; |
866 | auto operators = subgraph->operators(); |
867 | auto tensors = subgraph->tensors(); |
868 | if (!operators || !tensors || !buffers) { |
869 | error_reporter_->Report( |
870 | "Did not get operators, tensors, or buffers in input flat buffer.\n" ); |
871 | return cleanup_and_error(); |
872 | } |
873 | interpreter->reset(new Interpreter(error_reporter_)); |
874 | if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { |
875 | return cleanup_and_error(); |
876 | } |
877 | // Set num threads |
878 | (**interpreter).SetNumThreads(num_threads); |
879 | // Parse inputs/outputs |
880 | (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); |
881 | (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); |
882 | |
883 | // Finally setup nodes and tensors |
884 | if (ParseNodes(operators, interpreter->get()) != kTfLiteOk) |
885 | return cleanup_and_error(); |
886 | if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) |
887 | return cleanup_and_error(); |
888 | |
889 | return kTfLiteOk; |
890 | } |
891 | |
892 | } // namespace tflite |
893 | |