1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_module.h"
17
18#include <iterator>
19#include <set>
20#include <sstream>
21#include <unordered_map>
22#include <unordered_set>
23#include <utility>
24
25#include "tensorflow/compiler/xla/map_util.h"
26#include "tensorflow/compiler/xla/ptr_util.h"
27#include "tensorflow/compiler/xla/shape_util.h"
28#include "tensorflow/compiler/xla/types.h"
29#include "tensorflow/core/lib/gtl/map_util.h"
30#include "tensorflow/core/lib/strings/strcat.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace xla {
34
35HloModule::HloModule(const string& name,
36 const VersionedComputationHandle& entry_computation_handle,
37 const HloModuleConfig& config)
38 : name_(NameUniquer::GetSanitizedName(name)),
39 config_(config),
40 has_entry_computation_handle_(true),
41 entry_computation_handle_(entry_computation_handle),
42 unique_id_(next_unique_module_id_++) {}
43
44HloModule::HloModule(const string& name)
45 : name_(NameUniquer::GetSanitizedName(name)),
46 unique_id_(next_unique_module_id_++) {}
47HloModule::HloModule(const string& name, const HloModuleConfig& config)
48 : name_(NameUniquer::GetSanitizedName(name)),
49 config_(config),
50 unique_id_(next_unique_module_id_++) {}
51
52HloComputation* HloModule::AddComputationInternal(
53 std::unique_ptr<HloComputation> computation, bool is_entry,
54 bool uniquify_names) {
55 if (is_entry) {
56 CHECK_EQ(nullptr, entry_computation_);
57 entry_computation_ = computation.get();
58
59 // If the module configuration has no entry layout computation set, create a
60 // default one based on the program shape.
61 if (!config_.has_entry_computation_layout()) {
62 config_.SetDefaultComputationLayout(
63 entry_computation_->ComputeProgramShape());
64 }
65 }
66
67 if (uniquify_names) {
68 computation->UniquifyName(&computation_name_uniquer_);
69 for (auto* instruction : computation->instructions()) {
70 instruction->UniquifyName(&instruction_name_uniquer_);
71 }
72 } else {
73 // Don't uniquify the names of the computation or instruction, but we must
74 // run the names through the uniquifiers to prevent future name collisions
75 // for computations and instructions created later.
76 computation_name_uniquer_.GetUniqueName(computation->name());
77 for (auto* instruction : computation->instructions()) {
78 instruction_name_uniquer_.GetUniqueName(instruction->name());
79 }
80 }
81
82 // Pick unique IDs for each instruction.
83 for (auto* instruction : computation->instructions()) {
84 instruction->SetUniqueId(NewUniqueInstructionId());
85 }
86 // Set unique id to this computation.
87 CHECK_NE(computation->root_instruction()->unique_id(), -1)
88 << "Root has no valid id: " << computation->ToString();
89 computation->SetUniqueId(computation->root_instruction()->unique_id());
90
91 computation->set_parent(this);
92 computations_.push_back(std::move(computation));
93 return computations_.back().get();
94}
95
96HloComputation* HloModule::AddEntryComputation(
97 std::unique_ptr<HloComputation> computation) {
98 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
99 /*uniquify_names=*/true);
100}
101
102Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
103 auto it =
104 std::find_if(computations_.begin(), computations_.end(),
105 [&to_remove](const std::unique_ptr<HloComputation>& comp) {
106 return comp.get() == to_remove;
107 });
108 TF_RET_CHECK(it->get() == to_remove);
109 computations_.erase(it);
110 return Status::OK();
111}
112
113HloComputation* HloModule::AddEmbeddedComputation(
114 std::unique_ptr<HloComputation> computation) {
115 return AddComputationInternal(std::move(computation), /*is_entry=*/false,
116 /*uniquify_names=*/true);
117}
118
119void HloModule::ReplaceComputations(
120 const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
121 // Replace all uses of non-canonical computations with their
122 // representatives.
123 std::vector<std::unique_ptr<HloComputation>> new_computations;
124 new_computations.reserve(computations_.size());
125
126 for (std::unique_ptr<HloComputation>& computation : computations_) {
127 for (auto* instruction : computation->instructions()) {
128 switch (instruction->opcode()) {
129 case HloOpcode::kCall:
130 case HloOpcode::kMap:
131 case HloOpcode::kReduce:
132 case HloOpcode::kReduceWindow: {
133 HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
134 replacements, instruction->to_apply(), nullptr);
135 if (new_arg != nullptr) {
136 instruction->set_to_apply(new_arg);
137 }
138 break;
139 }
140 case HloOpcode::kWhile: {
141 HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
142 replacements, instruction->while_condition(), nullptr);
143 if (new_condition != nullptr) {
144 instruction->set_while_condition(new_condition);
145 }
146 HloComputation* new_body = tensorflow::gtl::FindWithDefault(
147 replacements, instruction->while_body(), nullptr);
148 if (new_body != nullptr) {
149 instruction->set_while_body(new_body);
150 }
151 break;
152 }
153 case HloOpcode::kConditional: {
154 HloComputation* new_true_computation =
155 tensorflow::gtl::FindWithDefault(
156 replacements, instruction->true_computation(), nullptr);
157 if (new_true_computation != nullptr) {
158 instruction->set_true_computation(new_true_computation);
159 }
160 HloComputation* new_false_computation =
161 tensorflow::gtl::FindWithDefault(
162 replacements, instruction->false_computation(), nullptr);
163 if (new_false_computation != nullptr) {
164 instruction->set_false_computation(new_false_computation);
165 }
166 break;
167 }
168 case HloOpcode::kSelectAndScatter: {
169 HloComputation* new_select = tensorflow::gtl::FindWithDefault(
170 replacements, instruction->select(), nullptr);
171 if (new_select != nullptr) {
172 instruction->set_select(new_select);
173 }
174 HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
175 replacements, instruction->scatter(), nullptr);
176 if (new_scatter != nullptr) {
177 instruction->set_scatter(new_scatter);
178 }
179 break;
180 }
181 default:
182 break;
183 }
184 }
185
186 if (replacements.find(computation.get()) == replacements.end()) {
187 new_computations.push_back(std::move(computation));
188 }
189 }
190
191 // Replace entry_computation if necessary.
192 entry_computation_ = tensorflow::gtl::FindWithDefault(
193 replacements, entry_computation_, entry_computation_);
194
195 computations_ = std::move(new_computations);
196}
197
198string HloModule::ToString(const HloPrintOptions& options) const {
199 std::ostringstream s;
200 s << "HloModule " << name() << "\n\n";
201 for (const HloComputation* computation : MakeComputationPostOrder()) {
202 if (computation == entry_computation()) {
203 s << "ENTRY ";
204 }
205 s << computation->ToString(options) << "\n\n";
206 }
207 return s.str();
208}
209
210HloModuleProto HloModule::ToProto() const {
211 HloModuleProto proto;
212 proto.set_id(unique_id_);
213 proto.set_name(name_);
214 proto.set_entry_computation_name(entry_computation_->name());
215 proto.set_entry_computation_id(entry_computation_->unique_id());
216 for (const HloComputation* computation : MakeComputationPostOrder()) {
217 HloComputationProto computation_proto = computation->ToProto();
218 if (computation->name() == entry_computation_->name()) {
219 *proto.mutable_program_shape() = computation_proto.program_shape();
220 }
221 proto.add_computations()->Swap(&computation_proto);
222 }
223 return proto;
224}
225
226/* static */
227StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
228 const HloModuleProto& proto, const HloModuleConfig& module_config,
229 const VersionedComputationHandle& entry_computation_handle) {
230 // The ProgramShape in the passed in module config must match the shapes of
231 // the entry parameters and root.
232 TF_RET_CHECK(proto.has_program_shape())
233 << "No program shape found in the proto";
234 const auto& expected_program_shape = proto.program_shape();
235 TF_RET_CHECK(expected_program_shape.parameters_size() ==
236 module_config.entry_computation_layout().parameter_count());
237 for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
238 const Shape& parameter_shape =
239 module_config.entry_computation_layout().parameter_layout(i).shape();
240 TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
241 parameter_shape))
242 << "HloModuleConfig has different shape for parameter " << i
243 << " than the HLO module. Expected: "
244 << ShapeUtil::HumanStringWithLayout(
245 expected_program_shape.parameters(i))
246 << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
247 }
248 const Shape& result_shape =
249 module_config.entry_computation_layout().result_layout().shape();
250 TF_RET_CHECK(
251 ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
252 << "HloModuleConfig has different result shape than the HLO module. "
253 "Expected: "
254 << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
255 << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
256
257 auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
258 module_config);
259
260 tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
261 for (const HloComputationProto& computation_proto : proto.computations()) {
262 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
263 HloComputation::CreateFromProto(
264 module.get(), computation_proto, computation_map));
265 CHECK_NE(computation.get(), nullptr);
266 int64 computation_id = computation_proto.id();
267 TF_RET_CHECK(computation_id != -1);
268 TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
269 // Don't uniquify names because we want names to be stable across
270 // serialization and deserialization.
271 computation_map[computation_id] = module->AddComputationInternal(
272 std::move(computation),
273 /*is_entry=*/proto.entry_computation_id() == computation_id,
274 /*uniquify_names=*/false);
275 }
276 TF_RET_CHECK(module->entry_computation_ != nullptr);
277
278 // Because we didn't uniquify the names, double-check that the instruction and
279 // computation names are unique from the proto.
280 tensorflow::gtl::FlatSet<string> computation_names;
281 tensorflow::gtl::FlatSet<string> instruction_names;
282 for (HloComputation* computation : module->computations()) {
283 TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
284 << "Computation name is not unique: " << computation->name();
285 computation_names.insert(computation->name());
286 for (HloInstruction* instruction : computation->instructions()) {
287 TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
288 << "Instruction name is not unique: " << instruction->name();
289 instruction_names.insert(instruction->name());
290 }
291 }
292
293 return std::move(module);
294}
295
296/* static */
297StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
298 const HloModuleProto& module, const DebugOptions& debug_options) {
299 TF_RET_CHECK(module.has_program_shape())
300 << "No program shape found in the proto";
301 const auto& program_shape = module.program_shape();
302
303 HloModuleConfig module_config(program_shape);
304 module_config.set_debug_options(debug_options);
305
306 // The module config is constructed with default layouts regardless of what is
307 // passed in via the ProgramShape. Set the layouts to the appropriate values.
308 ComputationLayout* entry_layout =
309 module_config.mutable_entry_computation_layout();
310 for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
311 TF_RETURN_IF_ERROR(
312 entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
313 program_shape.parameters(i)));
314 }
315 TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
316 program_shape.result()));
317
318 return module_config;
319}
320
321namespace {
322// Returns whether `hlo` is used outside the given subcomputation.
323// `instructions_in_subcomputation` is the instruction set of the given
324// subcomputation.
325bool IsUsedOutsideSubcomputation(
326 const HloInstruction& hlo,
327 const std::unordered_set<HloInstruction*>& instructions_in_subcomputation) {
328 for (HloInstruction* user : hlo.users()) {
329 if (!instructions_in_subcomputation.count(user)) {
330 return true;
331 }
332 }
333 return false;
334}
335} // anonymous namespace
336
337HloInstruction* HloModule::OutlineExpressionFromComputation(
338 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
339 const string& outlined_computation_name, HloComputation* computation) {
340 auto builder = HloComputation::Builder(outlined_computation_name);
341
342 // A map from original instructions to their counterparts in the new outlined
343 // function.
344 std::unordered_map<HloInstruction*, HloInstruction*> outlined_instructions;
345 // A set that contains all instructions to be outlined.
346 std::unordered_set<HloInstruction*> instruction_set_to_outline(
347 instructions_to_outline.begin(), instructions_to_outline.end());
348 std::vector<HloInstruction*> arguments;
349 std::vector<HloInstruction*> outputs;
350 int64 parameter_count = 0;
351 for (HloInstruction* instruction_to_outline : instructions_to_outline) {
352 // Clone the original instruction.
353 HloInstruction* outlined_instruction =
354 builder.AddInstruction(instruction_to_outline->Clone());
355
356 // Replace its operands to their counterparts in the new function.
357 for (int64 operand_num = 0;
358 operand_num < outlined_instruction->operand_count(); ++operand_num) {
359 HloInstruction* old_operand =
360 outlined_instruction->mutable_operand(operand_num);
361
362 HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
363 if (*operand_slot == nullptr) {
364 // Because instructions_to_outline is in topological order, if
365 // old_operand is not in outlined_instructions, old_operand must be an
366 // input of the outlined subcomputation and thus should be represented
367 // as a parameter in the new function.
368 arguments.push_back(old_operand);
369 *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
370 parameter_count, old_operand->shape(), ""));
371 ++parameter_count;
372 }
373 TF_CHECK_OK(
374 outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
375 }
376
377 // Insert the new instruction into the outlined_instructions map.
378 InsertOrDie(&outlined_instructions, instruction_to_outline,
379 outlined_instruction);
380
381 // Mark instruction_to_outline an output if it is used outside the
382 // subcomputation or is the output of the original computation (i.e. used
383 // externally).
384 if (instruction_to_outline->user_count() == 0 ||
385 IsUsedOutsideSubcomputation(*instruction_to_outline,
386 instruction_set_to_outline)) {
387 outputs.push_back(instruction_to_outline);
388 }
389 }
390
391 if (outputs.size() != 1) {
392 string error_message =
393 "The subcomputation to outline has multiple outputs:\n";
394 for (HloInstruction* output : outputs) {
395 tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n");
396 }
397 LOG(FATAL) << error_message;
398 }
399 HloInstruction* output = outputs[0];
400
401 // Creates a call to the nested computation.
402 HloComputation* nested_computation = AddEmbeddedComputation(
403 builder.Build(FindOrDie(outlined_instructions, output)));
404 HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
405 output->shape(), arguments, nested_computation));
406
407 VLOG(2) << "Outlining the following instructions";
408 for (auto* instruction_to_outline : instructions_to_outline) {
409 VLOG(2) << " " << instruction_to_outline->ToString();
410 }
411 VLOG(2) << "as a call " << call->ToString();
412 VLOG(2) << "to " << nested_computation->ToString();
413
414 TF_CHECK_OK(output->ReplaceAllUsesWith(call));
415 for (auto i = instructions_to_outline.rbegin();
416 i != instructions_to_outline.rend(); ++i) {
417 TF_CHECK_OK(computation->RemoveInstruction(*i));
418 }
419
420 return call;
421}
422
423int64 HloModule::instruction_count() const {
424 int64 n = 0;
425 for (const auto& computation : computations_) {
426 n += computation->instruction_count();
427 }
428 return n;
429}
430
431std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
432 // First determine all root computations by building a set of nonroot
433 // computations (computations which are called by an instruction in the
434 // module).
435 std::set<HloComputation*> nonroot_computations;
436 for (auto& computation : computations_) {
437 for (auto* instruction : computation->instructions()) {
438 for (HloComputation* called_computation :
439 instruction->called_computations()) {
440 nonroot_computations.insert(called_computation);
441 }
442 }
443 }
444
445 // Keep track of computations which have already been added to the post
446 // order. This prevents duplication as an embedded computation may be called
447 // from two different root computations.
448 std::set<HloComputation*> added_computations;
449 std::list<HloComputation*> post_order;
450 for (auto& computation : computations_) {
451 if (nonroot_computations.count(computation.get()) == 0) {
452 for (HloComputation* embedded_computation :
453 computation->MakeEmbeddedComputationsList()) {
454 if (added_computations.count(embedded_computation) == 0) {
455 post_order.push_back(embedded_computation);
456 added_computations.insert(embedded_computation);
457 }
458 }
459 // Root computations should only be encountered once.
460 CHECK_EQ(0, added_computations.count(computation.get()));
461 post_order.push_back(computation.get());
462 added_computations.insert(computation.get());
463 }
464 }
465 CHECK_EQ(post_order.size(), computations_.size());
466 return post_order;
467}
468
469std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
470 std::vector<HloComputation*> result;
471 for (auto* c : computations()) {
472 if (c->IsFusionComputation()) {
473 continue;
474 }
475 result.push_back(c);
476 }
477 return result;
478}
479
480std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
481 VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
482 auto module = MakeUnique<HloModule>(name_ + "-" + suffix);
483 module->config_ = config_;
484 module->entry_computation_handle_ = entry_computation_handle_;
485 module->has_entry_computation_handle_ = has_entry_computation_handle_;
486
487 std::unordered_map<HloComputation*, HloComputation*> clone_map;
488 for (auto& computation : computations_) {
489 if (computation->IsFusionComputation()) {
490 // Cloning of a fused computation is handled by its fusion instruction.
491 continue;
492 }
493
494 // When cloning a computation, pass in the new module, so that for any
495 // fusion instruction in this computation, the fused computation will be
496 // deep cloned to the new module.
497 auto cloned_computation = computation->Clone(suffix, module.get());
498 InsertOrDie(&clone_map, computation.get(), cloned_computation.get());
499
500 if (entry_computation_ == computation.get()) {
501 module->AddEntryComputation(std::move(cloned_computation));
502 } else {
503 module->AddEmbeddedComputation(std::move(cloned_computation));
504 }
505 }
506
507 for (auto& cloned_computation : module->computations_) {
508 for (auto* instruction : cloned_computation->instructions()) {
509 // Rewrite instruction's called_computation to point to the cloned
510 // computations.
511 instruction->ReplaceCalledComputations([&](HloComputation* hlo) {
512 if (hlo->IsFusionComputation()) {
513 // Cloning of a fused computation has already been handled when its
514 // fusion instruction is cloned. So this hlo computation is already
515 // the cloned one.
516 return hlo;
517 }
518 return FindOrDie(clone_map, hlo);
519 });
520 }
521 }
522 return module;
523}
524
525HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) {
526 HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this));
527 TF_CHECK_OK(
528 clone->root_instruction()->Accept([this](HloInstruction* instruction) {
529 instruction->ReplaceCalledComputations([this](HloComputation* callee) {
530 return DeepCloneComputation(callee);
531 });
532 return Status::OK();
533 }));
534 return clone;
535}
536
537uint64 HloModule::RandomNew64() const {
538 tensorflow::mutex_lock l(rng_mutex_);
539 return rng_();
540}
541
542/* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
543
544} // namespace xla
545