1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/compiler/xla/service/hlo_module.h" |
17 | |
18 | #include <iterator> |
19 | #include <set> |
20 | #include <sstream> |
21 | #include <unordered_map> |
22 | #include <unordered_set> |
23 | #include <utility> |
24 | |
25 | #include "tensorflow/compiler/xla/map_util.h" |
26 | #include "tensorflow/compiler/xla/ptr_util.h" |
27 | #include "tensorflow/compiler/xla/shape_util.h" |
28 | #include "tensorflow/compiler/xla/types.h" |
29 | #include "tensorflow/core/lib/gtl/map_util.h" |
30 | #include "tensorflow/core/lib/strings/strcat.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | |
33 | namespace xla { |
34 | |
35 | HloModule::HloModule(const string& name, |
36 | const VersionedComputationHandle& entry_computation_handle, |
37 | const HloModuleConfig& config) |
38 | : name_(NameUniquer::GetSanitizedName(name)), |
39 | config_(config), |
40 | has_entry_computation_handle_(true), |
41 | entry_computation_handle_(entry_computation_handle), |
42 | unique_id_(next_unique_module_id_++) {} |
43 | |
44 | HloModule::HloModule(const string& name) |
45 | : name_(NameUniquer::GetSanitizedName(name)), |
46 | unique_id_(next_unique_module_id_++) {} |
47 | HloModule::HloModule(const string& name, const HloModuleConfig& config) |
48 | : name_(NameUniquer::GetSanitizedName(name)), |
49 | config_(config), |
50 | unique_id_(next_unique_module_id_++) {} |
51 | |
52 | HloComputation* HloModule::AddComputationInternal( |
53 | std::unique_ptr<HloComputation> computation, bool is_entry, |
54 | bool uniquify_names) { |
55 | if (is_entry) { |
56 | CHECK_EQ(nullptr, entry_computation_); |
57 | entry_computation_ = computation.get(); |
58 | |
59 | // If the module configuration has no entry layout computation set, create a |
60 | // default one based on the program shape. |
61 | if (!config_.has_entry_computation_layout()) { |
62 | config_.SetDefaultComputationLayout( |
63 | entry_computation_->ComputeProgramShape()); |
64 | } |
65 | } |
66 | |
67 | if (uniquify_names) { |
68 | computation->UniquifyName(&computation_name_uniquer_); |
69 | for (auto* instruction : computation->instructions()) { |
70 | instruction->UniquifyName(&instruction_name_uniquer_); |
71 | } |
72 | } else { |
73 | // Don't uniquify the names of the computation or instruction, but we must |
74 | // run the names through the uniquifiers to prevent future name collisions |
75 | // for computations and instructions created later. |
76 | computation_name_uniquer_.GetUniqueName(computation->name()); |
77 | for (auto* instruction : computation->instructions()) { |
78 | instruction_name_uniquer_.GetUniqueName(instruction->name()); |
79 | } |
80 | } |
81 | |
82 | // Pick unique IDs for each instruction. |
83 | for (auto* instruction : computation->instructions()) { |
84 | instruction->SetUniqueId(NewUniqueInstructionId()); |
85 | } |
86 | // Set unique id to this computation. |
87 | CHECK_NE(computation->root_instruction()->unique_id(), -1) |
88 | << "Root has no valid id: " << computation->ToString(); |
89 | computation->SetUniqueId(computation->root_instruction()->unique_id()); |
90 | |
91 | computation->set_parent(this); |
92 | computations_.push_back(std::move(computation)); |
93 | return computations_.back().get(); |
94 | } |
95 | |
96 | HloComputation* HloModule::AddEntryComputation( |
97 | std::unique_ptr<HloComputation> computation) { |
98 | return AddComputationInternal(std::move(computation), /*is_entry=*/true, |
99 | /*uniquify_names=*/true); |
100 | } |
101 | |
102 | Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { |
103 | auto it = |
104 | std::find_if(computations_.begin(), computations_.end(), |
105 | [&to_remove](const std::unique_ptr<HloComputation>& comp) { |
106 | return comp.get() == to_remove; |
107 | }); |
108 | TF_RET_CHECK(it->get() == to_remove); |
109 | computations_.erase(it); |
110 | return Status::OK(); |
111 | } |
112 | |
113 | HloComputation* HloModule::AddEmbeddedComputation( |
114 | std::unique_ptr<HloComputation> computation) { |
115 | return AddComputationInternal(std::move(computation), /*is_entry=*/false, |
116 | /*uniquify_names=*/true); |
117 | } |
118 | |
119 | void HloModule::ReplaceComputations( |
120 | const std::unordered_map<HloComputation*, HloComputation*>& replacements) { |
121 | // Replace all uses of non-canonical computations with their |
122 | // representatives. |
123 | std::vector<std::unique_ptr<HloComputation>> new_computations; |
124 | new_computations.reserve(computations_.size()); |
125 | |
126 | for (std::unique_ptr<HloComputation>& computation : computations_) { |
127 | for (auto* instruction : computation->instructions()) { |
128 | switch (instruction->opcode()) { |
129 | case HloOpcode::kCall: |
130 | case HloOpcode::kMap: |
131 | case HloOpcode::kReduce: |
132 | case HloOpcode::kReduceWindow: { |
133 | HloComputation* new_arg = tensorflow::gtl::FindWithDefault( |
134 | replacements, instruction->to_apply(), nullptr); |
135 | if (new_arg != nullptr) { |
136 | instruction->set_to_apply(new_arg); |
137 | } |
138 | break; |
139 | } |
140 | case HloOpcode::kWhile: { |
141 | HloComputation* new_condition = tensorflow::gtl::FindWithDefault( |
142 | replacements, instruction->while_condition(), nullptr); |
143 | if (new_condition != nullptr) { |
144 | instruction->set_while_condition(new_condition); |
145 | } |
146 | HloComputation* new_body = tensorflow::gtl::FindWithDefault( |
147 | replacements, instruction->while_body(), nullptr); |
148 | if (new_body != nullptr) { |
149 | instruction->set_while_body(new_body); |
150 | } |
151 | break; |
152 | } |
153 | case HloOpcode::kConditional: { |
154 | HloComputation* new_true_computation = |
155 | tensorflow::gtl::FindWithDefault( |
156 | replacements, instruction->true_computation(), nullptr); |
157 | if (new_true_computation != nullptr) { |
158 | instruction->set_true_computation(new_true_computation); |
159 | } |
160 | HloComputation* new_false_computation = |
161 | tensorflow::gtl::FindWithDefault( |
162 | replacements, instruction->false_computation(), nullptr); |
163 | if (new_false_computation != nullptr) { |
164 | instruction->set_false_computation(new_false_computation); |
165 | } |
166 | break; |
167 | } |
168 | case HloOpcode::kSelectAndScatter: { |
169 | HloComputation* new_select = tensorflow::gtl::FindWithDefault( |
170 | replacements, instruction->select(), nullptr); |
171 | if (new_select != nullptr) { |
172 | instruction->set_select(new_select); |
173 | } |
174 | HloComputation* new_scatter = tensorflow::gtl::FindWithDefault( |
175 | replacements, instruction->scatter(), nullptr); |
176 | if (new_scatter != nullptr) { |
177 | instruction->set_scatter(new_scatter); |
178 | } |
179 | break; |
180 | } |
181 | default: |
182 | break; |
183 | } |
184 | } |
185 | |
186 | if (replacements.find(computation.get()) == replacements.end()) { |
187 | new_computations.push_back(std::move(computation)); |
188 | } |
189 | } |
190 | |
191 | // Replace entry_computation if necessary. |
192 | entry_computation_ = tensorflow::gtl::FindWithDefault( |
193 | replacements, entry_computation_, entry_computation_); |
194 | |
195 | computations_ = std::move(new_computations); |
196 | } |
197 | |
198 | string HloModule::ToString(const HloPrintOptions& options) const { |
199 | std::ostringstream s; |
200 | s << "HloModule " << name() << "\n\n" ; |
201 | for (const HloComputation* computation : MakeComputationPostOrder()) { |
202 | if (computation == entry_computation()) { |
203 | s << "ENTRY " ; |
204 | } |
205 | s << computation->ToString(options) << "\n\n" ; |
206 | } |
207 | return s.str(); |
208 | } |
209 | |
210 | HloModuleProto HloModule::ToProto() const { |
211 | HloModuleProto proto; |
212 | proto.set_id(unique_id_); |
213 | proto.set_name(name_); |
214 | proto.set_entry_computation_name(entry_computation_->name()); |
215 | proto.set_entry_computation_id(entry_computation_->unique_id()); |
216 | for (const HloComputation* computation : MakeComputationPostOrder()) { |
217 | HloComputationProto computation_proto = computation->ToProto(); |
218 | if (computation->name() == entry_computation_->name()) { |
219 | *proto.mutable_program_shape() = computation_proto.program_shape(); |
220 | } |
221 | proto.add_computations()->Swap(&computation_proto); |
222 | } |
223 | return proto; |
224 | } |
225 | |
226 | /* static */ |
227 | StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( |
228 | const HloModuleProto& proto, const HloModuleConfig& module_config, |
229 | const VersionedComputationHandle& entry_computation_handle) { |
230 | // The ProgramShape in the passed in module config must match the shapes of |
231 | // the entry parameters and root. |
232 | TF_RET_CHECK(proto.has_program_shape()) |
233 | << "No program shape found in the proto" ; |
234 | const auto& expected_program_shape = proto.program_shape(); |
235 | TF_RET_CHECK(expected_program_shape.parameters_size() == |
236 | module_config.entry_computation_layout().parameter_count()); |
237 | for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { |
238 | const Shape& parameter_shape = |
239 | module_config.entry_computation_layout().parameter_layout(i).shape(); |
240 | TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), |
241 | parameter_shape)) |
242 | << "HloModuleConfig has different shape for parameter " << i |
243 | << " than the HLO module. Expected: " |
244 | << ShapeUtil::HumanStringWithLayout( |
245 | expected_program_shape.parameters(i)) |
246 | << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); |
247 | } |
248 | const Shape& result_shape = |
249 | module_config.entry_computation_layout().result_layout().shape(); |
250 | TF_RET_CHECK( |
251 | ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) |
252 | << "HloModuleConfig has different result shape than the HLO module. " |
253 | "Expected: " |
254 | << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) |
255 | << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); |
256 | |
257 | auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle, |
258 | module_config); |
259 | |
260 | tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map; |
261 | for (const HloComputationProto& computation_proto : proto.computations()) { |
262 | TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation, |
263 | HloComputation::CreateFromProto( |
264 | module.get(), computation_proto, computation_map)); |
265 | CHECK_NE(computation.get(), nullptr); |
266 | int64 computation_id = computation_proto.id(); |
267 | TF_RET_CHECK(computation_id != -1); |
268 | TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); |
269 | // Don't uniquify names because we want names to be stable across |
270 | // serialization and deserialization. |
271 | computation_map[computation_id] = module->AddComputationInternal( |
272 | std::move(computation), |
273 | /*is_entry=*/proto.entry_computation_id() == computation_id, |
274 | /*uniquify_names=*/false); |
275 | } |
276 | TF_RET_CHECK(module->entry_computation_ != nullptr); |
277 | |
278 | // Because we didn't uniquify the names, double-check that the instruction and |
279 | // computation names are unique from the proto. |
280 | tensorflow::gtl::FlatSet<string> computation_names; |
281 | tensorflow::gtl::FlatSet<string> instruction_names; |
282 | for (HloComputation* computation : module->computations()) { |
283 | TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) |
284 | << "Computation name is not unique: " << computation->name(); |
285 | computation_names.insert(computation->name()); |
286 | for (HloInstruction* instruction : computation->instructions()) { |
287 | TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) |
288 | << "Instruction name is not unique: " << instruction->name(); |
289 | instruction_names.insert(instruction->name()); |
290 | } |
291 | } |
292 | |
293 | return std::move(module); |
294 | } |
295 | |
296 | /* static */ |
297 | StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto( |
298 | const HloModuleProto& module, const DebugOptions& debug_options) { |
299 | TF_RET_CHECK(module.has_program_shape()) |
300 | << "No program shape found in the proto" ; |
301 | const auto& program_shape = module.program_shape(); |
302 | |
303 | HloModuleConfig module_config(program_shape); |
304 | module_config.set_debug_options(debug_options); |
305 | |
306 | // The module config is constructed with default layouts regardless of what is |
307 | // passed in via the ProgramShape. Set the layouts to the appropriate values. |
308 | ComputationLayout* entry_layout = |
309 | module_config.mutable_entry_computation_layout(); |
310 | for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { |
311 | TF_RETURN_IF_ERROR( |
312 | entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( |
313 | program_shape.parameters(i))); |
314 | } |
315 | TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( |
316 | program_shape.result())); |
317 | |
318 | return module_config; |
319 | } |
320 | |
321 | namespace { |
322 | // Returns whether `hlo` is used outside the given subcomputation. |
323 | // `instructions_in_subcomputation` is the instruction set of the given |
324 | // subcomputation. |
325 | bool IsUsedOutsideSubcomputation( |
326 | const HloInstruction& hlo, |
327 | const std::unordered_set<HloInstruction*>& instructions_in_subcomputation) { |
328 | for (HloInstruction* user : hlo.users()) { |
329 | if (!instructions_in_subcomputation.count(user)) { |
330 | return true; |
331 | } |
332 | } |
333 | return false; |
334 | } |
335 | } // anonymous namespace |
336 | |
337 | HloInstruction* HloModule::OutlineExpressionFromComputation( |
338 | tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline, |
339 | const string& outlined_computation_name, HloComputation* computation) { |
340 | auto builder = HloComputation::Builder(outlined_computation_name); |
341 | |
342 | // A map from original instructions to their counterparts in the new outlined |
343 | // function. |
344 | std::unordered_map<HloInstruction*, HloInstruction*> outlined_instructions; |
345 | // A set that contains all instructions to be outlined. |
346 | std::unordered_set<HloInstruction*> instruction_set_to_outline( |
347 | instructions_to_outline.begin(), instructions_to_outline.end()); |
348 | std::vector<HloInstruction*> arguments; |
349 | std::vector<HloInstruction*> outputs; |
350 | int64 parameter_count = 0; |
351 | for (HloInstruction* instruction_to_outline : instructions_to_outline) { |
352 | // Clone the original instruction. |
353 | HloInstruction* outlined_instruction = |
354 | builder.AddInstruction(instruction_to_outline->Clone()); |
355 | |
356 | // Replace its operands to their counterparts in the new function. |
357 | for (int64 operand_num = 0; |
358 | operand_num < outlined_instruction->operand_count(); ++operand_num) { |
359 | HloInstruction* old_operand = |
360 | outlined_instruction->mutable_operand(operand_num); |
361 | |
362 | HloInstruction** operand_slot = &(outlined_instructions[old_operand]); |
363 | if (*operand_slot == nullptr) { |
364 | // Because instructions_to_outline is in topological order, if |
365 | // old_operand is not in outlined_instructions, old_operand must be an |
366 | // input of the outlined subcomputation and thus should be represented |
367 | // as a parameter in the new function. |
368 | arguments.push_back(old_operand); |
369 | *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( |
370 | parameter_count, old_operand->shape(), "" )); |
371 | ++parameter_count; |
372 | } |
373 | TF_CHECK_OK( |
374 | outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot)); |
375 | } |
376 | |
377 | // Insert the new instruction into the outlined_instructions map. |
378 | InsertOrDie(&outlined_instructions, instruction_to_outline, |
379 | outlined_instruction); |
380 | |
381 | // Mark instruction_to_outline an output if it is used outside the |
382 | // subcomputation or is the output of the original computation (i.e. used |
383 | // externally). |
384 | if (instruction_to_outline->user_count() == 0 || |
385 | IsUsedOutsideSubcomputation(*instruction_to_outline, |
386 | instruction_set_to_outline)) { |
387 | outputs.push_back(instruction_to_outline); |
388 | } |
389 | } |
390 | |
391 | if (outputs.size() != 1) { |
392 | string error_message = |
393 | "The subcomputation to outline has multiple outputs:\n" ; |
394 | for (HloInstruction* output : outputs) { |
395 | tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n" ); |
396 | } |
397 | LOG(FATAL) << error_message; |
398 | } |
399 | HloInstruction* output = outputs[0]; |
400 | |
401 | // Creates a call to the nested computation. |
402 | HloComputation* nested_computation = AddEmbeddedComputation( |
403 | builder.Build(FindOrDie(outlined_instructions, output))); |
404 | HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall( |
405 | output->shape(), arguments, nested_computation)); |
406 | |
407 | VLOG(2) << "Outlining the following instructions" ; |
408 | for (auto* instruction_to_outline : instructions_to_outline) { |
409 | VLOG(2) << " " << instruction_to_outline->ToString(); |
410 | } |
411 | VLOG(2) << "as a call " << call->ToString(); |
412 | VLOG(2) << "to " << nested_computation->ToString(); |
413 | |
414 | TF_CHECK_OK(output->ReplaceAllUsesWith(call)); |
415 | for (auto i = instructions_to_outline.rbegin(); |
416 | i != instructions_to_outline.rend(); ++i) { |
417 | TF_CHECK_OK(computation->RemoveInstruction(*i)); |
418 | } |
419 | |
420 | return call; |
421 | } |
422 | |
423 | int64 HloModule::instruction_count() const { |
424 | int64 n = 0; |
425 | for (const auto& computation : computations_) { |
426 | n += computation->instruction_count(); |
427 | } |
428 | return n; |
429 | } |
430 | |
431 | std::list<HloComputation*> HloModule::MakeComputationPostOrder() const { |
432 | // First determine all root computations by building a set of nonroot |
433 | // computations (computations which are called by an instruction in the |
434 | // module). |
435 | std::set<HloComputation*> nonroot_computations; |
436 | for (auto& computation : computations_) { |
437 | for (auto* instruction : computation->instructions()) { |
438 | for (HloComputation* called_computation : |
439 | instruction->called_computations()) { |
440 | nonroot_computations.insert(called_computation); |
441 | } |
442 | } |
443 | } |
444 | |
445 | // Keep track of computations which have already been added to the post |
446 | // order. This prevents duplication as an embedded computation may be called |
447 | // from two different root computations. |
448 | std::set<HloComputation*> added_computations; |
449 | std::list<HloComputation*> post_order; |
450 | for (auto& computation : computations_) { |
451 | if (nonroot_computations.count(computation.get()) == 0) { |
452 | for (HloComputation* embedded_computation : |
453 | computation->MakeEmbeddedComputationsList()) { |
454 | if (added_computations.count(embedded_computation) == 0) { |
455 | post_order.push_back(embedded_computation); |
456 | added_computations.insert(embedded_computation); |
457 | } |
458 | } |
459 | // Root computations should only be encountered once. |
460 | CHECK_EQ(0, added_computations.count(computation.get())); |
461 | post_order.push_back(computation.get()); |
462 | added_computations.insert(computation.get()); |
463 | } |
464 | } |
465 | CHECK_EQ(post_order.size(), computations_.size()); |
466 | return post_order; |
467 | } |
468 | |
469 | std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const { |
470 | std::vector<HloComputation*> result; |
471 | for (auto* c : computations()) { |
472 | if (c->IsFusionComputation()) { |
473 | continue; |
474 | } |
475 | result.push_back(c); |
476 | } |
477 | return result; |
478 | } |
479 | |
480 | std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const { |
481 | VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n" ; |
482 | auto module = MakeUnique<HloModule>(name_ + "-" + suffix); |
483 | module->config_ = config_; |
484 | module->entry_computation_handle_ = entry_computation_handle_; |
485 | module->has_entry_computation_handle_ = has_entry_computation_handle_; |
486 | |
487 | std::unordered_map<HloComputation*, HloComputation*> clone_map; |
488 | for (auto& computation : computations_) { |
489 | if (computation->IsFusionComputation()) { |
490 | // Cloning of a fused computation is handled by its fusion instruction. |
491 | continue; |
492 | } |
493 | |
494 | // When cloning a computation, pass in the new module, so that for any |
495 | // fusion instruction in this computation, the fused computation will be |
496 | // deep cloned to the new module. |
497 | auto cloned_computation = computation->Clone(suffix, module.get()); |
498 | InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); |
499 | |
500 | if (entry_computation_ == computation.get()) { |
501 | module->AddEntryComputation(std::move(cloned_computation)); |
502 | } else { |
503 | module->AddEmbeddedComputation(std::move(cloned_computation)); |
504 | } |
505 | } |
506 | |
507 | for (auto& cloned_computation : module->computations_) { |
508 | for (auto* instruction : cloned_computation->instructions()) { |
509 | // Rewrite instruction's called_computation to point to the cloned |
510 | // computations. |
511 | instruction->ReplaceCalledComputations([&](HloComputation* hlo) { |
512 | if (hlo->IsFusionComputation()) { |
513 | // Cloning of a fused computation has already been handled when its |
514 | // fusion instruction is cloned. So this hlo computation is already |
515 | // the cloned one. |
516 | return hlo; |
517 | } |
518 | return FindOrDie(clone_map, hlo); |
519 | }); |
520 | } |
521 | } |
522 | return module; |
523 | } |
524 | |
525 | HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { |
526 | HloComputation* clone = AddEmbeddedComputation(computation->Clone("" , this)); |
527 | TF_CHECK_OK( |
528 | clone->root_instruction()->Accept([this](HloInstruction* instruction) { |
529 | instruction->ReplaceCalledComputations([this](HloComputation* callee) { |
530 | return DeepCloneComputation(callee); |
531 | }); |
532 | return Status::OK(); |
533 | })); |
534 | return clone; |
535 | } |
536 | |
537 | uint64 HloModule::RandomNew64() const { |
538 | tensorflow::mutex_lock l(rng_mutex_); |
539 | return rng_(); |
540 | } |
541 | |
542 | /* static */ std::atomic<int> HloModule::next_unique_module_id_(0); |
543 | |
544 | } // namespace xla |
545 | |