1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17#include "tensorflow/core/framework/function.pb.h"
18#include "tensorflow/core/framework/versions.pb.h"
19#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
20#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
21#include "tensorflow/core/grappler/optimizers/constant_folding.h"
22#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
23#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
24#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
25#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
26#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
27#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
28#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
29#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
30#include "tensorflow/core/grappler/optimizers/model_pruner.h"
31#include "tensorflow/core/grappler/utils/colocation.h"
32#include "tensorflow/core/grappler/utils/topological_sort.h"
33#include "tensorflow/core/lib/core/status.h"
34
35namespace tensorflow {
36namespace grappler {
37
38namespace {
39int64 NumEdges(const GraphDef& graph) {
40 int64 num_edges = 0;
41 for (const auto& node : graph.node()) {
42 num_edges += node.input_size();
43 }
44 return num_edges;
45}
46
47string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
48 return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
49 after.node_size() - before.node_size(), "), ",
50 NumEdges(after), " edges (",
51 NumEdges(after) - NumEdges(before), ")");
52}
53} // namespace
54
55std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
56 const string& optimizer) {
57 std::unique_ptr<GraphOptimizer> graph_optimizer;
58 if (optimizer == "pruning") {
59 graph_optimizer.reset(new ModelPruner());
60 }
61 if (optimizer == "function") {
62 graph_optimizer.reset(new FunctionOptimizer(cfg_.function_optimization()));
63 }
64 if (optimizer == "constfold") {
65 graph_optimizer.reset(new ConstantFolding(cpu_device_));
66 }
67 if (optimizer == "layout") {
68 graph_optimizer.reset(new LayoutOptimizer());
69 }
70 if (optimizer == "memory") {
71 graph_optimizer.reset(new MemoryOptimizer(RewriterConfig::MANUAL));
72 }
73 if (optimizer == "arithmetic") {
74 graph_optimizer.reset(
75 new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
76 }
77 if (optimizer == "autoparallel") {
78 graph_optimizer.reset(
79 new AutoParallel(cfg_.auto_parallel().num_replicas()));
80 }
81 if (optimizer == "loop") {
82 graph_optimizer.reset(new LoopOptimizer(cfg_.loop_optimization()));
83 }
84 if (optimizer == "dependency") {
85 graph_optimizer.reset(
86 new DependencyOptimizer(cfg_.dependency_optimization()));
87 }
88 if (optimizer == "debug_stripper") {
89 graph_optimizer.reset(new DebugStripper());
90 }
91 return graph_optimizer;
92}
93
94Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
95 GraphDef* optimized_graph) {
96 std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
97 if (cfg_.optimizers().empty()) {
98 if (!cfg_.disable_model_pruning()) {
99 optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
100 }
101 if (cfg_.function_optimization() != RewriterConfig::OFF) {
102 optimizers.push_back(std::unique_ptr<GraphOptimizer>(
103 new FunctionOptimizer(cfg_.function_optimization())));
104 }
105 if (cfg_.debug_stripper() == RewriterConfig::ON) {
106 optimizers.push_back(
107 std::unique_ptr<GraphOptimizer>(new DebugStripper()));
108 }
109 if (cfg_.constant_folding() != RewriterConfig::OFF) {
110 optimizers.push_back(std::unique_ptr<GraphOptimizer>(
111 new ConstantFolding(cfg_.constant_folding(), cpu_device_)));
112 }
113 if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
114 optimizers.push_back(std::unique_ptr<GraphOptimizer>(
115 new ArithmeticOptimizer(cfg_.arithmetic_optimization())));
116 }
117 if (cfg_.loop_optimization() != RewriterConfig::OFF) {
118 optimizers.push_back(std::unique_ptr<GraphOptimizer>(
119 new LoopOptimizer(cfg_.loop_optimization())));
120 }
121 if (cfg_.dependency_optimization() != RewriterConfig::OFF) {
122 optimizers.push_back(std::unique_ptr<GraphOptimizer>(
123 new DependencyOptimizer(cfg_.dependency_optimization())));
124 }
125 if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
126 optimizers.push_back(
127 std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
128 }
129 if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
130 if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
131 optimizers.push_back(std::unique_ptr<GraphOptimizer>(
132 // Use the default target node name prefix "gradients/"
133 new MemoryOptimizer(cfg_.memory_optimization())));
134 } else {
135 optimizers.push_back(
136 std::unique_ptr<GraphOptimizer>(new MemoryOptimizer(
137 cfg_.memory_optimization(),
138 cfg_.memory_optimizer_target_node_name_scope())));
139 }
140 }
141 if (cfg_.auto_parallel().enable()) {
142 optimizers.push_back(std::unique_ptr<GraphOptimizer>(
143 new AutoParallel(cfg_.auto_parallel().num_replicas())));
144 }
145 } else {
146 const std::set<string> available_optimizers = {
147 "pruning", "function", "constfold", "layout",
148 "memory", "autoparallel", "arithmetic", "loop",
149 "dependency", "debug_stripper"};
150 std::vector<string> custom_optimizer_names;
151 for (const auto& optimizer_name : cfg_.optimizers()) {
152 if (available_optimizers.find(optimizer_name) !=
153 available_optimizers.end()) {
154 optimizers.push_back(NewOptimizer(optimizer_name));
155 } else {
156 custom_optimizer_names.push_back(optimizer_name);
157 }
158 }
159 // Now run the custom optimizers.
160 for (const auto& optimizer_name : custom_optimizer_names) {
161 std::unique_ptr<CustomGraphOptimizer> opt =
162 CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
163 if (opt == nullptr) continue;
164 TF_RETURN_IF_ERROR(opt->Init());
165 optimizers.push_back(std::move(opt));
166 }
167 }
168
169 if (optimizers.empty()) {
170 *optimized_graph = item.graph;
171 return Status::OK();
172 }
173
174 // Some optimizers should be run only once.
175 const std::set<string> run_once_optimizers = {"layout"};
176 bool already_optimized = false;
177 const int num_iterations =
178 cfg_.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
179 ? 1
180 : cfg_.meta_optimizer_iterations();
181 for (int iteration = 0; iteration < num_iterations; ++iteration) {
182 VLOG(1) << "Starting optimization iteration " << iteration + 1;
183 for (const auto& optimizer : optimizers) {
184 if (iteration > 0 && run_once_optimizers.count(optimizer->name())) {
185 continue;
186 }
187 if (!already_optimized) {
188 Status status = optimizer->Optimize(cluster, item, optimized_graph);
189 string result;
190 if (!status.ok()) {
191 VLOG(1) << "Not able to apply optimizer " << optimizer->name()
192 << ". Return status: " << status.ToString();
193 result = status.ToString();
194 } else {
195 already_optimized = true;
196 result = strings::StrCat(
197 "OK. ", PrintSizesBeforeAfter(item.graph, *optimized_graph));
198 }
199 result_.push_back(std::make_pair(optimizer->name(), result));
200 VLOG(1) << "Optimizer " << optimizer->name()
201 << " return status: " << result;
202 } else {
203 GrapplerItem optimized_item(item, std::move(*optimized_graph));
204 Status status =
205 optimizer->Optimize(cluster, optimized_item, optimized_graph);
206 string result;
207 if (!status.ok()) {
208 VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": "
209 << status.ToString();
210 optimized_graph->Swap(&optimized_item.graph);
211 result = status.ToString();
212 } else {
213 result = strings::StrCat(
214 optimizer->name(), ": ",
215 PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph));
216 }
217 result_.push_back(std::make_pair(optimizer->name(), result));
218 VLOG(1) << result;
219 }
220 }
221 }
222
223 if (already_optimized) {
224 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
225 ReassignColocation(optimized_graph);
226 // Make sure that the optimizers preserved the graph version and library.
227 DCHECK_GE(optimized_graph->library().function_size(),
228 item.graph.library().function_size());
229 DCHECK_GE(optimized_graph->library().gradient_size(),
230 item.graph.library().gradient_size());
231 DCHECK_EQ(optimized_graph->versions().producer(),
232 item.graph.versions().producer());
233 } else {
234 *optimized_graph = item.graph;
235 }
236
237 return Status::OK();
238}
239
240void MetaOptimizer::PrintResult() {
241 for (const auto& result : result_) {
242 LOG(INFO) << "Return status of optimizer " << result.first << ": "
243 << result.second;
244 }
245}
246
247void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
248 const GraphDef& pruned_graph, double result) {
249 // Nothing to do for MetaOptimizer.
250}
251
252bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
253 return !cfg.disable_model_pruning() ||
254 cfg.layout_optimizer() != RewriterConfig::OFF ||
255 cfg.function_optimization() != RewriterConfig::OFF ||
256 cfg.constant_folding() != RewriterConfig::OFF ||
257 cfg.arithmetic_optimization() != RewriterConfig::OFF ||
258 cfg.loop_optimization() != RewriterConfig::OFF ||
259 cfg.dependency_optimization() != RewriterConfig::OFF ||
260 cfg.auto_parallel().enable() ||
261 cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
262 cfg.debug_stripper() == RewriterConfig::ON ||
263 !cfg.optimizers().empty();
264}
265
266Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
267 DeviceBase* cpu_device, Cluster* cluster,
268 GraphDef* optimized_graph) {
269 MetaOptimizer optimizer(cpu_device, cfg);
270 return optimizer.Optimize(cluster, item, optimized_graph);
271}
272
273} // namespace grappler
274} // namespace tensorflow
275