1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <cstdio>
16#include <memory>
17#include <string>
18
19#include "absl/strings/string_view.h"
20#include "tensorflow/contrib/lite/toco/model.h"
21#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
22#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
23#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
24#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
25#include "tensorflow/contrib/lite/toco/toco_port.h"
26#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
27#include "tensorflow/contrib/lite/toco/toco_tooling.h"
28#include "tensorflow/contrib/lite/toco/toco_types.h"
29#include "tensorflow/core/platform/logging.h"
30
31namespace toco {
32namespace {
33
34// Checks the permissions of the output file to ensure it is writeable.
35void CheckOutputFilePermissions(const Arg<string>& output_file) {
36 QCHECK(output_file.specified()) << "Missing required flag --output_file.\n";
37 QCHECK(port::file::Writable(output_file.value()).ok())
38 << "Specified output_file is not writable: " << output_file.value()
39 << ".\n";
40}
41
42// Checks the permissions of the frozen model file.
43void CheckFrozenModelPermissions(const Arg<string>& input_file) {
44 QCHECK(input_file.specified()) << "Missing required flag --input_file.\n";
45 QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok())
46 << "Specified input_file does not exist: " << input_file.value() << ".\n";
47 QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok())
48 << "Specified input_file exists, but is not readable: "
49 << input_file.value() << ".\n";
50}
51
52// Checks the permissions of the SavedModel directory.
53void CheckSavedModelPermissions(const Arg<string>& savedmodel_directory) {
54 QCHECK(savedmodel_directory.specified())
55 << "Missing required flag --savedmodel_directory.\n";
56 QCHECK(
57 port::file::Exists(savedmodel_directory.value(), port::file::Defaults())
58 .ok())
59 << "Specified savedmodel_directory does not exist: "
60 << savedmodel_directory.value() << ".\n";
61}
62
63// Reads the contents of the GraphDef from either the frozen graph file or the
64// SavedModel directory. If it reads the SavedModel directory, it updates the
65// ModelFlags and TocoFlags accordingly.
66void ReadInputData(const ParsedTocoFlags& parsed_toco_flags,
67 const ParsedModelFlags& parsed_model_flags,
68 TocoFlags* toco_flags, ModelFlags* model_flags,
69 string* graph_def_contents) {
70 port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n");
71
72 bool has_input_file = parsed_toco_flags.input_file.specified();
73 bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified();
74
75 // Ensure either input_file or savedmodel_directory flag has been set.
76 QCHECK_NE(has_input_file, has_savedmodel_dir)
77 << "Specify either input_file or savedmodel_directory flag.\n";
78
79 // Checks the input file permissions and reads the contents.
80 if (has_input_file) {
81 CheckFrozenModelPermissions(parsed_toco_flags.input_file);
82 CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
83 graph_def_contents, port::file::Defaults())
84 .ok());
85 } else {
86 CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory);
87 GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags,
88 model_flags, graph_def_contents);
89 }
90}
91
92void ToolMain(const ParsedTocoFlags& parsed_toco_flags,
93 const ParsedModelFlags& parsed_model_flags) {
94 ModelFlags model_flags;
95 ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags);
96
97 TocoFlags toco_flags;
98 ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
99
100 string graph_def_contents;
101 ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags,
102 &model_flags, &graph_def_contents);
103 CheckOutputFilePermissions(parsed_toco_flags.output_file);
104
105 std::unique_ptr<Model> model =
106 Import(toco_flags, model_flags, graph_def_contents);
107 Transform(toco_flags, model.get());
108 string output_file_contents;
109 Export(toco_flags, *model, toco_flags.allow_custom_ops(),
110 &output_file_contents);
111 CHECK(port::file::SetContents(parsed_toco_flags.output_file.value(),
112 output_file_contents, port::file::Defaults())
113 .ok());
114}
115
116} // namespace
117} // namespace toco
118
119int main(int argc, char** argv) {
120 toco::string msg;
121 toco::ParsedTocoFlags parsed_toco_flags;
122 toco::ParsedModelFlags parsed_model_flags;
123
124 // If no args were specified, give a help string to be helpful.
125 int* effective_argc = &argc;
126 char** effective_argv = argv;
127 if (argc == 1) {
128 // No arguments, so manufacture help argv.
129 static int dummy_argc = 2;
130 static char* dummy_argv[] = {argv[0], const_cast<char*>("--help")};
131 effective_argc = &dummy_argc;
132 effective_argv = dummy_argv;
133 }
134
135 // Parse toco flags and command flags in sequence, each one strips off args,
136 // giving InitGoogle a chance to handle all remaining arguments.
137 bool toco_success = toco::ParseTocoFlagsFromCommandLineFlags(
138 effective_argc, effective_argv, &msg, &parsed_toco_flags);
139 bool model_success = toco::ParseModelFlagsFromCommandLineFlags(
140 effective_argc, effective_argv, &msg, &parsed_model_flags);
141 if (!toco_success || !model_success || !msg.empty()) {
142 fprintf(stderr, "%s", msg.c_str());
143 fflush(stderr);
144 return 1;
145 }
146 toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true);
147 toco::ToolMain(parsed_toco_flags, parsed_model_flags);
148}
149