1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17
18#include <unistd.h>
19#include <algorithm>
20#include <atomic>
21#include <deque>
22#include <map>
23#include <memory>
24#include <string>
25#include <tuple>
26#include <unordered_map>
27#include <vector>
28
29#include "tensorflow/compiler/xla/layout_util.h"
30#include "tensorflow/compiler/xla/literal_util.h"
31#include "tensorflow/compiler/xla/service/hlo_module.h"
32#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
33#include "tensorflow/compiler/xla/shape_util.h"
34#include "tensorflow/compiler/xla/types.h"
35#include "tensorflow/compiler/xla/window_util.h"
36#include "tensorflow/core/lib/core/status.h"
37#include "tensorflow/core/lib/gtl/map_util.h"
38#include "tensorflow/core/lib/gtl/optional.h"
39#include "tensorflow/core/lib/io/path.h"
40#include "tensorflow/core/lib/strings/numbers.h"
41#include "tensorflow/core/lib/strings/str_util.h"
42#include "tensorflow/core/lib/strings/strcat.h"
43#include "tensorflow/core/lib/strings/stringprintf.h"
44#include "tensorflow/core/platform/env.h"
45#include "tensorflow/core/platform/protobuf.h"
46#include "tensorflow/core/platform/regexp.h"
47
48using ::tensorflow::Env;
49using ::tensorflow::WriteStringToFile;
50using ::tensorflow::gtl::nullopt;
51using ::tensorflow::gtl::optional;
52using ::tensorflow::io::JoinPath;
53using ::tensorflow::str_util::Join;
54using ::tensorflow::str_util::StringReplace;
55using ::tensorflow::strings::StrAppend;
56using ::tensorflow::strings::StrCat;
57
58namespace xla {
59namespace hlo_graph_dumper {
60namespace {
61
62// Helpers for Printf and Appendf.
63template <typename T>
64struct PrintfConvert {
65 const T& operator()(const T& t) const { return t; }
66};
67template <>
68struct PrintfConvert<string> {
69 const char* operator()(const string& s) const { return s.c_str(); }
70};
71
72// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str()
73// on strings.
74template <typename... Ts>
75string Printf(const char* fmt, const Ts&... ts) {
76 return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...);
77}
78template <typename... Ts>
79void Appendf(string* s, const char* fmt, const Ts&... ts) {
80 tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...);
81}
82
83// Used to indicate how we should treat a given HLOInstruction in the graph.
84// should we treat it like normal, hide it, and so on?
85enum NodeFilterResult {
86 kNormalNode,
87 kHideNode,
88 // Make the node easy to find in the final graph.
89 kHighlightNode,
90 // "Gray out" the node to indicate that some of its operands have been
91 // omitted.
92 kSomeOperandsOmitted,
93 // Style the node the same as kSomeOperandsOmitted, but also don't connect it
94 // to its operands, even if they're present in the graph.
95 kOmitNodeOperands,
96 // Same style as kSomeOperandsOmitted, but used to indicate that some of the
97 // node's *users* have been omitted.
98 kSomeUsersOmitted,
99};
100
101// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
102// It lets callers tell the graph-drawing routines which nodes they want to be
103// shown, hidden, or highlighted.
104class NodeFilter {
105 public:
106 NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
107
108 explicit NodeFilter(
109 std::function<NodeFilterResult(const HloInstruction* instr)> filter)
110 : filter_(std::move(filter)) {}
111
112 bool Show(const HloInstruction* instr) const {
113 return filter_(instr) != kHideNode;
114 }
115 bool Highlight(const HloInstruction* instr) const {
116 return filter_(instr) == kHighlightNode;
117 }
118 bool OmitOperands(const HloInstruction* instr) const {
119 return filter_(instr) == kOmitNodeOperands;
120 }
121 bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
122 auto result = filter_(instr);
123 return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
124 }
125 bool Deemphasized(const HloInstruction* instr) const {
126 auto result = filter_(instr);
127 return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
128 result == kSomeUsersOmitted;
129 }
130
131 bool ShowFusionSubcomputation(const HloInstruction* instr) const {
132 CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
133 return Show(instr) && !SomeOrAllOperandsOmitted(instr);
134 }
135
136 private:
137 std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
138};
139
140// Node color schemes, used by NodeColorAttributes.
141enum ColorScheme {
142 kBlue,
143 kBrown,
144 kDarkBlue,
145 kDarkGreen,
146 kDarkRed,
147 kGray,
148 kGreen,
149 kOrange,
150 kPurple,
151 kRed,
152 kWhite,
153 kYellow,
154
155 // Causes the node's border to be a dashed line, and its content to be gray
156 // text on a white background, suggesting that this is an "unimportant" node.
157 kDashedBorder,
158};
159
160// Graphviz attributes/colors that make up a color scheme.
161struct NodeColors {
162 const char* style;
163 const char* fill_color;
164 const char* stroke_color;
165 const char* font_color;
166};
167
168NodeColors NodeColorsForScheme(ColorScheme color) {
169 switch (color) {
170 case kBlue:
171 return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
172 case kBrown:
173 return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
174 case kDarkBlue:
175 return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
176 case kDarkGreen:
177 return NodeColors{"filled", "#2e7d32", "#005005", "white"};
178 case kDarkRed:
179 return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
180 case kGray:
181 return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
182 case kGreen:
183 return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
184 case kOrange:
185 return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
186 case kPurple:
187 return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
188 case kRed:
189 return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
190 case kWhite:
191 return NodeColors{"filled", "white", "black", "black"};
192 case kYellow:
193 return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
194 case kDashedBorder:
195 // "filled,dashed" looks the same as "dashed", since we have a white
196 // background. But we use "filled,dashed" so that when you hover over
197 // any part of the node (not just the text inside the node), our css
198 // :hover rule is triggered.
199 return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
200 }
201}
202
203// Given a ColorScheme, returns an attribute string for a node of that color.
204// Sets the node's style and fill/stroke/text colors.
205//
206// Colors are from https://material.io/color.
207string NodeColorAttributes(ColorScheme color) {
208 NodeColors node_colors = NodeColorsForScheme(color);
209
210 return Printf(
211 R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
212 node_colors.style, node_colors.font_color, node_colors.stroke_color,
213 node_colors.fill_color);
214}
215
216// Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
217// graphviz HTML-like string.
218string HtmlLikeStringSanitize(tensorflow::StringPiece s) {
219 return StringReplace(StringReplace(s, "<", "&lt;", /*replace_all=*/true), ">",
220 "&gt;", /*replace_all=*/true);
221}
222
223// Tries to generates a human-readable one-word description of the given
224// computation.
225//
226// Currently we support:
227//
228// "return param0 + param1;" --> "add"
229// "return param0 * param1;" --> "multiply"
230// "return min(param0, param1);" --> "min"
231// "return max(param0, param1);" --> "max"
232// "return param0 <= param1;" --> "less-or-equal"
233// "return param0 >= param1;" --> "greater-or-equal"
234// "return param0 > param1;" --> "greater-than"
235// "return param0 < param1;" --> "less-than"
236// "return param0 == param1;" --> "equal-to"
237// "return param0 != param1;" --> "not-equal-to"
238//
239// where param0 and param1 are effective scalars. For the ops that are
240// commutative, we also support them with param0 and param1 swapped.
241//
242// This is useful primarily for reduce and map nodes. These take a
243// subcomputation which is almost always one of the above, and pattern matching
244// it to a short string lets us tell the user what the subcomputation is without
245// drawing it as a graph.
246optional<string> MatchTrivialComputation(const HloComputation* computation) {
247 if (computation->instruction_count() != 3) {
248 return nullopt;
249 }
250
251 HloInstruction* root = computation->root_instruction();
252 if (root->operand_count() != 2) {
253 return nullopt;
254 }
255
256 // Check that both of the operands to the root are parameters.
257 const HloInstruction* operand0 = root->operand(0);
258 const HloInstruction* operand1 = root->operand(1);
259 if (operand0->opcode() != HloOpcode::kParameter ||
260 operand1->opcode() != HloOpcode::kParameter) {
261 return nullopt;
262 }
263
264 // Check that the two operands of root are param0 and param1. All of the
265 // opcodes we recognize are commutative, so we're OK with either order.
266 auto n0 = operand0->parameter_number();
267 auto n1 = operand1->parameter_number();
268 if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) {
269 return nullopt;
270 }
271
272 // If the params are reversed, check that the operation being performed is
273 // commutative.
274 if (n0 == 1) {
275 switch (root->opcode()) {
276 case HloOpcode::kLe:
277 case HloOpcode::kGe:
278 case HloOpcode::kGt:
279 case HloOpcode::kLt:
280 return nullopt;
281 default:
282 break;
283 }
284 }
285
286 // Check that the root and params are all effective scalars.
287 if (!ShapeUtil::IsEffectiveScalar(root->shape()) ||
288 !ShapeUtil::IsEffectiveScalar(operand0->shape()) ||
289 !ShapeUtil::IsEffectiveScalar(operand1->shape())) {
290 return nullopt;
291 }
292
293 // If we recognize the root's opcode, we've successfully pattern-matched!
294 switch (root->opcode()) {
295 case HloOpcode::kAdd:
296 return "add";
297 case HloOpcode::kMultiply:
298 return "multiply";
299 case HloOpcode::kMinimum:
300 return "min";
301 case HloOpcode::kMaximum:
302 return "max";
303 case HloOpcode::kLe:
304 return "less-or-equal";
305 case HloOpcode::kGe:
306 return "greater-or-equal";
307 case HloOpcode::kGt:
308 return "greater-than";
309 case HloOpcode::kLt:
310 return "less-than";
311 case HloOpcode::kEq:
312 return "equal-to";
313 case HloOpcode::kNe:
314 return "not-equal-to";
315 default:
316 return nullopt;
317 }
318}
319
320// Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
321class HloDotDumper {
322 public:
323 HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
324 const DebugOptions& debug_options, bool show_metadata,
325 const HloExecutionProfile* profile, NodeFilter filter)
326 : computation_(computation),
327 label_(label.ToString()),
328 debug_options_(debug_options),
329 show_metadata_(show_metadata),
330 profile_(profile),
331 filter_(std::move(filter)) {}
332
333 string Dump();
334
335 private:
336 // Returns the dot graph identifier for the given instruction.
337 string InstructionId(const HloInstruction* instruction) {
338 return StrCat(reinterpret_cast<uint64>(instruction));
339 }
340
341 // Returns the dot graph identifier for the given computation.
342 string SubcomputationId(const HloComputation* computation) {
343 return StrCat("cluster_", reinterpret_cast<uint64>(computation));
344 }
345
346 // Generates graph header/footer. These should be called *after* dumping all
347 // of the instructions and subcomputations for the graph, as they both use
348 // data generated while dumping the graph.
349 string Header();
350 string Footer();
351
352 bool ShouldShowSubcomputation(const HloComputation* subcomp);
353 bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
354
355 // We omit some nodes from the graph, instead drawing them inlined into the
356 // nodes that use them.
357 bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
358
359 string DumpSubcomputation(const HloComputation* subcomp,
360 const HloInstruction* parent_instr);
361 string DumpComputation(const HloComputation* comp);
362 string DumpRootTag();
363 string DumpInstruction(const HloInstruction* instr);
364 ColorScheme GetInstructionColor(const HloInstruction* instr);
365 string GetInstructionNodeShape(const HloInstruction* instr);
366 string GetInstructionNodeLabel(const HloInstruction* instr);
367 string GetInstructionNodeMetadata(const HloInstruction* instr);
368 string GetInstructionNodeExtraInfo(const HloInstruction* instr);
369 string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
370 void AddInstructionIncomingEdges(const HloInstruction* instr);
371
372 // For most instructions, GetNodeForEdge(instr) returns instr.
373 //
374 // The exception is fusion nodes. For these, we walk up the chain of nested
375 // fusion nodes starting at instr until we reach a node that either (a) isn't
376 // a fusion node, or (b) is a fusion node for which
377 // ShouldShowFusionSubcomputation is false.
378 //
379 // We do this because fusion nodes are expanded inline -- if
380 // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
381 // the graph.
382 //
383 // In general when you want to draw an edge from A to B, you should actually
384 // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
385 const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
386
387 // If instr has just one computation and it's trivial (e.g. "return param0 +
388 // param1"), returns a string you can put into the node's body that names the
389 // subcomputation, e.g. "Subcomputation: <b>add</b>".
390 string GetInstructionTrivialComputationStr(const HloInstruction* instr);
391
392 const HloComputation* computation_; // never null
393 const string label_; // overall name for the graph
394 const DebugOptions& debug_options_;
395 const bool show_metadata_;
396 const HloExecutionProfile* profile_; // may be null
397 const NodeFilter filter_;
398
399 // Each HloInstruction dumped gets a monotically-increasing node ID. This
400 // must start at 1, because that's where graphviz's accounting starts.
401 int64 next_node_id_ = 1;
402 std::unordered_map<const HloInstruction*, int64> node_ids_;
403
404 // The "root" tag doesn't have an associated HloInstruction pointer, so we
405 // need to store it outside the map.
406 int64 root_node_id_;
407
408 // Each (from, to) edge gets a monotonically-increasing ID. This is a
409 // multimap because it's possible for the same edge to appear multiple times
410 // in the graph (e.g. x^2 may be represented as mul(x, x)).
411 int64 next_edge_id_ = 1;
412 std::unordered_multimap<
413 std::pair<const HloInstruction*, const HloInstruction*>, int64,
414 tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
415 edge_ids_;
416
417 // Each HloComputation that's emitted gets a monotonically-increasing ID.
418 int64 next_cluster_id_ = 1;
419 std::unordered_map<const HloComputation*, int64> cluster_ids_;
420
421 // Edges to print from Footer(). Edges come at the end because graphviz is
422 // unhappy if an edge from a subcomputation to a node in the outer computation
423 // appears before both the inner computation and the destination node are
424 // defined.
425 std::vector<string> edges_;
426
427 // When coloring by sharding information, we track the sharding string
428 // representation to color association, by round-robin the color schemes.
429 std::unordered_map<string, ColorScheme> sharding_colors_;
430 int64 next_shard_color_ = 0;
431};
432
433string HloDotDumper::Dump() {
434 string body;
435 StrAppend(&body, DumpComputation(computation_));
436 StrAppend(&body, DumpRootTag());
437
438 // By contract, Header() and Footer() have to be called after we've dumped all
439 // our instructions, because they use state generated during that process.
440 string g = Header();
441 StrAppend(&g, body);
442 StrAppend(&g, Footer());
443 return g;
444}
445
446string HloDotDumper::Header() {
447 const char* fmt = R"(digraph G {
448rankdir = TB;
449compound = true;
450label = <<b>%s</b>>;
451labelloc = t;
452// Disable the tooltip. Interestingly, "" doesn't work!
453tooltip = " ";
454// DOT graphs accept a stylesheet as a URI. So naturally, an inline
455// stylesheet is a data URI!
456stylesheet="
457 data:text/css,
458 @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
459 svg text {
460 font-family: 'Roboto';
461 font-size: 12px;
462 }
463
464 %s
465"
466
467)";
468
469 VLOG(3) << "Generating Header";
470
471 string graph_label =
472 StrCat(label_, "<br/>Computation ", computation_->name());
473 if (computation_->IsFusionComputation()) {
474 StrAppend(&graph_label,
475 StrCat(" (in fusion instruction ",
476 computation_->FusionInstruction()->name(), ")"));
477 }
478 if (profile_ != nullptr) {
479 auto cycles = profile_->total_cycles_executed(*computation_);
480 Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles,
481 tensorflow::strings::HumanReadableNum(cycles));
482 }
483
484 // Create CSS rules that say, when you hover over the given node or cluster,
485 // turn the given edge the given color.
486 //
487 // We rely on a few properties of how graphviz generates SVGs:
488 //
489 // - Nodes are named "nodeN", where N corresponds to the 1-based index of
490 // the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
491 // Edges are similarly named "edgeN", and clusters are named "clustN".
492 // - Nodes come before their in- and out-edges in the SVG. We need this
493 // because the "X ~ Y" CSS selector finds a sibling of X that *comes
494 // after X in the DOM* and matches Y.
495 std::vector<string> edge_css_rules;
496 const char* kBlue = "#1976d2";
497 const char* kRed = "#d32f2f";
498 for (const auto& kv : edge_ids_) {
499 const HloInstruction* from_node = kv.first.first;
500 const HloInstruction* to_node = kv.first.second;
501 int64 edge_id = kv.second;
502
503 auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
504 const char* color) {
505 // One could imagine other ways of writing this CSS rule that involve
506 // less duplication, but this way seems to be relatively performant.
507 edge_css_rules.push_back(
508 Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n"
509 " #%s%d:hover ~ #edge%lld path { "
510 "stroke: %s; stroke-width: .2em; }\n"
511 " #%s%d:hover ~ #edge%lld polygon { "
512 "fill: %s; stroke: %s; stroke-width: .2em; }\n",
513 elem_type, elem_id, edge_id, color, //
514 elem_type, elem_id, edge_id, color, //
515 elem_type, elem_id, edge_id, color, color));
516 };
517
518 // The "to_node" value may be a NULL, indicating that this points to the
519 // "root" tag rather than a normal node.
520 int64 from_node_id =
521 tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
522 if (from_node_id == -1) {
523 LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
524 }
525 int64 to_node_id =
526 to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
527 : root_node_id_;
528 if (to_node != nullptr && to_node_id == -1) {
529 LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
530 }
531
532 add_hover_css_rule("node", from_node_id, kBlue);
533 add_hover_css_rule("node", to_node_id, kRed);
534
535 if (to_node) {
536 VLOG(3) << "Adding css for edge " << edge_id << " from node "
537 << from_node->name() << " to node " << to_node->name();
538 } else {
539 VLOG(3) << "Adding css for edge " << edge_id << " from node "
540 << from_node->name() << " to root tag";
541 }
542
543 // If this edge crosses a fusion cluster boundary, highlight it when the
544 // cluster is hovered over.
545 if (to_node) {
546 if (from_node->IsFused() &&
547 from_node->parent()->root_instruction() == from_node) {
548 int64 cluster_id = cluster_ids_.at(from_node->parent());
549 add_hover_css_rule("clust", cluster_id, kBlue);
550 }
551 if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
552 int64 cluster_id = cluster_ids_.at(to_node->parent());
553 add_hover_css_rule("clust", cluster_id, kRed);
554 }
555 }
556 }
557
558 return Printf(fmt, graph_label, Join(edge_css_rules, "\n"));
559}
560
561string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
562
563bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
564 CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
565 return ShouldShowSubcomputation(instr->fused_instructions_computation());
566}
567
568bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
569 if (subcomp->IsFusionComputation()) {
570 const HloInstruction* fusion = subcomp->FusionInstruction();
571 if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
572 return false;
573 }
574 }
575
576 // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
577 // into the graph.
578 if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
579 return false;
580 }
581
582 // Show the subcomputation if we're showing any of its members.
583 return std::any_of(
584 computation_->instructions().begin(), computation_->instructions().end(),
585 [&](const HloInstruction* instr) { return filter_.Show(instr); });
586}
587
588string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
589 const HloInstruction* parent_instr) {
590 VLOG(2) << "Dumping subcomputation " << subcomp->name();
591 const char* computation_fmt = R"(subgraph %s {
592%s
593label = <%s>;
594labelloc = t;
595tooltip = " ";
596%s
597} // %s
598
599)";
600
601 cluster_ids_[subcomp] = next_cluster_id_++;
602
603 string id = SubcomputationId(subcomp);
604
605 string subcomp_label, style;
606 if (parent_instr->opcode() == HloOpcode::kFusion) {
607 subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s",
608 HtmlLikeStringSanitize(parent_instr->name()),
609 HtmlLikeStringSanitize(parent_instr->ToCategory()));
610 string extra_info = GetInstructionNodeExtraInfo(parent_instr);
611 if (!extra_info.empty()) {
612 StrAppend(&subcomp_label, "<br/>", extra_info);
613 }
614
615 bool highlight = filter_.Highlight(parent_instr);
616 const char* fillcolor;
617 const char* strokecolor;
618 if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
619 // Use the sharding color, if the node isn't highlighted.
620 NodeColors node_colors =
621 NodeColorsForScheme(GetInstructionColor(parent_instr));
622 fillcolor = node_colors.fill_color;
623 strokecolor = node_colors.stroke_color;
624 } else {
625 // Subcomputation's fill/stroke color is light/dark red/gray, depending on
626 // whether or not the subcomputation's fusion node is highlighted.
627 fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
628 strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
629 }
630 style =
631 Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
632 fillcolor, strokecolor);
633 } else {
634 subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s",
635 HtmlLikeStringSanitize(parent_instr->name()),
636 HtmlLikeStringSanitize(subcomp->name()));
637 style = "style=rounded; color=black;";
638 }
639
640 string comp_body = DumpComputation(subcomp);
641
642 // Add an edge from the subcomputation to its parent node. If subcomp
643 // belongs to a fusion node, it's drawn in place of the fusion instruction,
644 // so there's no need to link those.
645 if (parent_instr->opcode() != HloOpcode::kFusion) {
646 const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
647 VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
648 << " as " << next_edge_id_;
649 edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
650 const char* edge_fmt =
651 R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
652 edges_.push_back(Printf(
653 edge_fmt, InstructionId(from), InstructionId(parent_instr),
654 SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
655 }
656
657 string computation =
658 Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
659
660 return computation;
661}
662
663string HloDotDumper::DumpComputation(const HloComputation* comp) {
664 string g;
665 for (const auto* instr : comp->instructions()) {
666 if (!filter_.Show(instr)) {
667 continue;
668 }
669
670 // Dump subcomputations within instr.
671 for (const HloComputation* subcomp : instr->called_computations()) {
672 if (ShouldShowSubcomputation(subcomp)) {
673 StrAppend(&g, DumpSubcomputation(subcomp, instr));
674 }
675 }
676
677 StrAppend(&g, DumpInstruction(instr));
678 }
679 return g;
680}
681
682string HloDotDumper::DumpRootTag() {
683 const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
684
685 // We didn't display constants as separate nodes; so if the root is a
686 // constant, we don't add root tag or edge for it.
687 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
688 return "";
689 }
690
691 auto from_id = InstructionId(from);
692
693 // The ID of the root computation is otherwise unused, so it makes a good ID
694 // to use for the root-tag node. However, the edge_ids_ map requires a
695 // HloInstruction* pointer for the 'to' value, so we use a NULL value there
696 // (rather than a pointer type-cast) to make it obvious if it is erroneously
697 // dereferenced.
698 HloInstruction* to = nullptr;
699 auto to_id = SubcomputationId(computation_);
700
701 string node_body = "ROOT";
702 string node_shape = "circle";
703 ColorScheme color = kBrown;
704
705 VLOG(2) << "Adding root tag as node " << next_node_id_;
706 root_node_id_ = next_node_id_++;
707
708 VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
709 << next_edge_id_;
710 edge_ids_.insert({{from, to}, next_edge_id_++});
711 edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
712
713 return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
714 "\n",
715 to_id, node_body, node_shape, NodeColorAttributes(color));
716}
717
718bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
719 // If a node:
720 //
721 // - is a tuple-shaped parameter,
722 // - is not a parameter to a fusion node,
723 // - has at least kMinUsersToOmit users shown, and
724 // - all of the shown users are get-tuple-elements,
725 //
726 // then we omit it from the graph, merging it with its users.
727 //
728 // This helps us handle the common case where a while loop body has one big
729 // tuple-shaped parameter.
730 const int kMinUsersToOmit = 3;
731 return instr->opcode() == HloOpcode::kParameter &&
732 ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() &&
733 std::count_if(instr->users().begin(), instr->users().end(),
734 [&](const HloInstruction* user) {
735 return filter_.Show(user);
736 }) > kMinUsersToOmit &&
737 std::all_of(instr->users().begin(), instr->users().end(),
738 [&](const HloInstruction* user) {
739 return !filter_.Show(user) ||
740 user->opcode() == HloOpcode::kGetTupleElement;
741 });
742}
743
744string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
745 // We don't display constants as separate nodes; they're merged into their
746 // users.
747 if (instr->opcode() == HloOpcode::kConstant) {
748 return "";
749 }
750 // Skip this node if it's merged into its users.
751 if (ShouldMergeIntoUsers(instr)) {
752 return "";
753 }
754 // Omit the fusion node if its subcomputation is drawn, since the
755 // subcomputation will be drawn inline.
756 if (instr->opcode() == HloOpcode::kFusion &&
757 ShouldShowFusionSubcomputation(instr)) {
758 return "";
759 }
760
761 VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
762 node_ids_[instr] = next_node_id_++;
763
764 ColorScheme color = GetInstructionColor(instr);
765 string node_shape = GetInstructionNodeShape(instr);
766 string node_label = GetInstructionNodeLabel(instr);
767 string node_metadata = GetInstructionNodeMetadata(instr);
768 string extra_info = GetInstructionNodeExtraInfo(instr);
769 string inlined_constants = GetInstructionNodeInlinedOperands(instr);
770 string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
771 AddInstructionIncomingEdges(instr);
772
773 if (!debug_options_.xla_hlo_graph_sharding_color()) {
774 // Override the node's styling if it should be (de-)emphasized.
775 if (filter_.Deemphasized(instr)) {
776 color = kDashedBorder;
777 }
778 if (filter_.Highlight(instr)) {
779 node_shape = "diamond";
780 color = kDarkRed;
781 }
782 }
783 // Build the text that will be displayed inside the node.
784 string node_body = node_label;
785 for (const string& s :
786 {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) {
787 if (!s.empty()) {
788 StrAppend(&node_body, "<br/>", s);
789 }
790 }
791
792 return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
793 "\n",
794 InstructionId(instr), node_body, node_shape,
795 NodeColorAttributes(color));
796}
797
798string HloDotDumper::GetInstructionNodeInlinedOperands(
799 const HloInstruction* instr) {
800 auto stringify_constant = [](const HloInstruction* constant) {
801 const auto& shape = constant->shape();
802
803 // If the shape has a dimension of size zero, print it as e.g.
804 // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
805 // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
806 // is just noise.
807 if (ShapeUtil::HasZeroElements(shape)) {
808 return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape()));
809 }
810
811 // Print the literal value of constants with <= K elements.
812 optional<int64> elem_count;
813 if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
814 elem_count = 1;
815 for (int64 dim : shape.dimensions()) {
816 *elem_count *= dim;
817 }
818 }
819 if (elem_count.has_value() && *elem_count <= 8) {
820 return Printf("%s (%s)", constant->literal().ToString(),
821 ShapeUtil::HumanString(constant->shape()));
822 }
823
824 // Otherwise, print e.g. "%constant.42 (s32[100])".
825 string constant_name;
826 if (tensorflow::str_util::StartsWith(constant->name(), "constant")) {
827 constant_name = constant->name();
828 } else {
829 constant_name = StrCat("constant ", constant->name());
830 }
831 return Printf("%s %s", constant_name,
832 ShapeUtil::HumanString(constant->shape()));
833 };
834
835 // Special case: If instr is a parameter to a fusion node, check whether the
836 // corresponding operand to the fusion node is a constant.
837 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
838 const HloInstruction* fusion = instr->parent()->FusionInstruction();
839 const HloInstruction* operand = fusion->operand(instr->parameter_number());
840 if (operand->opcode() != HloOpcode::kConstant) {
841 return "";
842 }
843 return StrCat("<b>constant</b> ", stringify_constant(operand));
844 }
845
846 std::vector<string> lines;
847 for (int64 i = 0; i < instr->operand_count(); ++i) {
848 const HloInstruction* operand = instr->operand(i);
849 optional<string> operand_str;
850 if (operand->opcode() == HloOpcode::kConstant) {
851 operand_str = stringify_constant(operand);
852 } else if (ShouldMergeIntoUsers(operand)) {
853 // Special case: If the operand is a parameter, use its parameter number
854 // rather than its name, because that's generally how people think of the
855 // node.
856 if (operand->opcode() == HloOpcode::kParameter) {
857 operand_str = Printf("Parameter %lld", operand->parameter_number());
858 } else {
859 operand_str = operand->name();
860 }
861 }
862
863 if (operand_str) {
864 if (instr->operand_count() > 1) {
865 lines.push_back(Printf("<b>operand %lld</b> = %s", i, *operand_str));
866 } else {
867 lines.push_back(Printf("<b>operand</b> = %s", *operand_str));
868 }
869 }
870 }
871 return Join(lines, "<br/>");
872}
873
874ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
875 if (debug_options_.xla_hlo_graph_sharding_color()) {
876 if (!instr->has_sharding()) {
877 return kDashedBorder;
878 }
879 string shard_str = instr->sharding().ToString();
880 auto it = sharding_colors_.find(shard_str);
881 if (it != sharding_colors_.end()) {
882 return it->second;
883 }
884 ColorScheme color = static_cast<ColorScheme>(
885 kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
886 sharding_colors_.emplace(shard_str, color);
887 return color;
888 }
889 const auto kParameterColor = kOrange;
890
891 // Special case: If this instruction has a parameter merged into it, paint it
892 // the same color as a parameter.
893 if (std::any_of(instr->operands().begin(), instr->operands().end(),
894 [&](const HloInstruction* operand) {
895 return operand->opcode() == HloOpcode::kParameter &&
896 ShouldMergeIntoUsers(operand);
897 })) {
898 return kParameterColor;
899 }
900
901 // Pick different colors or shapes for instructions which are particularly
902 // expensive (eg, dot) and those which are unusual in some way or unique
903 // (eg, parameter).
904 switch (instr->opcode()) {
905 case HloOpcode::kAbs:
906 case HloOpcode::kAdd:
907 case HloOpcode::kAnd:
908 case HloOpcode::kAtan2:
909 case HloOpcode::kBitcastConvert:
910 case HloOpcode::kCeil:
911 case HloOpcode::kClamp:
912 case HloOpcode::kComplex:
913 case HloOpcode::kConvert:
914 case HloOpcode::kCos:
915 case HloOpcode::kDivide:
916 case HloOpcode::kEq:
917 case HloOpcode::kExp:
918 case HloOpcode::kFloor:
919 case HloOpcode::kGe:
920 case HloOpcode::kGt:
921 case HloOpcode::kImag:
922 case HloOpcode::kIsFinite:
923 case HloOpcode::kLe:
924 case HloOpcode::kLog:
925 case HloOpcode::kLt:
926 case HloOpcode::kMaximum:
927 case HloOpcode::kMinimum:
928 case HloOpcode::kMultiply:
929 case HloOpcode::kNe:
930 case HloOpcode::kNegate:
931 case HloOpcode::kNot:
932 case HloOpcode::kOr:
933 case HloOpcode::kPower:
934 case HloOpcode::kReal:
935 case HloOpcode::kRemainder:
936 case HloOpcode::kRng:
937 case HloOpcode::kRoundNearestAfz:
938 case HloOpcode::kShiftLeft:
939 case HloOpcode::kShiftRightArithmetic:
940 case HloOpcode::kShiftRightLogical:
941 case HloOpcode::kSign:
942 case HloOpcode::kSin:
943 case HloOpcode::kSlice:
944 case HloOpcode::kSort:
945 case HloOpcode::kSubtract:
946 case HloOpcode::kTanh:
947 // De-emphasize scalar-shaped elementwise ops -- they're generally
948 // uninteresting.
949 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
950 return kWhite;
951 }
952 return kYellow;
953 case HloOpcode::kBitcast:
954 case HloOpcode::kGetTupleElement:
955 case HloOpcode::kTrace:
956 case HloOpcode::kTuple:
957 return kWhite;
958 case HloOpcode::kBroadcast:
959 case HloOpcode::kBroadcastDimOne:
960 // De-emphasize nodes which broadcast a scalar within a fusion node --
961 // these are essentially free.
962 if (instr->IsFused() &&
963 ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
964 return kWhite;
965 }
966 return kGreen;
967 case HloOpcode::kConcatenate:
968 case HloOpcode::kCopy:
969 case HloOpcode::kDynamicSlice:
970 case HloOpcode::kGather:
971 case HloOpcode::kPad:
972 case HloOpcode::kReshape:
973 case HloOpcode::kReverse:
974 case HloOpcode::kSelect:
975 case HloOpcode::kTranspose:
976 // De-emphasize scalar-shaped data movement ops and all data movement ops
977 // inside fusion nodes, both of which are essentially free.
978 if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
979 return kWhite;
980 }
981 return kGreen;
982 case HloOpcode::kDynamicUpdateSlice:
983 // Unlike the data-movement ops above, dynamic-update-slice is not ~free
984 // inside of fusion nodes, so we de-emphasize it only if it's
985 // scalar-shaped.
986 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
987 return kWhite;
988 }
989 return kGreen;
990 case HloOpcode::kConvolution:
991 case HloOpcode::kDot:
992 case HloOpcode::kFft:
993 return kDarkBlue;
994 case HloOpcode::kReducePrecision:
995 return kRed;
996 case HloOpcode::kParameter:
997 return kParameterColor;
998 case HloOpcode::kBatchNormGrad:
999 case HloOpcode::kBatchNormInference:
1000 case HloOpcode::kBatchNormTraining:
1001 case HloOpcode::kReduce:
1002 case HloOpcode::kReduceWindow:
1003 case HloOpcode::kSelectAndScatter:
1004 return kPurple;
1005 case HloOpcode::kFusion:
1006 case HloOpcode::kMap:
1007 return kGray;
1008 case HloOpcode::kCrossReplicaSum:
1009 case HloOpcode::kInfeed:
1010 case HloOpcode::kOutfeed:
1011 case HloOpcode::kRecv:
1012 case HloOpcode::kRecvDone:
1013 case HloOpcode::kSend:
1014 case HloOpcode::kSendDone:
1015 return kBrown;
1016 case HloOpcode::kCall:
1017 case HloOpcode::kConditional:
1018 case HloOpcode::kCustomCall:
1019 case HloOpcode::kHostCompute:
1020 case HloOpcode::kWhile:
1021 return kDarkGreen;
1022 case HloOpcode::kConstant:
1023 LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1024 }
1025}
1026
1027string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1028 // Give while loops a different shape so they're easier to pick out.
1029 switch (instr->opcode()) {
1030 case HloOpcode::kWhile:
1031 return "ellipse";
1032 default:
1033 return "rect";
1034 }
1035}
1036
1037string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1038 // If we have a parameter, put the param number in the name.
1039 if (instr->opcode() == HloOpcode::kParameter) {
1040 return Printf("<b>Parameter %lld</b>", instr->parameter_number());
1041 }
1042
1043 // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1044 // an add instruction. In this case we render just the name.
1045 if (tensorflow::str_util::StartsWith(instr->name(),
1046 HloOpcodeString(instr->opcode()))) {
1047 return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1048 }
1049 string extended_opcode =
1050 StrCat(HloOpcodeString(instr->opcode()),
1051 instr->opcode() != HloOpcode::kFusion
1052 ? ""
1053 : StrCat(":", xla::ToString(instr->fusion_kind())));
1054 // If the name does not contain the opcode, render both.
1055 return Printf("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1056 HtmlLikeStringSanitize(instr->name()));
1057}
1058
1059string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1060 if (!show_metadata_) {
1061 return "";
1062 }
1063
1064 std::vector<string> lines;
1065 if (!instr->metadata().op_name().empty()) {
1066 lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1067 }
1068 if (!instr->metadata().op_type().empty()) {
1069 lines.push_back(Printf(
1070 "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1071 }
1072 if (!instr->metadata().source_file().empty() &&
1073 instr->metadata().source_line() != 0) {
1074 lines.push_back(Printf("op_type: %s", instr->metadata().source_file(),
1075 instr->metadata().source_line()));
1076 }
1077
1078 return Join(lines, "<br/>");
1079}
1080
1081string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1082 std::vector<string> lines;
1083
1084 // Get the instruction's extra attributes excluding the names of its
1085 // subcomputations, since those are drawn explicitly in the graph.
1086 for (const auto& line : instr->ExtraAttributesToString(
1087 HloPrintOptions().set_print_subcomputation_references(false))) {
1088 lines.push_back(HtmlLikeStringSanitize(line));
1089 }
1090
1091 // Show the shape and layout of the instruction, unless it's an inlined fusion
1092 // node -- there the shape and layout is present in the output node.
1093 if (instr->opcode() != HloOpcode::kFusion ||
1094 !ShouldShowFusionSubcomputation(instr)) {
1095 // Show layout of instructions with more than one dimension. Don't show
1096 // layout on tuples or tensors with just one dimension (which only have one
1097 // possible layout) to avoid visual noise.
1098 bool shape_is_multidim = false;
1099 ShapeUtil::ForEachSubshape(instr->shape(),
1100 [&](const Shape& s, const ShapeIndex&) {
1101 shape_is_multidim |= s.dimensions_size() > 1;
1102 });
1103 string instr_shape;
1104 if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1105 instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1106 } else {
1107 instr_shape = ShapeUtil::HumanString(instr->shape());
1108 }
1109
1110 // Some instructions have giant tuples as their shapes, so truncate the
1111 // HLO's shape to kMaxShapeLen characters.
1112 constexpr int kMaxShapeLen = 64;
1113 if (instr_shape.length() > kMaxShapeLen) {
1114 instr_shape = StrCat(
1115 tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3),
1116 "...");
1117 }
1118 lines.push_back(instr_shape);
1119 }
1120 if (debug_options_.xla_hlo_graph_addresses()) {
1121 lines.push_back(Printf("[%p]", instr));
1122 }
1123 if (profile_ != nullptr) {
1124 double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1125 double total_cycles_executed =
1126 profile_->total_cycles_executed(*instr->parent());
1127 if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1128 lines.push_back(
1129 Printf("%% of cycles executed=%.2f",
1130 100 * hlo_cycles_executed / total_cycles_executed));
1131 }
1132 }
1133 return Join(lines, "<br/>");
1134}
1135
1136void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1137 auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1138 int64 operand_num, bool control_edge = false) {
1139 from = GetNodeForEdge(from);
1140
1141 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1142 ShouldMergeIntoUsers(from)) {
1143 return;
1144 }
1145 VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1146 << " as " << next_edge_id_;
1147 edge_ids_.insert({{from, to}, next_edge_id_++});
1148
1149 string edge_label;
1150 if (instr->operand_count() > 1 && !control_edge) {
1151 edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num);
1152 } else if (control_edge) {
1153 edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1154 }
1155 const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)";
1156 edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to),
1157 from->name(), to->name(), edge_label));
1158 };
1159
1160 // Add edges from instr's operands to instr. Parameters within fusion
1161 // expressions are handled specially -- we draw an edge from the corresponding
1162 // operand on the fusion node itself to the parameter.
1163 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1164 // Only add the edge if this is not the outermost computation; otherwise it
1165 // will lead from a node we're not drawing.
1166 if (instr->parent() != computation_) {
1167 const HloInstruction* fusion = instr->parent()->FusionInstruction();
1168 add_edge(fusion->operand(instr->parameter_number()), instr,
1169 /*operand_num=*/0);
1170 }
1171 } else {
1172 for (int64 i = 0; i < instr->operand_count(); ++i) {
1173 add_edge(instr->operand(i), instr, i);
1174 }
1175 for (const HloInstruction* pred : instr->control_predecessors()) {
1176 add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1177 }
1178 }
1179}
1180
1181string HloDotDumper::GetInstructionTrivialComputationStr(
1182 const HloInstruction* instr) {
1183 // called_computations() on a fusion node "inherits" any called computations
1184 // of the fused root, which isn't what we want. Just ignore fusion nodes
1185 // here; they're handled separately.
1186 if (instr->opcode() == HloOpcode::kFusion) {
1187 return "";
1188 }
1189
1190 std::vector<string> lines;
1191 for (int64 i = 0; i < instr->called_computations().size(); ++i) {
1192 optional<string> computation_type =
1193 MatchTrivialComputation(instr->called_computations()[i]);
1194 if (!computation_type) {
1195 continue;
1196 }
1197 if (instr->called_computations().size() == 1) {
1198 lines.push_back(Printf("Subcomputation: <b>%s</b>",
1199 HtmlLikeStringSanitize(*computation_type)));
1200 } else {
1201 lines.push_back(Printf("Subcomputation %lld: <b>%s</b>", i,
1202 HtmlLikeStringSanitize(*computation_type)));
1203 }
1204 }
1205 return Join(lines, "<br/>");
1206}
1207
1208const HloInstruction* HloDotDumper::GetNodeForEdge(
1209 const HloInstruction* instr) {
1210 while (instr->opcode() == HloOpcode::kFusion &&
1211 ShouldShowFusionSubcomputation(instr)) {
1212 instr = instr->fused_expression_root();
1213 }
1214 return instr;
1215}
1216
1217class GraphRendererRegistry {
1218 public:
1219 void AddRenderer(GraphRendererInterface* graph_renderer) {
1220 tensorflow::mutex_lock lock(mu_);
1221 graph_renderer_ = graph_renderer;
1222 }
1223
1224 GraphRendererInterface* GetDefaultRenderer() {
1225 tensorflow::mutex_lock lock(mu_);
1226 return graph_renderer_;
1227 }
1228
1229 static GraphRendererRegistry* Default() {
1230 static GraphRendererRegistry* registry = new GraphRendererRegistry();
1231 return registry;
1232 }
1233
1234 private:
1235 tensorflow::mutex mu_;
1236 GraphRendererInterface* graph_renderer_ = nullptr;
1237};
1238
1239} // namespace
1240
1241Registrar::Registrar(GraphRendererInterface* dumper) {
1242 GraphRendererRegistry::Default()->AddRenderer(dumper);
1243}
1244
1245namespace {
1246
1247// Gets a NodeFilter that includes roughly all instructions whose distance from
1248// root is <= radius.
1249NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
1250 // First, find the neighborhood of nodes with distance from root <= radius.
1251 // These nodes are our initial set of "normal" nodes.
1252 std::unordered_map<const HloInstruction*, NodeFilterResult> nodes;
1253 std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1254 worklist.push_back({root, 0});
1255 while (!worklist.empty()) {
1256 const HloInstruction* instr;
1257 int64 depth;
1258 std::tie(instr, depth) = worklist.front();
1259 worklist.pop_front();
1260
1261 nodes[instr] = kNormalNode;
1262 if (depth == radius) {
1263 continue;
1264 }
1265
1266 // Traverse into instr's operands.
1267 //
1268 // Don't traverse into tuples' operands unless the tuple is the root.
1269 // Usually a tuple is the bottommost node in the graph, and so its operands
1270 // are not interesting to the graph at hand.
1271 if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1272 for (const HloInstruction* operand : instr->operands()) {
1273 if (!nodes.count(operand)) {
1274 worklist.push_back({operand, depth + 1});
1275 }
1276 }
1277 }
1278
1279 // Traverse into instr's nested computations.
1280 for (const HloComputation* computation : instr->called_computations()) {
1281 worklist.push_back({computation->root_instruction(), depth + 1});
1282 }
1283
1284 // Traverse into instr's users, unless:
1285 //
1286 // - there are a ton of them, in which case they're probably not
1287 // interesting (and anyway, rendering them all would make the graph
1288 // unreadable), or
1289 // - instr is a constant, in which case its users are probably not
1290 // interesting.
1291 if (instr->opcode() == HloOpcode::kConstant) {
1292 continue;
1293 }
1294 constexpr int kMaxUsersToRender = 16;
1295 if (instr->user_count() > kMaxUsersToRender) {
1296 // If we're going to skip this node's users, style it as such.
1297 nodes[instr] = kSomeUsersOmitted;
1298 continue;
1299 }
1300 for (const HloInstruction* user : instr->users()) {
1301 if (!nodes.count(user)) {
1302 worklist.push_back({user, depth + 1});
1303 }
1304 }
1305 }
1306
1307 auto is_displayed = [&](const HloInstruction* instr) {
1308 // Constants are displayed inline with their users; they're never omitted.
1309 // Nodes in subcomputations are always shown.
1310 return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant ||
1311 instr->parent() != root->parent();
1312 };
1313
1314 // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1315 // know which nodes will be included in the graph.
1316 for (auto& kv : nodes) {
1317 const HloInstruction* instr = kv.first;
1318 NodeFilterResult& filter_result = kv.second;
1319 const auto& operands = instr->operands();
1320
1321 if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
1322 !std::all_of(operands.begin(), operands.end(), is_displayed)) {
1323 // Mark nodes with some operands omitted appropriately.
1324 filter_result = kSomeOperandsOmitted;
1325 } else if (!operands.empty() &&
1326 std::none_of(operands.begin(), operands.end(), is_displayed)) {
1327 // Mark nodes with *all* operands omitted appropriately.
1328 filter_result = kOmitNodeOperands;
1329 }
1330
1331 // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1332 // users made it into the graph.
1333 if (filter_result == kSomeUsersOmitted &&
1334 std::all_of(instr->users().begin(), instr->users().end(),
1335 is_displayed)) {
1336 filter_result = kNormalNode;
1337 }
1338 }
1339
1340 // Highlight the root node.
1341 nodes[root] = kHighlightNode;
1342
1343 return NodeFilter([=](const HloInstruction* instr) {
1344 auto it = nodes.find(instr);
1345 if (it != nodes.end()) {
1346 return it->second;
1347 }
1348 // Show all nodes in subcomputations.
1349 if (instr->parent() != root->parent()) {
1350 return kNormalNode;
1351 }
1352 return kHideNode;
1353 });
1354}
1355
1356string SaveGraph(const string& graph,
1357 GraphRendererInterface::GraphKind graph_kind,
1358 const string& dest_path) {
1359 static std::atomic<int> output_num(0);
1360 string file_extension;
1361 switch (graph_kind) {
1362 case GraphRendererInterface::DOT_GRAPH:
1363 file_extension = ".dot";
1364 break;
1365 case GraphRendererInterface::TF_GRAPHDEF:
1366 file_extension = ".pbtxt";
1367 break;
1368 }
1369 string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, "."));
1370 auto status = Status::OK();
1371 auto env = tensorflow::Env::Default();
1372 if (!env->CreateUniqueFileName(&path, file_extension)) {
1373 status =
1374 Status(tensorflow::error::Code::UNKNOWN,
1375 StrCat("Failed to create temporary file to dump HLO graph: ",
1376 strerror(errno)));
1377 } else {
1378 status = tensorflow::WriteStringToFile(env, path, graph);
1379 }
1380 if (!status.ok()) {
1381 LOG(WARNING) << "Saving HLO graph failed: " << status;
1382 }
1383 return path;
1384}
1385
1386string ExportGraph(const string& graph,
1387 GraphRendererInterface::GraphKind graph_kind,
1388 const DebugOptions& debug_options) {
1389 string path = debug_options.xla_hlo_graph_path();
1390 if (!path.empty()) {
1391 return SaveGraph(graph, graph_kind, path);
1392 } else {
1393 auto graph_renderer =
1394 GraphRendererRegistry::Default()->GetDefaultRenderer();
1395 CHECK(graph_renderer != nullptr)
1396 << "No registered renderer for the HLO graph. "
1397 "Use --xla_hlo_graph_path=PATH to export to local file system";
1398 return graph_renderer->RenderGraph(graph, graph_kind, debug_options);
1399 }
1400}
1401
1402} // namespace
1403
1404string DumpGraph(const HloComputation& computation, const string& label,
1405 const DebugOptions& debug_options,
1406 const HloExecutionProfile* hlo_execution_profile,
1407 bool show_metadata) {
1408 GraphRendererInterface::GraphKind graph_kind;
1409 string graph;
1410 if (debug_options.xla_hlo_dump_as_graphdef()) {
1411 HloTfGraphBuilder builder(debug_options);
1412 TF_CHECK_OK(builder.AddComputation(computation));
1413 CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(),
1414 &graph));
1415 graph_kind = GraphRendererInterface::TF_GRAPHDEF;
1416 } else {
1417 graph = HloDotDumper(&computation, label, debug_options, show_metadata,
1418 hlo_execution_profile, NodeFilter())
1419 .Dump();
1420 graph_kind = GraphRendererInterface::DOT_GRAPH;
1421 }
1422
1423 string graph_url = ExportGraph(graph, graph_kind, debug_options);
1424 LOG(INFO) << "computation " << computation.name() << " [" << label
1425 << "]: " << graph_url;
1426 return graph_url;
1427}
1428
1429string DumpNeighborhoodAround(const HloInstruction& node, int radius,
1430 bool show_metadata) {
1431 auto debug_options = node.GetModule()->config().debug_options();
1432 string label =
1433 StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1434 NodeFilter filter = MakeNodeFilter(&node, radius);
1435 string graph =
1436 HloDotDumper(node.parent(), label, debug_options, show_metadata,
1437 /*profile=*/nullptr, filter)
1438 .Dump();
1439 return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options);
1440}
1441
1442void DumpText(const HloModule& module, const string& label,
1443 const string& directory_path, bool do_prefix) {
1444 Env* env = Env::Default();
1445 TF_CHECK_OK(env->RecursivelyCreateDir(directory_path));
1446 string prefix = StrCat(env->NowMicros());
1447 string filename =
1448 do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt");
1449 string path = JoinPath(directory_path, filename);
1450 TF_CHECK_OK(WriteStringToFile(
1451 env, path,
1452 module.ToString(HloPrintOptions().set_print_large_constants(true))));
1453 LOG(INFO) << "dumping module '" << module.name() << "' to " << path;
1454}
1455
1456string MaybeDumpHloModule(const HloModule& module, const string& label,
1457 const HloExecutionProfile* profile) {
1458 const DebugOptions& debug_options = module.config().debug_options();
1459 VLOG(2) << "MaybeDumpHloModule called on module " << module.name()
1460 << " with generate_hlo_graph regex \""
1461 << debug_options.xla_generate_hlo_graph() << "\"";
1462 string graph_url;
1463 if (!debug_options.xla_generate_hlo_graph().empty() &&
1464 RE2::PartialMatch(module.name(),
1465 debug_options.xla_generate_hlo_graph())) {
1466 graph_url =
1467 DumpGraph(*module.entry_computation(), label, debug_options, profile);
1468 }
1469 if (!debug_options.xla_log_hlo_text().empty() &&
1470 RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) {
1471 LOG(INFO) << "HLO for module " << module.name();
1472 LOG(INFO) << "Label: " << label;
1473 XLA_LOG_LINES(2, module.ToString());
1474 }
1475 if (!debug_options.xla_generate_hlo_text_to().empty()) {
1476 DumpText(module, label, debug_options.xla_generate_hlo_text_to());
1477 }
1478 return graph_url;
1479}
1480
1481} // namespace hlo_graph_dumper
1482} // namespace xla
1483