1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <cstdio> |
16 | #include <memory> |
17 | #include <string> |
18 | |
19 | #include "absl/strings/string_view.h" |
20 | #include "tensorflow/contrib/lite/toco/model.h" |
21 | #include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" |
22 | #include "tensorflow/contrib/lite/toco/model_flags.pb.h" |
23 | #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" |
24 | #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" |
25 | #include "tensorflow/contrib/lite/toco/toco_port.h" |
26 | #include "tensorflow/contrib/lite/toco/toco_saved_model.h" |
27 | #include "tensorflow/contrib/lite/toco/toco_tooling.h" |
28 | #include "tensorflow/contrib/lite/toco/toco_types.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | |
31 | namespace toco { |
32 | namespace { |
33 | |
34 | // Checks the permissions of the output file to ensure it is writeable. |
35 | void CheckOutputFilePermissions(const Arg<string>& output_file) { |
36 | QCHECK(output_file.specified()) << "Missing required flag --output_file.\n" ; |
37 | QCHECK(port::file::Writable(output_file.value()).ok()) |
38 | << "Specified output_file is not writable: " << output_file.value() |
39 | << ".\n" ; |
40 | } |
41 | |
42 | // Checks the permissions of the frozen model file. |
43 | void CheckFrozenModelPermissions(const Arg<string>& input_file) { |
44 | QCHECK(input_file.specified()) << "Missing required flag --input_file.\n" ; |
45 | QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok()) |
46 | << "Specified input_file does not exist: " << input_file.value() << ".\n" ; |
47 | QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok()) |
48 | << "Specified input_file exists, but is not readable: " |
49 | << input_file.value() << ".\n" ; |
50 | } |
51 | |
52 | // Checks the permissions of the SavedModel directory. |
53 | void CheckSavedModelPermissions(const Arg<string>& savedmodel_directory) { |
54 | QCHECK(savedmodel_directory.specified()) |
55 | << "Missing required flag --savedmodel_directory.\n" ; |
56 | QCHECK( |
57 | port::file::Exists(savedmodel_directory.value(), port::file::Defaults()) |
58 | .ok()) |
59 | << "Specified savedmodel_directory does not exist: " |
60 | << savedmodel_directory.value() << ".\n" ; |
61 | } |
62 | |
63 | // Reads the contents of the GraphDef from either the frozen graph file or the |
64 | // SavedModel directory. If it reads the SavedModel directory, it updates the |
65 | // ModelFlags and TocoFlags accordingly. |
66 | void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, |
67 | const ParsedModelFlags& parsed_model_flags, |
68 | TocoFlags* toco_flags, ModelFlags* model_flags, |
69 | string* graph_def_contents) { |
70 | port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n" ); |
71 | |
72 | bool has_input_file = parsed_toco_flags.input_file.specified(); |
73 | bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified(); |
74 | |
75 | // Ensure either input_file or savedmodel_directory flag has been set. |
76 | QCHECK_NE(has_input_file, has_savedmodel_dir) |
77 | << "Specify either input_file or savedmodel_directory flag.\n" ; |
78 | |
79 | // Checks the input file permissions and reads the contents. |
80 | if (has_input_file) { |
81 | CheckFrozenModelPermissions(parsed_toco_flags.input_file); |
82 | CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), |
83 | graph_def_contents, port::file::Defaults()) |
84 | .ok()); |
85 | } else { |
86 | CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory); |
87 | GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags, |
88 | model_flags, graph_def_contents); |
89 | } |
90 | } |
91 | |
92 | void ToolMain(const ParsedTocoFlags& parsed_toco_flags, |
93 | const ParsedModelFlags& parsed_model_flags) { |
94 | ModelFlags model_flags; |
95 | ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags); |
96 | |
97 | TocoFlags toco_flags; |
98 | ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); |
99 | |
100 | string graph_def_contents; |
101 | ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags, |
102 | &model_flags, &graph_def_contents); |
103 | CheckOutputFilePermissions(parsed_toco_flags.output_file); |
104 | |
105 | std::unique_ptr<Model> model = |
106 | Import(toco_flags, model_flags, graph_def_contents); |
107 | Transform(toco_flags, model.get()); |
108 | string output_file_contents; |
109 | Export(toco_flags, *model, toco_flags.allow_custom_ops(), |
110 | &output_file_contents); |
111 | CHECK(port::file::SetContents(parsed_toco_flags.output_file.value(), |
112 | output_file_contents, port::file::Defaults()) |
113 | .ok()); |
114 | } |
115 | |
116 | } // namespace |
117 | } // namespace toco |
118 | |
119 | int main(int argc, char** argv) { |
120 | toco::string msg; |
121 | toco::ParsedTocoFlags parsed_toco_flags; |
122 | toco::ParsedModelFlags parsed_model_flags; |
123 | |
124 | // If no args were specified, give a help string to be helpful. |
125 | int* effective_argc = &argc; |
126 | char** effective_argv = argv; |
127 | if (argc == 1) { |
128 | // No arguments, so manufacture help argv. |
129 | static int dummy_argc = 2; |
130 | static char* dummy_argv[] = {argv[0], const_cast<char*>("--help" )}; |
131 | effective_argc = &dummy_argc; |
132 | effective_argv = dummy_argv; |
133 | } |
134 | |
135 | // Parse toco flags and command flags in sequence, each one strips off args, |
136 | // giving InitGoogle a chance to handle all remaining arguments. |
137 | bool toco_success = toco::ParseTocoFlagsFromCommandLineFlags( |
138 | effective_argc, effective_argv, &msg, &parsed_toco_flags); |
139 | bool model_success = toco::ParseModelFlagsFromCommandLineFlags( |
140 | effective_argc, effective_argv, &msg, &parsed_model_flags); |
141 | if (!toco_success || !model_success || !msg.empty()) { |
142 | fprintf(stderr, "%s" , msg.c_str()); |
143 | fflush(stderr); |
144 | return 1; |
145 | } |
146 | toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true); |
147 | toco::ToolMain(parsed_toco_flags, parsed_model_flags); |
148 | } |
149 | |