1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" |
17 | |
18 | #include <unistd.h> |
19 | #include <algorithm> |
20 | #include <atomic> |
21 | #include <deque> |
22 | #include <map> |
23 | #include <memory> |
24 | #include <string> |
25 | #include <tuple> |
26 | #include <unordered_map> |
27 | #include <vector> |
28 | |
29 | #include "tensorflow/compiler/xla/layout_util.h" |
30 | #include "tensorflow/compiler/xla/literal_util.h" |
31 | #include "tensorflow/compiler/xla/service/hlo_module.h" |
32 | #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" |
33 | #include "tensorflow/compiler/xla/shape_util.h" |
34 | #include "tensorflow/compiler/xla/types.h" |
35 | #include "tensorflow/compiler/xla/window_util.h" |
36 | #include "tensorflow/core/lib/core/status.h" |
37 | #include "tensorflow/core/lib/gtl/map_util.h" |
38 | #include "tensorflow/core/lib/gtl/optional.h" |
39 | #include "tensorflow/core/lib/io/path.h" |
40 | #include "tensorflow/core/lib/strings/numbers.h" |
41 | #include "tensorflow/core/lib/strings/str_util.h" |
42 | #include "tensorflow/core/lib/strings/strcat.h" |
43 | #include "tensorflow/core/lib/strings/stringprintf.h" |
44 | #include "tensorflow/core/platform/env.h" |
45 | #include "tensorflow/core/platform/protobuf.h" |
46 | #include "tensorflow/core/platform/regexp.h" |
47 | |
48 | using ::tensorflow::Env; |
49 | using ::tensorflow::WriteStringToFile; |
50 | using ::tensorflow::gtl::nullopt; |
51 | using ::tensorflow::gtl::optional; |
52 | using ::tensorflow::io::JoinPath; |
53 | using ::tensorflow::str_util::Join; |
54 | using ::tensorflow::str_util::StringReplace; |
55 | using ::tensorflow::strings::StrAppend; |
56 | using ::tensorflow::strings::StrCat; |
57 | |
58 | namespace xla { |
59 | namespace hlo_graph_dumper { |
60 | namespace { |
61 | |
62 | // Helpers for Printf and Appendf. |
63 | template <typename T> |
64 | struct PrintfConvert { |
65 | const T& operator()(const T& t) const { return t; } |
66 | }; |
67 | template <> |
68 | struct PrintfConvert<string> { |
69 | const char* operator()(const string& s) const { return s.c_str(); } |
70 | }; |
71 | |
72 | // Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() |
73 | // on strings. |
74 | template <typename... Ts> |
75 | string Printf(const char* fmt, const Ts&... ts) { |
76 | return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...); |
77 | } |
78 | template <typename... Ts> |
79 | void Appendf(string* s, const char* fmt, const Ts&... ts) { |
80 | tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...); |
81 | } |
82 | |
83 | // Used to indicate how we should treat a given HLOInstruction in the graph. |
84 | // should we treat it like normal, hide it, and so on? |
85 | enum NodeFilterResult { |
86 | kNormalNode, |
87 | kHideNode, |
88 | // Make the node easy to find in the final graph. |
89 | kHighlightNode, |
90 | // "Gray out" the node to indicate that some of its operands have been |
91 | // omitted. |
92 | kSomeOperandsOmitted, |
93 | // Style the node the same as kSomeOperandsOmitted, but also don't connect it |
94 | // to its operands, even if they're present in the graph. |
95 | kOmitNodeOperands, |
96 | // Same style as kSomeOperandsOmitted, but used to indicate that some of the |
97 | // node's *users* have been omitted. |
98 | kSomeUsersOmitted, |
99 | }; |
100 | |
101 | // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult. |
102 | // It lets callers tell the graph-drawing routines which nodes they want to be |
103 | // shown, hidden, or highlighted. |
104 | class NodeFilter { |
105 | public: |
106 | NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {} |
107 | |
108 | explicit NodeFilter( |
109 | std::function<NodeFilterResult(const HloInstruction* instr)> filter) |
110 | : filter_(std::move(filter)) {} |
111 | |
112 | bool Show(const HloInstruction* instr) const { |
113 | return filter_(instr) != kHideNode; |
114 | } |
115 | bool Highlight(const HloInstruction* instr) const { |
116 | return filter_(instr) == kHighlightNode; |
117 | } |
118 | bool OmitOperands(const HloInstruction* instr) const { |
119 | return filter_(instr) == kOmitNodeOperands; |
120 | } |
121 | bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const { |
122 | auto result = filter_(instr); |
123 | return result == kOmitNodeOperands || result == kSomeOperandsOmitted; |
124 | } |
125 | bool Deemphasized(const HloInstruction* instr) const { |
126 | auto result = filter_(instr); |
127 | return result == kOmitNodeOperands || result == kSomeOperandsOmitted || |
128 | result == kSomeUsersOmitted; |
129 | } |
130 | |
131 | bool ShowFusionSubcomputation(const HloInstruction* instr) const { |
132 | CHECK_EQ(instr->opcode(), HloOpcode::kFusion); |
133 | return Show(instr) && !SomeOrAllOperandsOmitted(instr); |
134 | } |
135 | |
136 | private: |
137 | std::function<NodeFilterResult(const HloInstruction* instr)> filter_; |
138 | }; |
139 | |
140 | // Node color schemes, used by NodeColorAttributes. |
141 | enum ColorScheme { |
142 | kBlue, |
143 | kBrown, |
144 | kDarkBlue, |
145 | kDarkGreen, |
146 | kDarkRed, |
147 | kGray, |
148 | kGreen, |
149 | kOrange, |
150 | kPurple, |
151 | kRed, |
152 | kWhite, |
153 | kYellow, |
154 | |
155 | // Causes the node's border to be a dashed line, and its content to be gray |
156 | // text on a white background, suggesting that this is an "unimportant" node. |
157 | kDashedBorder, |
158 | }; |
159 | |
160 | // Graphviz attributes/colors that make up a color scheme. |
161 | struct NodeColors { |
162 | const char* style; |
163 | const char* fill_color; |
164 | const char* stroke_color; |
165 | const char* font_color; |
166 | }; |
167 | |
168 | NodeColors NodeColorsForScheme(ColorScheme color) { |
169 | switch (color) { |
170 | case kBlue: |
171 | return NodeColors{"filled" , "#bbdefb" , "#8aacc8" , "black" }; |
172 | case kBrown: |
173 | return NodeColors{"filled" , "#bcaaa4" , "#8c7b75" , "black" }; |
174 | case kDarkBlue: |
175 | return NodeColors{"filled" , "#1565c0" , "#003c8f" , "white" }; |
176 | case kDarkGreen: |
177 | return NodeColors{"filled" , "#2e7d32" , "#005005" , "white" }; |
178 | case kDarkRed: |
179 | return NodeColors{"filled" , "#b71c1c" , "#7f0000" , "white" }; |
180 | case kGray: |
181 | return NodeColors{"filled" , "#cfd8dc" , "#9ea7aa" , "black" }; |
182 | case kGreen: |
183 | return NodeColors{"filled" , "#c8e6c9" , "#97b498" , "black" }; |
184 | case kOrange: |
185 | return NodeColors{"filled" , "#ffe0b2" , "#cbae82" , "black" }; |
186 | case kPurple: |
187 | return NodeColors{"filled" , "#e1bee7" , "#af8eb5" , "black" }; |
188 | case kRed: |
189 | return NodeColors{"filled" , "#ffcdd2" , "#cb9ca1" , "black" }; |
190 | case kWhite: |
191 | return NodeColors{"filled" , "white" , "black" , "black" }; |
192 | case kYellow: |
193 | return NodeColors{"filled" , "#fff9c4" , "#cbc693" , "black" }; |
194 | case kDashedBorder: |
195 | // "filled,dashed" looks the same as "dashed", since we have a white |
196 | // background. But we use "filled,dashed" so that when you hover over |
197 | // any part of the node (not just the text inside the node), our css |
198 | // :hover rule is triggered. |
199 | return NodeColors{"filled,dashed" , "white" , "#757575" , "#757575" }; |
200 | } |
201 | } |
202 | |
203 | // Given a ColorScheme, returns an attribute string for a node of that color. |
204 | // Sets the node's style and fill/stroke/text colors. |
205 | // |
206 | // Colors are from https://material.io/color. |
207 | string NodeColorAttributes(ColorScheme color) { |
208 | NodeColors node_colors = NodeColorsForScheme(color); |
209 | |
210 | return Printf( |
211 | R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")" , |
212 | node_colors.style, node_colors.font_color, node_colors.stroke_color, |
213 | node_colors.fill_color); |
214 | } |
215 | |
216 | // Replaces <> with <>, so that this string is safe(er) for use in a |
217 | // graphviz HTML-like string. |
218 | string HtmlLikeStringSanitize(tensorflow::StringPiece s) { |
219 | return StringReplace(StringReplace(s, "<" , "<" , /*replace_all=*/true), ">" , |
220 | ">" , /*replace_all=*/true); |
221 | } |
222 | |
223 | // Tries to generates a human-readable one-word description of the given |
224 | // computation. |
225 | // |
226 | // Currently we support: |
227 | // |
228 | // "return param0 + param1;" --> "add" |
229 | // "return param0 * param1;" --> "multiply" |
230 | // "return min(param0, param1);" --> "min" |
231 | // "return max(param0, param1);" --> "max" |
232 | // "return param0 <= param1;" --> "less-or-equal" |
233 | // "return param0 >= param1;" --> "greater-or-equal" |
234 | // "return param0 > param1;" --> "greater-than" |
235 | // "return param0 < param1;" --> "less-than" |
236 | // "return param0 == param1;" --> "equal-to" |
237 | // "return param0 != param1;" --> "not-equal-to" |
238 | // |
239 | // where param0 and param1 are effective scalars. For the ops that are |
240 | // commutative, we also support them with param0 and param1 swapped. |
241 | // |
242 | // This is useful primarily for reduce and map nodes. These take a |
243 | // subcomputation which is almost always one of the above, and pattern matching |
244 | // it to a short string lets us tell the user what the subcomputation is without |
245 | // drawing it as a graph. |
246 | optional<string> MatchTrivialComputation(const HloComputation* computation) { |
247 | if (computation->instruction_count() != 3) { |
248 | return nullopt; |
249 | } |
250 | |
251 | HloInstruction* root = computation->root_instruction(); |
252 | if (root->operand_count() != 2) { |
253 | return nullopt; |
254 | } |
255 | |
256 | // Check that both of the operands to the root are parameters. |
257 | const HloInstruction* operand0 = root->operand(0); |
258 | const HloInstruction* operand1 = root->operand(1); |
259 | if (operand0->opcode() != HloOpcode::kParameter || |
260 | operand1->opcode() != HloOpcode::kParameter) { |
261 | return nullopt; |
262 | } |
263 | |
264 | // Check that the two operands of root are param0 and param1. All of the |
265 | // opcodes we recognize are commutative, so we're OK with either order. |
266 | auto n0 = operand0->parameter_number(); |
267 | auto n1 = operand1->parameter_number(); |
268 | if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { |
269 | return nullopt; |
270 | } |
271 | |
272 | // If the params are reversed, check that the operation being performed is |
273 | // commutative. |
274 | if (n0 == 1) { |
275 | switch (root->opcode()) { |
276 | case HloOpcode::kLe: |
277 | case HloOpcode::kGe: |
278 | case HloOpcode::kGt: |
279 | case HloOpcode::kLt: |
280 | return nullopt; |
281 | default: |
282 | break; |
283 | } |
284 | } |
285 | |
286 | // Check that the root and params are all effective scalars. |
287 | if (!ShapeUtil::IsEffectiveScalar(root->shape()) || |
288 | !ShapeUtil::IsEffectiveScalar(operand0->shape()) || |
289 | !ShapeUtil::IsEffectiveScalar(operand1->shape())) { |
290 | return nullopt; |
291 | } |
292 | |
293 | // If we recognize the root's opcode, we've successfully pattern-matched! |
294 | switch (root->opcode()) { |
295 | case HloOpcode::kAdd: |
296 | return "add" ; |
297 | case HloOpcode::kMultiply: |
298 | return "multiply" ; |
299 | case HloOpcode::kMinimum: |
300 | return "min" ; |
301 | case HloOpcode::kMaximum: |
302 | return "max" ; |
303 | case HloOpcode::kLe: |
304 | return "less-or-equal" ; |
305 | case HloOpcode::kGe: |
306 | return "greater-or-equal" ; |
307 | case HloOpcode::kGt: |
308 | return "greater-than" ; |
309 | case HloOpcode::kLt: |
310 | return "less-than" ; |
311 | case HloOpcode::kEq: |
312 | return "equal-to" ; |
313 | case HloOpcode::kNe: |
314 | return "not-equal-to" ; |
315 | default: |
316 | return nullopt; |
317 | } |
318 | } |
319 | |
320 | // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). |
321 | class HloDotDumper { |
322 | public: |
323 | HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, |
324 | const DebugOptions& debug_options, bool show_metadata, |
325 | const HloExecutionProfile* profile, NodeFilter filter) |
326 | : computation_(computation), |
327 | label_(label.ToString()), |
328 | debug_options_(debug_options), |
329 | show_metadata_(show_metadata), |
330 | profile_(profile), |
331 | filter_(std::move(filter)) {} |
332 | |
333 | string Dump(); |
334 | |
335 | private: |
336 | // Returns the dot graph identifier for the given instruction. |
337 | string InstructionId(const HloInstruction* instruction) { |
338 | return StrCat(reinterpret_cast<uint64>(instruction)); |
339 | } |
340 | |
341 | // Returns the dot graph identifier for the given computation. |
342 | string SubcomputationId(const HloComputation* computation) { |
343 | return StrCat("cluster_" , reinterpret_cast<uint64>(computation)); |
344 | } |
345 | |
346 | // Generates graph header/footer. These should be called *after* dumping all |
347 | // of the instructions and subcomputations for the graph, as they both use |
348 | // data generated while dumping the graph. |
349 | string Header(); |
350 | string Footer(); |
351 | |
352 | bool ShouldShowSubcomputation(const HloComputation* subcomp); |
353 | bool ShouldShowFusionSubcomputation(const HloInstruction* instr); |
354 | |
355 | // We omit some nodes from the graph, instead drawing them inlined into the |
356 | // nodes that use them. |
357 | bool ShouldMergeIntoUsers(const HloInstruction* instr) const; |
358 | |
359 | string DumpSubcomputation(const HloComputation* subcomp, |
360 | const HloInstruction* parent_instr); |
361 | string DumpComputation(const HloComputation* comp); |
362 | string DumpRootTag(); |
363 | string DumpInstruction(const HloInstruction* instr); |
364 | ColorScheme GetInstructionColor(const HloInstruction* instr); |
365 | string GetInstructionNodeShape(const HloInstruction* instr); |
366 | string GetInstructionNodeLabel(const HloInstruction* instr); |
367 | string GetInstructionNodeMetadata(const HloInstruction* instr); |
368 | string GetInstructionNodeExtraInfo(const HloInstruction* instr); |
369 | string GetInstructionNodeInlinedOperands(const HloInstruction* instr); |
370 | void AddInstructionIncomingEdges(const HloInstruction* instr); |
371 | |
372 | // For most instructions, GetNodeForEdge(instr) returns instr. |
373 | // |
374 | // The exception is fusion nodes. For these, we walk up the chain of nested |
375 | // fusion nodes starting at instr until we reach a node that either (a) isn't |
376 | // a fusion node, or (b) is a fusion node for which |
377 | // ShouldShowFusionSubcomputation is false. |
378 | // |
379 | // We do this because fusion nodes are expanded inline -- if |
380 | // ShouldShowFusionSubcomputation is true, the fusion node won't be present in |
381 | // the graph. |
382 | // |
383 | // In general when you want to draw an edge from A to B, you should actually |
384 | // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B). |
385 | const HloInstruction* GetNodeForEdge(const HloInstruction* instr); |
386 | |
387 | // If instr has just one computation and it's trivial (e.g. "return param0 + |
388 | // param1"), returns a string you can put into the node's body that names the |
389 | // subcomputation, e.g. "Subcomputation: <b>add</b>". |
390 | string GetInstructionTrivialComputationStr(const HloInstruction* instr); |
391 | |
392 | const HloComputation* computation_; // never null |
393 | const string label_; // overall name for the graph |
394 | const DebugOptions& debug_options_; |
395 | const bool show_metadata_; |
396 | const HloExecutionProfile* profile_; // may be null |
397 | const NodeFilter filter_; |
398 | |
399 | // Each HloInstruction dumped gets a monotically-increasing node ID. This |
400 | // must start at 1, because that's where graphviz's accounting starts. |
401 | int64 next_node_id_ = 1; |
402 | std::unordered_map<const HloInstruction*, int64> node_ids_; |
403 | |
404 | // The "root" tag doesn't have an associated HloInstruction pointer, so we |
405 | // need to store it outside the map. |
406 | int64 root_node_id_; |
407 | |
408 | // Each (from, to) edge gets a monotonically-increasing ID. This is a |
409 | // multimap because it's possible for the same edge to appear multiple times |
410 | // in the graph (e.g. x^2 may be represented as mul(x, x)). |
411 | int64 next_edge_id_ = 1; |
412 | std::unordered_multimap< |
413 | std::pair<const HloInstruction*, const HloInstruction*>, int64, |
414 | tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>> |
415 | edge_ids_; |
416 | |
417 | // Each HloComputation that's emitted gets a monotonically-increasing ID. |
418 | int64 next_cluster_id_ = 1; |
419 | std::unordered_map<const HloComputation*, int64> cluster_ids_; |
420 | |
421 | // Edges to print from Footer(). Edges come at the end because graphviz is |
422 | // unhappy if an edge from a subcomputation to a node in the outer computation |
423 | // appears before both the inner computation and the destination node are |
424 | // defined. |
425 | std::vector<string> edges_; |
426 | |
427 | // When coloring by sharding information, we track the sharding string |
428 | // representation to color association, by round-robin the color schemes. |
429 | std::unordered_map<string, ColorScheme> sharding_colors_; |
430 | int64 next_shard_color_ = 0; |
431 | }; |
432 | |
433 | string HloDotDumper::Dump() { |
434 | string body; |
435 | StrAppend(&body, DumpComputation(computation_)); |
436 | StrAppend(&body, DumpRootTag()); |
437 | |
438 | // By contract, Header() and Footer() have to be called after we've dumped all |
439 | // our instructions, because they use state generated during that process. |
440 | string g = Header(); |
441 | StrAppend(&g, body); |
442 | StrAppend(&g, Footer()); |
443 | return g; |
444 | } |
445 | |
446 | string HloDotDumper::() { |
447 | const char* fmt = R"(digraph G { |
448 | rankdir = TB; |
449 | compound = true; |
450 | label = <<b>%s</b>>; |
451 | labelloc = t; |
452 | // Disable the tooltip. Interestingly, "" doesn't work! |
453 | tooltip = " "; |
454 | // DOT graphs accept a stylesheet as a URI. So naturally, an inline |
455 | // stylesheet is a data URI! |
456 | stylesheet=" |
457 | data:text/css, |
458 | @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); |
459 | svg text { |
460 | font-family: 'Roboto'; |
461 | font-size: 12px; |
462 | } |
463 | |
464 | %s |
465 | " |
466 | |
467 | )" ; |
468 | |
469 | VLOG(3) << "Generating Header" ; |
470 | |
471 | string graph_label = |
472 | StrCat(label_, "<br/>Computation " , computation_->name()); |
473 | if (computation_->IsFusionComputation()) { |
474 | StrAppend(&graph_label, |
475 | StrCat(" (in fusion instruction " , |
476 | computation_->FusionInstruction()->name(), ")" )); |
477 | } |
478 | if (profile_ != nullptr) { |
479 | auto cycles = profile_->total_cycles_executed(*computation_); |
480 | Appendf(&graph_label, "<br/>total cycles = %lld (%s)" , cycles, |
481 | tensorflow::strings::HumanReadableNum(cycles)); |
482 | } |
483 | |
484 | // Create CSS rules that say, when you hover over the given node or cluster, |
485 | // turn the given edge the given color. |
486 | // |
487 | // We rely on a few properties of how graphviz generates SVGs: |
488 | // |
489 | // - Nodes are named "nodeN", where N corresponds to the 1-based index of |
490 | // the node in our DOT (i.e. the first node in the DOT is "node1", etc.). |
491 | // Edges are similarly named "edgeN", and clusters are named "clustN". |
492 | // - Nodes come before their in- and out-edges in the SVG. We need this |
493 | // because the "X ~ Y" CSS selector finds a sibling of X that *comes |
494 | // after X in the DOM* and matches Y. |
495 | std::vector<string> edge_css_rules; |
496 | const char* kBlue = "#1976d2" ; |
497 | const char* kRed = "#d32f2f" ; |
498 | for (const auto& kv : edge_ids_) { |
499 | const HloInstruction* from_node = kv.first.first; |
500 | const HloInstruction* to_node = kv.first.second; |
501 | int64 edge_id = kv.second; |
502 | |
503 | auto add_hover_css_rule = [&](string elem_type, int64 elem_id, |
504 | const char* color) { |
505 | // One could imagine other ways of writing this CSS rule that involve |
506 | // less duplication, but this way seems to be relatively performant. |
507 | edge_css_rules.push_back( |
508 | Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n" |
509 | " #%s%d:hover ~ #edge%lld path { " |
510 | "stroke: %s; stroke-width: .2em; }\n" |
511 | " #%s%d:hover ~ #edge%lld polygon { " |
512 | "fill: %s; stroke: %s; stroke-width: .2em; }\n" , |
513 | elem_type, elem_id, edge_id, color, // |
514 | elem_type, elem_id, edge_id, color, // |
515 | elem_type, elem_id, edge_id, color, color)); |
516 | }; |
517 | |
518 | // The "to_node" value may be a NULL, indicating that this points to the |
519 | // "root" tag rather than a normal node. |
520 | int64 from_node_id = |
521 | tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1); |
522 | if (from_node_id == -1) { |
523 | LOG(FATAL) << from_node->name() << " was added to edges but not to nodes" ; |
524 | } |
525 | int64 to_node_id = |
526 | to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1) |
527 | : root_node_id_; |
528 | if (to_node != nullptr && to_node_id == -1) { |
529 | LOG(FATAL) << to_node->name() << " was added to edges but not to nodes" ; |
530 | } |
531 | |
532 | add_hover_css_rule("node" , from_node_id, kBlue); |
533 | add_hover_css_rule("node" , to_node_id, kRed); |
534 | |
535 | if (to_node) { |
536 | VLOG(3) << "Adding css for edge " << edge_id << " from node " |
537 | << from_node->name() << " to node " << to_node->name(); |
538 | } else { |
539 | VLOG(3) << "Adding css for edge " << edge_id << " from node " |
540 | << from_node->name() << " to root tag" ; |
541 | } |
542 | |
543 | // If this edge crosses a fusion cluster boundary, highlight it when the |
544 | // cluster is hovered over. |
545 | if (to_node) { |
546 | if (from_node->IsFused() && |
547 | from_node->parent()->root_instruction() == from_node) { |
548 | int64 cluster_id = cluster_ids_.at(from_node->parent()); |
549 | add_hover_css_rule("clust" , cluster_id, kBlue); |
550 | } |
551 | if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) { |
552 | int64 cluster_id = cluster_ids_.at(to_node->parent()); |
553 | add_hover_css_rule("clust" , cluster_id, kRed); |
554 | } |
555 | } |
556 | } |
557 | |
558 | return Printf(fmt, graph_label, Join(edge_css_rules, "\n" )); |
559 | } |
560 | |
561 | string HloDotDumper::() { return StrCat(Join(edges_, "\n" ), "\n}" ); } |
562 | |
563 | bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { |
564 | CHECK_EQ(instr->opcode(), HloOpcode::kFusion); |
565 | return ShouldShowSubcomputation(instr->fused_instructions_computation()); |
566 | } |
567 | |
568 | bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { |
569 | if (subcomp->IsFusionComputation()) { |
570 | const HloInstruction* fusion = subcomp->FusionInstruction(); |
571 | if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) { |
572 | return false; |
573 | } |
574 | } |
575 | |
576 | // Don't show trivial subcomputations on non-fusion nodes -- these are inlined |
577 | // into the graph. |
578 | if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) { |
579 | return false; |
580 | } |
581 | |
582 | // Show the subcomputation if we're showing any of its members. |
583 | return std::any_of( |
584 | computation_->instructions().begin(), computation_->instructions().end(), |
585 | [&](const HloInstruction* instr) { return filter_.Show(instr); }); |
586 | } |
587 | |
588 | string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, |
589 | const HloInstruction* parent_instr) { |
590 | VLOG(2) << "Dumping subcomputation " << subcomp->name(); |
591 | const char* computation_fmt = R"(subgraph %s { |
592 | %s |
593 | label = <%s>; |
594 | labelloc = t; |
595 | tooltip = " "; |
596 | %s |
597 | } // %s |
598 | |
599 | )" ; |
600 | |
601 | cluster_ids_[subcomp] = next_cluster_id_++; |
602 | |
603 | string id = SubcomputationId(subcomp); |
604 | |
605 | string subcomp_label, style; |
606 | if (parent_instr->opcode() == HloOpcode::kFusion) { |
607 | subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s" , |
608 | HtmlLikeStringSanitize(parent_instr->name()), |
609 | HtmlLikeStringSanitize(parent_instr->ToCategory())); |
610 | string = GetInstructionNodeExtraInfo(parent_instr); |
611 | if (!extra_info.empty()) { |
612 | StrAppend(&subcomp_label, "<br/>" , extra_info); |
613 | } |
614 | |
615 | bool highlight = filter_.Highlight(parent_instr); |
616 | const char* fillcolor; |
617 | const char* strokecolor; |
618 | if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) { |
619 | // Use the sharding color, if the node isn't highlighted. |
620 | NodeColors node_colors = |
621 | NodeColorsForScheme(GetInstructionColor(parent_instr)); |
622 | fillcolor = node_colors.fill_color; |
623 | strokecolor = node_colors.stroke_color; |
624 | } else { |
625 | // Subcomputation's fill/stroke color is light/dark red/gray, depending on |
626 | // whether or not the subcomputation's fusion node is highlighted. |
627 | fillcolor = highlight ? "#ffcdd2" : "#f5f5f5" ; |
628 | strokecolor = highlight ? "#b71c1c" : "#c2c2c2" ; |
629 | } |
630 | style = |
631 | Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")" , |
632 | fillcolor, strokecolor); |
633 | } else { |
634 | subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s" , |
635 | HtmlLikeStringSanitize(parent_instr->name()), |
636 | HtmlLikeStringSanitize(subcomp->name())); |
637 | style = "style=rounded; color=black;" ; |
638 | } |
639 | |
640 | string comp_body = DumpComputation(subcomp); |
641 | |
642 | // Add an edge from the subcomputation to its parent node. If subcomp |
643 | // belongs to a fusion node, it's drawn in place of the fusion instruction, |
644 | // so there's no need to link those. |
645 | if (parent_instr->opcode() != HloOpcode::kFusion) { |
646 | const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); |
647 | VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() |
648 | << " as " << next_edge_id_; |
649 | edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); |
650 | const char* edge_fmt = |
651 | R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)" ; |
652 | edges_.push_back(Printf( |
653 | edge_fmt, InstructionId(from), InstructionId(parent_instr), |
654 | SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); |
655 | } |
656 | |
657 | string computation = |
658 | Printf(computation_fmt, id, style, subcomp_label, comp_body, id); |
659 | |
660 | return computation; |
661 | } |
662 | |
663 | string HloDotDumper::DumpComputation(const HloComputation* comp) { |
664 | string g; |
665 | for (const auto* instr : comp->instructions()) { |
666 | if (!filter_.Show(instr)) { |
667 | continue; |
668 | } |
669 | |
670 | // Dump subcomputations within instr. |
671 | for (const HloComputation* subcomp : instr->called_computations()) { |
672 | if (ShouldShowSubcomputation(subcomp)) { |
673 | StrAppend(&g, DumpSubcomputation(subcomp, instr)); |
674 | } |
675 | } |
676 | |
677 | StrAppend(&g, DumpInstruction(instr)); |
678 | } |
679 | return g; |
680 | } |
681 | |
682 | string HloDotDumper::DumpRootTag() { |
683 | const HloInstruction* from = GetNodeForEdge(computation_->root_instruction()); |
684 | |
685 | // We didn't display constants as separate nodes; so if the root is a |
686 | // constant, we don't add root tag or edge for it. |
687 | if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { |
688 | return "" ; |
689 | } |
690 | |
691 | auto from_id = InstructionId(from); |
692 | |
693 | // The ID of the root computation is otherwise unused, so it makes a good ID |
694 | // to use for the root-tag node. However, the edge_ids_ map requires a |
695 | // HloInstruction* pointer for the 'to' value, so we use a NULL value there |
696 | // (rather than a pointer type-cast) to make it obvious if it is erroneously |
697 | // dereferenced. |
698 | HloInstruction* to = nullptr; |
699 | auto to_id = SubcomputationId(computation_); |
700 | |
701 | string node_body = "ROOT" ; |
702 | string node_shape = "circle" ; |
703 | ColorScheme color = kBrown; |
704 | |
705 | VLOG(2) << "Adding root tag as node " << next_node_id_; |
706 | root_node_id_ = next_node_id_++; |
707 | |
708 | VLOG(2) << "Adding edge from " << from->name() << " to root tag as " |
709 | << next_edge_id_; |
710 | edge_ids_.insert({{from, to}, next_edge_id_++}); |
711 | edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)" , from_id, to_id)); |
712 | |
713 | return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" |
714 | "\n" , |
715 | to_id, node_body, node_shape, NodeColorAttributes(color)); |
716 | } |
717 | |
718 | bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { |
719 | // If a node: |
720 | // |
721 | // - is a tuple-shaped parameter, |
722 | // - is not a parameter to a fusion node, |
723 | // - has at least kMinUsersToOmit users shown, and |
724 | // - all of the shown users are get-tuple-elements, |
725 | // |
726 | // then we omit it from the graph, merging it with its users. |
727 | // |
728 | // This helps us handle the common case where a while loop body has one big |
729 | // tuple-shaped parameter. |
730 | const int kMinUsersToOmit = 3; |
731 | return instr->opcode() == HloOpcode::kParameter && |
732 | ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && |
733 | std::count_if(instr->users().begin(), instr->users().end(), |
734 | [&](const HloInstruction* user) { |
735 | return filter_.Show(user); |
736 | }) > kMinUsersToOmit && |
737 | std::all_of(instr->users().begin(), instr->users().end(), |
738 | [&](const HloInstruction* user) { |
739 | return !filter_.Show(user) || |
740 | user->opcode() == HloOpcode::kGetTupleElement; |
741 | }); |
742 | } |
743 | |
744 | string HloDotDumper::DumpInstruction(const HloInstruction* instr) { |
745 | // We don't display constants as separate nodes; they're merged into their |
746 | // users. |
747 | if (instr->opcode() == HloOpcode::kConstant) { |
748 | return "" ; |
749 | } |
750 | // Skip this node if it's merged into its users. |
751 | if (ShouldMergeIntoUsers(instr)) { |
752 | return "" ; |
753 | } |
754 | // Omit the fusion node if its subcomputation is drawn, since the |
755 | // subcomputation will be drawn inline. |
756 | if (instr->opcode() == HloOpcode::kFusion && |
757 | ShouldShowFusionSubcomputation(instr)) { |
758 | return "" ; |
759 | } |
760 | |
761 | VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_; |
762 | node_ids_[instr] = next_node_id_++; |
763 | |
764 | ColorScheme color = GetInstructionColor(instr); |
765 | string node_shape = GetInstructionNodeShape(instr); |
766 | string node_label = GetInstructionNodeLabel(instr); |
767 | string node_metadata = GetInstructionNodeMetadata(instr); |
768 | string = GetInstructionNodeExtraInfo(instr); |
769 | string inlined_constants = GetInstructionNodeInlinedOperands(instr); |
770 | string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); |
771 | AddInstructionIncomingEdges(instr); |
772 | |
773 | if (!debug_options_.xla_hlo_graph_sharding_color()) { |
774 | // Override the node's styling if it should be (de-)emphasized. |
775 | if (filter_.Deemphasized(instr)) { |
776 | color = kDashedBorder; |
777 | } |
778 | if (filter_.Highlight(instr)) { |
779 | node_shape = "diamond" ; |
780 | color = kDarkRed; |
781 | } |
782 | } |
783 | // Build the text that will be displayed inside the node. |
784 | string node_body = node_label; |
785 | for (const string& s : |
786 | {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { |
787 | if (!s.empty()) { |
788 | StrAppend(&node_body, "<br/>" , s); |
789 | } |
790 | } |
791 | |
792 | return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" |
793 | "\n" , |
794 | InstructionId(instr), node_body, node_shape, |
795 | NodeColorAttributes(color)); |
796 | } |
797 | |
798 | string HloDotDumper::GetInstructionNodeInlinedOperands( |
799 | const HloInstruction* instr) { |
800 | auto stringify_constant = [](const HloInstruction* constant) { |
801 | const auto& shape = constant->shape(); |
802 | |
803 | // If the shape has a dimension of size zero, print it as e.g. |
804 | // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), |
805 | // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which |
806 | // is just noise. |
807 | if (ShapeUtil::HasZeroElements(shape)) { |
808 | return Printf("{} (%s)" , ShapeUtil::HumanString(constant->shape())); |
809 | } |
810 | |
811 | // Print the literal value of constants with <= K elements. |
812 | optional<int64> elem_count; |
813 | if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { |
814 | elem_count = 1; |
815 | for (int64 dim : shape.dimensions()) { |
816 | *elem_count *= dim; |
817 | } |
818 | } |
819 | if (elem_count.has_value() && *elem_count <= 8) { |
820 | return Printf("%s (%s)" , constant->literal().ToString(), |
821 | ShapeUtil::HumanString(constant->shape())); |
822 | } |
823 | |
824 | // Otherwise, print e.g. "%constant.42 (s32[100])". |
825 | string constant_name; |
826 | if (tensorflow::str_util::StartsWith(constant->name(), "constant" )) { |
827 | constant_name = constant->name(); |
828 | } else { |
829 | constant_name = StrCat("constant " , constant->name()); |
830 | } |
831 | return Printf("%s %s" , constant_name, |
832 | ShapeUtil::HumanString(constant->shape())); |
833 | }; |
834 | |
835 | // Special case: If instr is a parameter to a fusion node, check whether the |
836 | // corresponding operand to the fusion node is a constant. |
837 | if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { |
838 | const HloInstruction* fusion = instr->parent()->FusionInstruction(); |
839 | const HloInstruction* operand = fusion->operand(instr->parameter_number()); |
840 | if (operand->opcode() != HloOpcode::kConstant) { |
841 | return "" ; |
842 | } |
843 | return StrCat("<b>constant</b> " , stringify_constant(operand)); |
844 | } |
845 | |
846 | std::vector<string> lines; |
847 | for (int64 i = 0; i < instr->operand_count(); ++i) { |
848 | const HloInstruction* operand = instr->operand(i); |
849 | optional<string> operand_str; |
850 | if (operand->opcode() == HloOpcode::kConstant) { |
851 | operand_str = stringify_constant(operand); |
852 | } else if (ShouldMergeIntoUsers(operand)) { |
853 | // Special case: If the operand is a parameter, use its parameter number |
854 | // rather than its name, because that's generally how people think of the |
855 | // node. |
856 | if (operand->opcode() == HloOpcode::kParameter) { |
857 | operand_str = Printf("Parameter %lld" , operand->parameter_number()); |
858 | } else { |
859 | operand_str = operand->name(); |
860 | } |
861 | } |
862 | |
863 | if (operand_str) { |
864 | if (instr->operand_count() > 1) { |
865 | lines.push_back(Printf("<b>operand %lld</b> = %s" , i, *operand_str)); |
866 | } else { |
867 | lines.push_back(Printf("<b>operand</b> = %s" , *operand_str)); |
868 | } |
869 | } |
870 | } |
871 | return Join(lines, "<br/>" ); |
872 | } |
873 | |
874 | ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { |
875 | if (debug_options_.xla_hlo_graph_sharding_color()) { |
876 | if (!instr->has_sharding()) { |
877 | return kDashedBorder; |
878 | } |
879 | string shard_str = instr->sharding().ToString(); |
880 | auto it = sharding_colors_.find(shard_str); |
881 | if (it != sharding_colors_.end()) { |
882 | return it->second; |
883 | } |
884 | ColorScheme color = static_cast<ColorScheme>( |
885 | kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); |
886 | sharding_colors_.emplace(shard_str, color); |
887 | return color; |
888 | } |
889 | const auto kParameterColor = kOrange; |
890 | |
891 | // Special case: If this instruction has a parameter merged into it, paint it |
892 | // the same color as a parameter. |
893 | if (std::any_of(instr->operands().begin(), instr->operands().end(), |
894 | [&](const HloInstruction* operand) { |
895 | return operand->opcode() == HloOpcode::kParameter && |
896 | ShouldMergeIntoUsers(operand); |
897 | })) { |
898 | return kParameterColor; |
899 | } |
900 | |
901 | // Pick different colors or shapes for instructions which are particularly |
902 | // expensive (eg, dot) and those which are unusual in some way or unique |
903 | // (eg, parameter). |
904 | switch (instr->opcode()) { |
905 | case HloOpcode::kAbs: |
906 | case HloOpcode::kAdd: |
907 | case HloOpcode::kAnd: |
908 | case HloOpcode::kAtan2: |
909 | case HloOpcode::kBitcastConvert: |
910 | case HloOpcode::kCeil: |
911 | case HloOpcode::kClamp: |
912 | case HloOpcode::kComplex: |
913 | case HloOpcode::kConvert: |
914 | case HloOpcode::kCos: |
915 | case HloOpcode::kDivide: |
916 | case HloOpcode::kEq: |
917 | case HloOpcode::kExp: |
918 | case HloOpcode::kFloor: |
919 | case HloOpcode::kGe: |
920 | case HloOpcode::kGt: |
921 | case HloOpcode::kImag: |
922 | case HloOpcode::kIsFinite: |
923 | case HloOpcode::kLe: |
924 | case HloOpcode::kLog: |
925 | case HloOpcode::kLt: |
926 | case HloOpcode::kMaximum: |
927 | case HloOpcode::kMinimum: |
928 | case HloOpcode::kMultiply: |
929 | case HloOpcode::kNe: |
930 | case HloOpcode::kNegate: |
931 | case HloOpcode::kNot: |
932 | case HloOpcode::kOr: |
933 | case HloOpcode::kPower: |
934 | case HloOpcode::kReal: |
935 | case HloOpcode::kRemainder: |
936 | case HloOpcode::kRng: |
937 | case HloOpcode::kRoundNearestAfz: |
938 | case HloOpcode::kShiftLeft: |
939 | case HloOpcode::kShiftRightArithmetic: |
940 | case HloOpcode::kShiftRightLogical: |
941 | case HloOpcode::kSign: |
942 | case HloOpcode::kSin: |
943 | case HloOpcode::kSlice: |
944 | case HloOpcode::kSort: |
945 | case HloOpcode::kSubtract: |
946 | case HloOpcode::kTanh: |
947 | // De-emphasize scalar-shaped elementwise ops -- they're generally |
948 | // uninteresting. |
949 | if (ShapeUtil::IsEffectiveScalar(instr->shape())) { |
950 | return kWhite; |
951 | } |
952 | return kYellow; |
953 | case HloOpcode::kBitcast: |
954 | case HloOpcode::kGetTupleElement: |
955 | case HloOpcode::kTrace: |
956 | case HloOpcode::kTuple: |
957 | return kWhite; |
958 | case HloOpcode::kBroadcast: |
959 | case HloOpcode::kBroadcastDimOne: |
960 | // De-emphasize nodes which broadcast a scalar within a fusion node -- |
961 | // these are essentially free. |
962 | if (instr->IsFused() && |
963 | ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) { |
964 | return kWhite; |
965 | } |
966 | return kGreen; |
967 | case HloOpcode::kConcatenate: |
968 | case HloOpcode::kCopy: |
969 | case HloOpcode::kDynamicSlice: |
970 | case HloOpcode::kGather: |
971 | case HloOpcode::kPad: |
972 | case HloOpcode::kReshape: |
973 | case HloOpcode::kReverse: |
974 | case HloOpcode::kSelect: |
975 | case HloOpcode::kTranspose: |
976 | // De-emphasize scalar-shaped data movement ops and all data movement ops |
977 | // inside fusion nodes, both of which are essentially free. |
978 | if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) { |
979 | return kWhite; |
980 | } |
981 | return kGreen; |
982 | case HloOpcode::kDynamicUpdateSlice: |
983 | // Unlike the data-movement ops above, dynamic-update-slice is not ~free |
984 | // inside of fusion nodes, so we de-emphasize it only if it's |
985 | // scalar-shaped. |
986 | if (ShapeUtil::IsEffectiveScalar(instr->shape())) { |
987 | return kWhite; |
988 | } |
989 | return kGreen; |
990 | case HloOpcode::kConvolution: |
991 | case HloOpcode::kDot: |
992 | case HloOpcode::kFft: |
993 | return kDarkBlue; |
994 | case HloOpcode::kReducePrecision: |
995 | return kRed; |
996 | case HloOpcode::kParameter: |
997 | return kParameterColor; |
998 | case HloOpcode::kBatchNormGrad: |
999 | case HloOpcode::kBatchNormInference: |
1000 | case HloOpcode::kBatchNormTraining: |
1001 | case HloOpcode::kReduce: |
1002 | case HloOpcode::kReduceWindow: |
1003 | case HloOpcode::kSelectAndScatter: |
1004 | return kPurple; |
1005 | case HloOpcode::kFusion: |
1006 | case HloOpcode::kMap: |
1007 | return kGray; |
1008 | case HloOpcode::kCrossReplicaSum: |
1009 | case HloOpcode::kInfeed: |
1010 | case HloOpcode::kOutfeed: |
1011 | case HloOpcode::kRecv: |
1012 | case HloOpcode::kRecvDone: |
1013 | case HloOpcode::kSend: |
1014 | case HloOpcode::kSendDone: |
1015 | return kBrown; |
1016 | case HloOpcode::kCall: |
1017 | case HloOpcode::kConditional: |
1018 | case HloOpcode::kCustomCall: |
1019 | case HloOpcode::kHostCompute: |
1020 | case HloOpcode::kWhile: |
1021 | return kDarkGreen; |
1022 | case HloOpcode::kConstant: |
1023 | LOG(FATAL) << "Constants don't get their own nodes in the graph." ; |
1024 | } |
1025 | } |
1026 | |
1027 | string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { |
1028 | // Give while loops a different shape so they're easier to pick out. |
1029 | switch (instr->opcode()) { |
1030 | case HloOpcode::kWhile: |
1031 | return "ellipse" ; |
1032 | default: |
1033 | return "rect" ; |
1034 | } |
1035 | } |
1036 | |
1037 | string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { |
1038 | // If we have a parameter, put the param number in the name. |
1039 | if (instr->opcode() == HloOpcode::kParameter) { |
1040 | return Printf("<b>Parameter %lld</b>" , instr->parameter_number()); |
1041 | } |
1042 | |
1043 | // The HLO instruction name contains usually the opcode, e.g. "%add.42" is |
1044 | // an add instruction. In this case we render just the name. |
1045 | if (tensorflow::str_util::StartsWith(instr->name(), |
1046 | HloOpcodeString(instr->opcode()))) { |
1047 | return Printf("<b>%s</b>" , HtmlLikeStringSanitize(instr->name())); |
1048 | } |
1049 | string extended_opcode = |
1050 | StrCat(HloOpcodeString(instr->opcode()), |
1051 | instr->opcode() != HloOpcode::kFusion |
1052 | ? "" |
1053 | : StrCat(":" , xla::ToString(instr->fusion_kind()))); |
1054 | // If the name does not contain the opcode, render both. |
1055 | return Printf("<b>%s</b><br/>%s" , HtmlLikeStringSanitize(extended_opcode), |
1056 | HtmlLikeStringSanitize(instr->name())); |
1057 | } |
1058 | |
1059 | string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { |
1060 | if (!show_metadata_) { |
1061 | return "" ; |
1062 | } |
1063 | |
1064 | std::vector<string> lines; |
1065 | if (!instr->metadata().op_name().empty()) { |
1066 | lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); |
1067 | } |
1068 | if (!instr->metadata().op_type().empty()) { |
1069 | lines.push_back(Printf( |
1070 | "op_type: %s" , HtmlLikeStringSanitize(instr->metadata().op_type()))); |
1071 | } |
1072 | if (!instr->metadata().source_file().empty() && |
1073 | instr->metadata().source_line() != 0) { |
1074 | lines.push_back(Printf("op_type: %s" , instr->metadata().source_file(), |
1075 | instr->metadata().source_line())); |
1076 | } |
1077 | |
1078 | return Join(lines, "<br/>" ); |
1079 | } |
1080 | |
1081 | string HloDotDumper::(const HloInstruction* instr) { |
1082 | std::vector<string> lines; |
1083 | |
1084 | // Get the instruction's extra attributes excluding the names of its |
1085 | // subcomputations, since those are drawn explicitly in the graph. |
1086 | for (const auto& line : instr->ExtraAttributesToString( |
1087 | HloPrintOptions().set_print_subcomputation_references(false))) { |
1088 | lines.push_back(HtmlLikeStringSanitize(line)); |
1089 | } |
1090 | |
1091 | // Show the shape and layout of the instruction, unless it's an inlined fusion |
1092 | // node -- there the shape and layout is present in the output node. |
1093 | if (instr->opcode() != HloOpcode::kFusion || |
1094 | !ShouldShowFusionSubcomputation(instr)) { |
1095 | // Show layout of instructions with more than one dimension. Don't show |
1096 | // layout on tuples or tensors with just one dimension (which only have one |
1097 | // possible layout) to avoid visual noise. |
1098 | bool shape_is_multidim = false; |
1099 | ShapeUtil::ForEachSubshape(instr->shape(), |
1100 | [&](const Shape& s, const ShapeIndex&) { |
1101 | shape_is_multidim |= s.dimensions_size() > 1; |
1102 | }); |
1103 | string instr_shape; |
1104 | if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) { |
1105 | instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape()); |
1106 | } else { |
1107 | instr_shape = ShapeUtil::HumanString(instr->shape()); |
1108 | } |
1109 | |
1110 | // Some instructions have giant tuples as their shapes, so truncate the |
1111 | // HLO's shape to kMaxShapeLen characters. |
1112 | constexpr int kMaxShapeLen = 64; |
1113 | if (instr_shape.length() > kMaxShapeLen) { |
1114 | instr_shape = StrCat( |
1115 | tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), |
1116 | "..." ); |
1117 | } |
1118 | lines.push_back(instr_shape); |
1119 | } |
1120 | if (debug_options_.xla_hlo_graph_addresses()) { |
1121 | lines.push_back(Printf("[%p]" , instr)); |
1122 | } |
1123 | if (profile_ != nullptr) { |
1124 | double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); |
1125 | double total_cycles_executed = |
1126 | profile_->total_cycles_executed(*instr->parent()); |
1127 | if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { |
1128 | lines.push_back( |
1129 | Printf("%% of cycles executed=%.2f" , |
1130 | 100 * hlo_cycles_executed / total_cycles_executed)); |
1131 | } |
1132 | } |
1133 | return Join(lines, "<br/>" ); |
1134 | } |
1135 | |
1136 | void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { |
1137 | auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, |
1138 | int64 operand_num, bool control_edge = false) { |
1139 | from = GetNodeForEdge(from); |
1140 | |
1141 | if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant || |
1142 | ShouldMergeIntoUsers(from)) { |
1143 | return; |
1144 | } |
1145 | VLOG(2) << "Adding edge from " << from->name() << " to " << to->name() |
1146 | << " as " << next_edge_id_; |
1147 | edge_ids_.insert({{from, to}, next_edge_id_++}); |
1148 | |
1149 | string edge_label; |
1150 | if (instr->operand_count() > 1 && !control_edge) { |
1151 | edge_label = Printf(R"( headlabel="%lld", labeldistance=2)" , operand_num); |
1152 | } else if (control_edge) { |
1153 | edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"" ; |
1154 | } |
1155 | const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)" ; |
1156 | edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), |
1157 | from->name(), to->name(), edge_label)); |
1158 | }; |
1159 | |
1160 | // Add edges from instr's operands to instr. Parameters within fusion |
1161 | // expressions are handled specially -- we draw an edge from the corresponding |
1162 | // operand on the fusion node itself to the parameter. |
1163 | if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { |
1164 | // Only add the edge if this is not the outermost computation; otherwise it |
1165 | // will lead from a node we're not drawing. |
1166 | if (instr->parent() != computation_) { |
1167 | const HloInstruction* fusion = instr->parent()->FusionInstruction(); |
1168 | add_edge(fusion->operand(instr->parameter_number()), instr, |
1169 | /*operand_num=*/0); |
1170 | } |
1171 | } else { |
1172 | for (int64 i = 0; i < instr->operand_count(); ++i) { |
1173 | add_edge(instr->operand(i), instr, i); |
1174 | } |
1175 | for (const HloInstruction* pred : instr->control_predecessors()) { |
1176 | add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true); |
1177 | } |
1178 | } |
1179 | } |
1180 | |
1181 | string HloDotDumper::GetInstructionTrivialComputationStr( |
1182 | const HloInstruction* instr) { |
1183 | // called_computations() on a fusion node "inherits" any called computations |
1184 | // of the fused root, which isn't what we want. Just ignore fusion nodes |
1185 | // here; they're handled separately. |
1186 | if (instr->opcode() == HloOpcode::kFusion) { |
1187 | return "" ; |
1188 | } |
1189 | |
1190 | std::vector<string> lines; |
1191 | for (int64 i = 0; i < instr->called_computations().size(); ++i) { |
1192 | optional<string> computation_type = |
1193 | MatchTrivialComputation(instr->called_computations()[i]); |
1194 | if (!computation_type) { |
1195 | continue; |
1196 | } |
1197 | if (instr->called_computations().size() == 1) { |
1198 | lines.push_back(Printf("Subcomputation: <b>%s</b>" , |
1199 | HtmlLikeStringSanitize(*computation_type))); |
1200 | } else { |
1201 | lines.push_back(Printf("Subcomputation %lld: <b>%s</b>" , i, |
1202 | HtmlLikeStringSanitize(*computation_type))); |
1203 | } |
1204 | } |
1205 | return Join(lines, "<br/>" ); |
1206 | } |
1207 | |
1208 | const HloInstruction* HloDotDumper::GetNodeForEdge( |
1209 | const HloInstruction* instr) { |
1210 | while (instr->opcode() == HloOpcode::kFusion && |
1211 | ShouldShowFusionSubcomputation(instr)) { |
1212 | instr = instr->fused_expression_root(); |
1213 | } |
1214 | return instr; |
1215 | } |
1216 | |
1217 | class GraphRendererRegistry { |
1218 | public: |
1219 | void AddRenderer(GraphRendererInterface* graph_renderer) { |
1220 | tensorflow::mutex_lock lock(mu_); |
1221 | graph_renderer_ = graph_renderer; |
1222 | } |
1223 | |
1224 | GraphRendererInterface* GetDefaultRenderer() { |
1225 | tensorflow::mutex_lock lock(mu_); |
1226 | return graph_renderer_; |
1227 | } |
1228 | |
1229 | static GraphRendererRegistry* Default() { |
1230 | static GraphRendererRegistry* registry = new GraphRendererRegistry(); |
1231 | return registry; |
1232 | } |
1233 | |
1234 | private: |
1235 | tensorflow::mutex mu_; |
1236 | GraphRendererInterface* graph_renderer_ = nullptr; |
1237 | }; |
1238 | |
1239 | } // namespace |
1240 | |
1241 | Registrar::Registrar(GraphRendererInterface* dumper) { |
1242 | GraphRendererRegistry::Default()->AddRenderer(dumper); |
1243 | } |
1244 | |
1245 | namespace { |
1246 | |
1247 | // Gets a NodeFilter that includes roughly all instructions whose distance from |
1248 | // root is <= radius. |
1249 | NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { |
1250 | // First, find the neighborhood of nodes with distance from root <= radius. |
1251 | // These nodes are our initial set of "normal" nodes. |
1252 | std::unordered_map<const HloInstruction*, NodeFilterResult> nodes; |
1253 | std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist; |
1254 | worklist.push_back({root, 0}); |
1255 | while (!worklist.empty()) { |
1256 | const HloInstruction* instr; |
1257 | int64 depth; |
1258 | std::tie(instr, depth) = worklist.front(); |
1259 | worklist.pop_front(); |
1260 | |
1261 | nodes[instr] = kNormalNode; |
1262 | if (depth == radius) { |
1263 | continue; |
1264 | } |
1265 | |
1266 | // Traverse into instr's operands. |
1267 | // |
1268 | // Don't traverse into tuples' operands unless the tuple is the root. |
1269 | // Usually a tuple is the bottommost node in the graph, and so its operands |
1270 | // are not interesting to the graph at hand. |
1271 | if (instr == root || instr->opcode() != HloOpcode::kTuple) { |
1272 | for (const HloInstruction* operand : instr->operands()) { |
1273 | if (!nodes.count(operand)) { |
1274 | worklist.push_back({operand, depth + 1}); |
1275 | } |
1276 | } |
1277 | } |
1278 | |
1279 | // Traverse into instr's nested computations. |
1280 | for (const HloComputation* computation : instr->called_computations()) { |
1281 | worklist.push_back({computation->root_instruction(), depth + 1}); |
1282 | } |
1283 | |
1284 | // Traverse into instr's users, unless: |
1285 | // |
1286 | // - there are a ton of them, in which case they're probably not |
1287 | // interesting (and anyway, rendering them all would make the graph |
1288 | // unreadable), or |
1289 | // - instr is a constant, in which case its users are probably not |
1290 | // interesting. |
1291 | if (instr->opcode() == HloOpcode::kConstant) { |
1292 | continue; |
1293 | } |
1294 | constexpr int kMaxUsersToRender = 16; |
1295 | if (instr->user_count() > kMaxUsersToRender) { |
1296 | // If we're going to skip this node's users, style it as such. |
1297 | nodes[instr] = kSomeUsersOmitted; |
1298 | continue; |
1299 | } |
1300 | for (const HloInstruction* user : instr->users()) { |
1301 | if (!nodes.count(user)) { |
1302 | worklist.push_back({user, depth + 1}); |
1303 | } |
1304 | } |
1305 | } |
1306 | |
1307 | auto is_displayed = [&](const HloInstruction* instr) { |
1308 | // Constants are displayed inline with their users; they're never omitted. |
1309 | // Nodes in subcomputations are always shown. |
1310 | return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || |
1311 | instr->parent() != root->parent(); |
1312 | }; |
1313 | |
1314 | // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we |
1315 | // know which nodes will be included in the graph. |
1316 | for (auto& kv : nodes) { |
1317 | const HloInstruction* instr = kv.first; |
1318 | NodeFilterResult& filter_result = kv.second; |
1319 | const auto& operands = instr->operands(); |
1320 | |
1321 | if (std::any_of(operands.begin(), operands.end(), is_displayed) && |
1322 | !std::all_of(operands.begin(), operands.end(), is_displayed)) { |
1323 | // Mark nodes with some operands omitted appropriately. |
1324 | filter_result = kSomeOperandsOmitted; |
1325 | } else if (!operands.empty() && |
1326 | std::none_of(operands.begin(), operands.end(), is_displayed)) { |
1327 | // Mark nodes with *all* operands omitted appropriately. |
1328 | filter_result = kOmitNodeOperands; |
1329 | } |
1330 | |
1331 | // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their |
1332 | // users made it into the graph. |
1333 | if (filter_result == kSomeUsersOmitted && |
1334 | std::all_of(instr->users().begin(), instr->users().end(), |
1335 | is_displayed)) { |
1336 | filter_result = kNormalNode; |
1337 | } |
1338 | } |
1339 | |
1340 | // Highlight the root node. |
1341 | nodes[root] = kHighlightNode; |
1342 | |
1343 | return NodeFilter([=](const HloInstruction* instr) { |
1344 | auto it = nodes.find(instr); |
1345 | if (it != nodes.end()) { |
1346 | return it->second; |
1347 | } |
1348 | // Show all nodes in subcomputations. |
1349 | if (instr->parent() != root->parent()) { |
1350 | return kNormalNode; |
1351 | } |
1352 | return kHideNode; |
1353 | }); |
1354 | } |
1355 | |
1356 | string SaveGraph(const string& graph, |
1357 | GraphRendererInterface::GraphKind graph_kind, |
1358 | const string& dest_path) { |
1359 | static std::atomic<int> output_num(0); |
1360 | string file_extension; |
1361 | switch (graph_kind) { |
1362 | case GraphRendererInterface::DOT_GRAPH: |
1363 | file_extension = ".dot" ; |
1364 | break; |
1365 | case GraphRendererInterface::TF_GRAPHDEF: |
1366 | file_extension = ".pbtxt" ; |
1367 | break; |
1368 | } |
1369 | string path = JoinPath(dest_path, StrCat("hlo_graph_" , output_num++, "." )); |
1370 | auto status = Status::OK(); |
1371 | auto env = tensorflow::Env::Default(); |
1372 | if (!env->CreateUniqueFileName(&path, file_extension)) { |
1373 | status = |
1374 | Status(tensorflow::error::Code::UNKNOWN, |
1375 | StrCat("Failed to create temporary file to dump HLO graph: " , |
1376 | strerror(errno))); |
1377 | } else { |
1378 | status = tensorflow::WriteStringToFile(env, path, graph); |
1379 | } |
1380 | if (!status.ok()) { |
1381 | LOG(WARNING) << "Saving HLO graph failed: " << status; |
1382 | } |
1383 | return path; |
1384 | } |
1385 | |
1386 | string ExportGraph(const string& graph, |
1387 | GraphRendererInterface::GraphKind graph_kind, |
1388 | const DebugOptions& debug_options) { |
1389 | string path = debug_options.xla_hlo_graph_path(); |
1390 | if (!path.empty()) { |
1391 | return SaveGraph(graph, graph_kind, path); |
1392 | } else { |
1393 | auto graph_renderer = |
1394 | GraphRendererRegistry::Default()->GetDefaultRenderer(); |
1395 | CHECK(graph_renderer != nullptr) |
1396 | << "No registered renderer for the HLO graph. " |
1397 | "Use --xla_hlo_graph_path=PATH to export to local file system" ; |
1398 | return graph_renderer->RenderGraph(graph, graph_kind, debug_options); |
1399 | } |
1400 | } |
1401 | |
1402 | } // namespace |
1403 | |
1404 | string DumpGraph(const HloComputation& computation, const string& label, |
1405 | const DebugOptions& debug_options, |
1406 | const HloExecutionProfile* hlo_execution_profile, |
1407 | bool show_metadata) { |
1408 | GraphRendererInterface::GraphKind graph_kind; |
1409 | string graph; |
1410 | if (debug_options.xla_hlo_dump_as_graphdef()) { |
1411 | HloTfGraphBuilder builder(debug_options); |
1412 | TF_CHECK_OK(builder.AddComputation(computation)); |
1413 | CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), |
1414 | &graph)); |
1415 | graph_kind = GraphRendererInterface::TF_GRAPHDEF; |
1416 | } else { |
1417 | graph = HloDotDumper(&computation, label, debug_options, show_metadata, |
1418 | hlo_execution_profile, NodeFilter()) |
1419 | .Dump(); |
1420 | graph_kind = GraphRendererInterface::DOT_GRAPH; |
1421 | } |
1422 | |
1423 | string graph_url = ExportGraph(graph, graph_kind, debug_options); |
1424 | LOG(INFO) << "computation " << computation.name() << " [" << label |
1425 | << "]: " << graph_url; |
1426 | return graph_url; |
1427 | } |
1428 | |
1429 | string DumpNeighborhoodAround(const HloInstruction& node, int radius, |
1430 | bool show_metadata) { |
1431 | auto debug_options = node.GetModule()->config().debug_options(); |
1432 | string label = |
1433 | StrCat("Neighborhood of " , radius, " nodes around " , node.name()); |
1434 | NodeFilter filter = MakeNodeFilter(&node, radius); |
1435 | string graph = |
1436 | HloDotDumper(node.parent(), label, debug_options, show_metadata, |
1437 | /*profile=*/nullptr, filter) |
1438 | .Dump(); |
1439 | return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); |
1440 | } |
1441 | |
1442 | void DumpText(const HloModule& module, const string& label, |
1443 | const string& directory_path, bool do_prefix) { |
1444 | Env* env = Env::Default(); |
1445 | TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); |
1446 | string prefix = StrCat(env->NowMicros()); |
1447 | string filename = |
1448 | do_prefix ? StrCat(prefix, "-" , label, ".txt" ) : StrCat(label, ".txt" ); |
1449 | string path = JoinPath(directory_path, filename); |
1450 | TF_CHECK_OK(WriteStringToFile( |
1451 | env, path, |
1452 | module.ToString(HloPrintOptions().set_print_large_constants(true)))); |
1453 | LOG(INFO) << "dumping module '" << module.name() << "' to " << path; |
1454 | } |
1455 | |
1456 | string MaybeDumpHloModule(const HloModule& module, const string& label, |
1457 | const HloExecutionProfile* profile) { |
1458 | const DebugOptions& debug_options = module.config().debug_options(); |
1459 | VLOG(2) << "MaybeDumpHloModule called on module " << module.name() |
1460 | << " with generate_hlo_graph regex \"" |
1461 | << debug_options.xla_generate_hlo_graph() << "\"" ; |
1462 | string graph_url; |
1463 | if (!debug_options.xla_generate_hlo_graph().empty() && |
1464 | RE2::PartialMatch(module.name(), |
1465 | debug_options.xla_generate_hlo_graph())) { |
1466 | graph_url = |
1467 | DumpGraph(*module.entry_computation(), label, debug_options, profile); |
1468 | } |
1469 | if (!debug_options.xla_log_hlo_text().empty() && |
1470 | RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) { |
1471 | LOG(INFO) << "HLO for module " << module.name(); |
1472 | LOG(INFO) << "Label: " << label; |
1473 | XLA_LOG_LINES(2, module.ToString()); |
1474 | } |
1475 | if (!debug_options.xla_generate_hlo_text_to().empty()) { |
1476 | DumpText(module, label, debug_options.xla_generate_hlo_text_to()); |
1477 | } |
1478 | return graph_url; |
1479 | } |
1480 | |
1481 | } // namespace hlo_graph_dumper |
1482 | } // namespace xla |
1483 | |