| 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | |
| 16 | #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" |
| 17 | #include "tensorflow/core/framework/function.pb.h" |
| 18 | #include "tensorflow/core/framework/versions.pb.h" |
| 19 | #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" |
| 20 | #include "tensorflow/core/grappler/optimizers/auto_parallel.h" |
| 21 | #include "tensorflow/core/grappler/optimizers/constant_folding.h" |
| 22 | #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" |
| 23 | #include "tensorflow/core/grappler/optimizers/debug_stripper.h" |
| 24 | #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" |
| 25 | #include "tensorflow/core/grappler/optimizers/function_optimizer.h" |
| 26 | #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" |
| 27 | #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" |
| 28 | #include "tensorflow/core/grappler/optimizers/loop_optimizer.h" |
| 29 | #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" |
| 30 | #include "tensorflow/core/grappler/optimizers/model_pruner.h" |
| 31 | #include "tensorflow/core/grappler/utils/colocation.h" |
| 32 | #include "tensorflow/core/grappler/utils/topological_sort.h" |
| 33 | #include "tensorflow/core/lib/core/status.h" |
| 34 | |
| 35 | namespace tensorflow { |
| 36 | namespace grappler { |
| 37 | |
| 38 | namespace { |
| 39 | int64 NumEdges(const GraphDef& graph) { |
| 40 | int64 num_edges = 0; |
| 41 | for (const auto& node : graph.node()) { |
| 42 | num_edges += node.input_size(); |
| 43 | } |
| 44 | return num_edges; |
| 45 | } |
| 46 | |
| 47 | string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) { |
| 48 | return strings::StrCat("Graph size after: " , after.node_size(), " nodes (" , |
| 49 | after.node_size() - before.node_size(), "), " , |
| 50 | NumEdges(after), " edges (" , |
| 51 | NumEdges(after) - NumEdges(before), ")" ); |
| 52 | } |
| 53 | } // namespace |
| 54 | |
| 55 | std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer( |
| 56 | const string& optimizer) { |
| 57 | std::unique_ptr<GraphOptimizer> graph_optimizer; |
| 58 | if (optimizer == "pruning" ) { |
| 59 | graph_optimizer.reset(new ModelPruner()); |
| 60 | } |
| 61 | if (optimizer == "function" ) { |
| 62 | graph_optimizer.reset(new FunctionOptimizer(cfg_.function_optimization())); |
| 63 | } |
| 64 | if (optimizer == "constfold" ) { |
| 65 | graph_optimizer.reset(new ConstantFolding(cpu_device_)); |
| 66 | } |
| 67 | if (optimizer == "layout" ) { |
| 68 | graph_optimizer.reset(new LayoutOptimizer()); |
| 69 | } |
| 70 | if (optimizer == "memory" ) { |
| 71 | graph_optimizer.reset(new MemoryOptimizer(RewriterConfig::MANUAL)); |
| 72 | } |
| 73 | if (optimizer == "arithmetic" ) { |
| 74 | graph_optimizer.reset( |
| 75 | new ArithmeticOptimizer(cfg_.arithmetic_optimization())); |
| 76 | } |
| 77 | if (optimizer == "autoparallel" ) { |
| 78 | graph_optimizer.reset( |
| 79 | new AutoParallel(cfg_.auto_parallel().num_replicas())); |
| 80 | } |
| 81 | if (optimizer == "loop" ) { |
| 82 | graph_optimizer.reset(new LoopOptimizer(cfg_.loop_optimization())); |
| 83 | } |
| 84 | if (optimizer == "dependency" ) { |
| 85 | graph_optimizer.reset( |
| 86 | new DependencyOptimizer(cfg_.dependency_optimization())); |
| 87 | } |
| 88 | if (optimizer == "debug_stripper" ) { |
| 89 | graph_optimizer.reset(new DebugStripper()); |
| 90 | } |
| 91 | return graph_optimizer; |
| 92 | } |
| 93 | |
| 94 | Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, |
| 95 | GraphDef* optimized_graph) { |
| 96 | std::vector<std::unique_ptr<GraphOptimizer>> optimizers; |
| 97 | if (cfg_.optimizers().empty()) { |
| 98 | if (!cfg_.disable_model_pruning()) { |
| 99 | optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner())); |
| 100 | } |
| 101 | if (cfg_.function_optimization() != RewriterConfig::OFF) { |
| 102 | optimizers.push_back(std::unique_ptr<GraphOptimizer>( |
| 103 | new FunctionOptimizer(cfg_.function_optimization()))); |
| 104 | } |
| 105 | if (cfg_.debug_stripper() == RewriterConfig::ON) { |
| 106 | optimizers.push_back( |
| 107 | std::unique_ptr<GraphOptimizer>(new DebugStripper())); |
| 108 | } |
| 109 | if (cfg_.constant_folding() != RewriterConfig::OFF) { |
| 110 | optimizers.push_back(std::unique_ptr<GraphOptimizer>( |
| 111 | new ConstantFolding(cfg_.constant_folding(), cpu_device_))); |
| 112 | } |
| 113 | if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { |
| 114 | optimizers.push_back(std::unique_ptr<GraphOptimizer>( |
| 115 | new ArithmeticOptimizer(cfg_.arithmetic_optimization()))); |
| 116 | } |
| 117 | if (cfg_.loop_optimization() != RewriterConfig::OFF) { |
| 118 | optimizers.push_back(std::unique_ptr<GraphOptimizer>( |
| 119 | new LoopOptimizer(cfg_.loop_optimization()))); |
| 120 | } |
| 121 | if (cfg_.dependency_optimization() != RewriterConfig::OFF) { |
| 122 | optimizers.push_back(std::unique_ptr<GraphOptimizer>( |
| 123 | new DependencyOptimizer(cfg_.dependency_optimization()))); |
| 124 | } |
| 125 | if (cfg_.layout_optimizer() != RewriterConfig::OFF) { |
| 126 | optimizers.push_back( |
| 127 | std::unique_ptr<GraphOptimizer>(new LayoutOptimizer())); |
| 128 | } |
| 129 | if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) { |
| 130 | if (cfg_.memory_optimizer_target_node_name_scope().empty()) { |
| 131 | optimizers.push_back(std::unique_ptr<GraphOptimizer>( |
| 132 | // Use the default target node name prefix "gradients/" |
| 133 | new MemoryOptimizer(cfg_.memory_optimization()))); |
| 134 | } else { |
| 135 | optimizers.push_back( |
| 136 | std::unique_ptr<GraphOptimizer>(new MemoryOptimizer( |
| 137 | cfg_.memory_optimization(), |
| 138 | cfg_.memory_optimizer_target_node_name_scope()))); |
| 139 | } |
| 140 | } |
| 141 | if (cfg_.auto_parallel().enable()) { |
| 142 | optimizers.push_back(std::unique_ptr<GraphOptimizer>( |
| 143 | new AutoParallel(cfg_.auto_parallel().num_replicas()))); |
| 144 | } |
| 145 | } else { |
| 146 | const std::set<string> available_optimizers = { |
| 147 | "pruning" , "function" , "constfold" , "layout" , |
| 148 | "memory" , "autoparallel" , "arithmetic" , "loop" , |
| 149 | "dependency" , "debug_stripper" }; |
| 150 | std::vector<string> custom_optimizer_names; |
| 151 | for (const auto& optimizer_name : cfg_.optimizers()) { |
| 152 | if (available_optimizers.find(optimizer_name) != |
| 153 | available_optimizers.end()) { |
| 154 | optimizers.push_back(NewOptimizer(optimizer_name)); |
| 155 | } else { |
| 156 | custom_optimizer_names.push_back(optimizer_name); |
| 157 | } |
| 158 | } |
| 159 | // Now run the custom optimizers. |
| 160 | for (const auto& optimizer_name : custom_optimizer_names) { |
| 161 | std::unique_ptr<CustomGraphOptimizer> opt = |
| 162 | CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name); |
| 163 | if (opt == nullptr) continue; |
| 164 | TF_RETURN_IF_ERROR(opt->Init()); |
| 165 | optimizers.push_back(std::move(opt)); |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | if (optimizers.empty()) { |
| 170 | *optimized_graph = item.graph; |
| 171 | return Status::OK(); |
| 172 | } |
| 173 | |
| 174 | // Some optimizers should be run only once. |
| 175 | const std::set<string> run_once_optimizers = {"layout" }; |
| 176 | bool already_optimized = false; |
| 177 | const int num_iterations = |
| 178 | cfg_.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS |
| 179 | ? 1 |
| 180 | : cfg_.meta_optimizer_iterations(); |
| 181 | for (int iteration = 0; iteration < num_iterations; ++iteration) { |
| 182 | VLOG(1) << "Starting optimization iteration " << iteration + 1; |
| 183 | for (const auto& optimizer : optimizers) { |
| 184 | if (iteration > 0 && run_once_optimizers.count(optimizer->name())) { |
| 185 | continue; |
| 186 | } |
| 187 | if (!already_optimized) { |
| 188 | Status status = optimizer->Optimize(cluster, item, optimized_graph); |
| 189 | string result; |
| 190 | if (!status.ok()) { |
| 191 | VLOG(1) << "Not able to apply optimizer " << optimizer->name() |
| 192 | << ". Return status: " << status.ToString(); |
| 193 | result = status.ToString(); |
| 194 | } else { |
| 195 | already_optimized = true; |
| 196 | result = strings::StrCat( |
| 197 | "OK. " , PrintSizesBeforeAfter(item.graph, *optimized_graph)); |
| 198 | } |
| 199 | result_.push_back(std::make_pair(optimizer->name(), result)); |
| 200 | VLOG(1) << "Optimizer " << optimizer->name() |
| 201 | << " return status: " << result; |
| 202 | } else { |
| 203 | GrapplerItem optimized_item(item, std::move(*optimized_graph)); |
| 204 | Status status = |
| 205 | optimizer->Optimize(cluster, optimized_item, optimized_graph); |
| 206 | string result; |
| 207 | if (!status.ok()) { |
| 208 | VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": " |
| 209 | << status.ToString(); |
| 210 | optimized_graph->Swap(&optimized_item.graph); |
| 211 | result = status.ToString(); |
| 212 | } else { |
| 213 | result = strings::StrCat( |
| 214 | optimizer->name(), ": " , |
| 215 | PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph)); |
| 216 | } |
| 217 | result_.push_back(std::make_pair(optimizer->name(), result)); |
| 218 | VLOG(1) << result; |
| 219 | } |
| 220 | } |
| 221 | } |
| 222 | |
| 223 | if (already_optimized) { |
| 224 | TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph)); |
| 225 | ReassignColocation(optimized_graph); |
| 226 | // Make sure that the optimizers preserved the graph version and library. |
| 227 | DCHECK_GE(optimized_graph->library().function_size(), |
| 228 | item.graph.library().function_size()); |
| 229 | DCHECK_GE(optimized_graph->library().gradient_size(), |
| 230 | item.graph.library().gradient_size()); |
| 231 | DCHECK_EQ(optimized_graph->versions().producer(), |
| 232 | item.graph.versions().producer()); |
| 233 | } else { |
| 234 | *optimized_graph = item.graph; |
| 235 | } |
| 236 | |
| 237 | return Status::OK(); |
| 238 | } |
| 239 | |
| 240 | void MetaOptimizer::PrintResult() { |
| 241 | for (const auto& result : result_) { |
| 242 | LOG(INFO) << "Return status of optimizer " << result.first << ": " |
| 243 | << result.second; |
| 244 | } |
| 245 | } |
| 246 | |
| 247 | void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, |
| 248 | const GraphDef& pruned_graph, double result) { |
| 249 | // Nothing to do for MetaOptimizer. |
| 250 | } |
| 251 | |
| 252 | bool MetaOptimizerEnabled(const RewriterConfig& cfg) { |
| 253 | return !cfg.disable_model_pruning() || |
| 254 | cfg.layout_optimizer() != RewriterConfig::OFF || |
| 255 | cfg.function_optimization() != RewriterConfig::OFF || |
| 256 | cfg.constant_folding() != RewriterConfig::OFF || |
| 257 | cfg.arithmetic_optimization() != RewriterConfig::OFF || |
| 258 | cfg.loop_optimization() != RewriterConfig::OFF || |
| 259 | cfg.dependency_optimization() != RewriterConfig::OFF || |
| 260 | cfg.auto_parallel().enable() || |
| 261 | cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || |
| 262 | cfg.debug_stripper() == RewriterConfig::ON || |
| 263 | !cfg.optimizers().empty(); |
| 264 | } |
| 265 | |
| 266 | Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg, |
| 267 | DeviceBase* cpu_device, Cluster* cluster, |
| 268 | GraphDef* optimized_graph) { |
| 269 | MetaOptimizer optimizer(cpu_device, cfg); |
| 270 | return optimizer.Optimize(cluster, item, optimized_graph); |
| 271 | } |
| 272 | |
| 273 | } // namespace grappler |
| 274 | } // namespace tensorflow |
| 275 | |