1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
18
19#include <atomic>
20#include <list>
21#include <memory>
22#include <random>
23#include <string>
24#include <unordered_map>
25#include <vector>
26
27#include "tensorflow/compiler/xla/iterator_util.h"
28#include "tensorflow/compiler/xla/service/hlo.pb.h"
29#include "tensorflow/compiler/xla/service/hlo_computation.h"
30#include "tensorflow/compiler/xla/service/hlo_instruction.h"
31#include "tensorflow/compiler/xla/service/hlo_module_config.h"
32#include "tensorflow/compiler/xla/service/name_uniquer.h"
33#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
34#include "tensorflow/compiler/xla/types.h"
35#include "tensorflow/core/lib/gtl/array_slice.h"
36#include "tensorflow/core/lib/gtl/iterator_range.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/platform/mutex.h"
39
40namespace xla {
41
42// Describes a compilation unit at the HLO level.
43//
44// A HLO module contains one or more HLO computations. The module contains one
45// "entry" computation which produces the result. The module also includes any
46// embedded computations used by instructions such as "map" and "reduce". All
47// computations are owned by the module.
48class HloModule {
49 public:
50 HloModule(const string& name,
51 const VersionedComputationHandle& entry_computation_handle,
52 const HloModuleConfig& config);
53
54 // Constructor without a versioned computation handle. This constructor should
55 // only be used for HloModules used outside of the XLA service (eg
56 // tests). The versioned handle is used by the service in the compilation
57 // cache. A default configuration is created for this module.
58 explicit HloModule(const string& name);
59 explicit HloModule(const string& name, const HloModuleConfig& config);
60
61 // Adds an entry computation to the module. A module can only have one entry
62 // computation. Returns a pointer to the newly added computation.
63 HloComputation* AddEntryComputation(
64 std::unique_ptr<HloComputation> computation);
65
66 // Adds an embedded computation to the module.
67 HloComputation* AddEmbeddedComputation(
68 std::unique_ptr<HloComputation> computation);
69
70 // Removes an embedded computation.
71 Status RemoveEmbeddedComputation(HloComputation* to_remove);
72
73 // Replaces all uses of computations that are keys of 'replacements' with
74 // the corresponding values in 'replacements'. Replaces the entry computation,
75 // if applicable.
76 //
77 // This function iterates over all instructions in the module to find
78 // computations to replace. We could speed it up by keeping track of users of
79 // computations.
80 void ReplaceComputations(
81 const std::unordered_map<HloComputation*, HloComputation*>& replacements);
82
83 const string& name() const { return name_; }
84
85 // Returns a deep copy of this module including all computations.
86 std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
87
88 // Performs a deep clone of the computation, by recursively cloning all
89 // the called computations as well.
90 HloComputation* DeepCloneComputation(HloComputation* computation);
91
92 // Return a pointer to the entry computation of the module..
93 const HloComputation* entry_computation() const {
94 CHECK_NE(nullptr, entry_computation_);
95 return entry_computation_;
96 }
97 HloComputation* entry_computation() {
98 CHECK_NE(nullptr, entry_computation_);
99 return entry_computation_;
100 }
101
102 ComputationLayout* mutable_entry_computation_layout() {
103 return config_.mutable_entry_computation_layout();
104 }
105
106 const ComputationLayout& entry_computation_layout() const {
107 return config_.entry_computation_layout();
108 }
109
110 const VersionedComputationHandle& entry_computation_handle() const {
111 return entry_computation_handle_;
112 }
113
114 // Gets the computations in this module.
115 //
116 // Returns a view of HloComputation*s, so you can iterate over this in the
117 // natural way:
118 //
119 // for (HloComputation* c : module->computations()) { ... }
120 //
121 tensorflow::gtl::iterator_range<UnwrappingIterator<
122 std::vector<std::unique_ptr<HloComputation>>::const_iterator>>
123 computations() const {
124 return {MakeUnwrappingIterator(computations_.begin()),
125 MakeUnwrappingIterator(computations_.end())};
126 }
127 tensorflow::gtl::iterator_range<UnwrappingIterator<
128 std::vector<std::unique_ptr<HloComputation>>::iterator>>
129 computations() {
130 return {MakeUnwrappingIterator(computations_.begin()),
131 MakeUnwrappingIterator(computations_.end())};
132 }
133
134 // Gets the number of computations in this module.
135 int64 computation_count() const { return computations_.size(); }
136
137 // Gets the number of instructions in this module.
138 int64 instruction_count() const;
139
140 // Compute and return a post order of all computations in the module. The sort
141 // is defined like so: if computation A has an instruction which calls
142 // computation B, then A will appear after B in the sort.
143 std::list<HloComputation*> MakeComputationPostOrder() const;
144
145 // Gets the computations in this module which aren't for fusion nodes.
146 //
147 // Postcondition: All computations in the returned list have
148 // !IsFusionComputation().
149 //
150 // Note: Callers can and do rely on the return value here being a *snapshot*
151 // of the module's non-fusion computations -- that is, it's OK to add or
152 // remove computations from a module while iterating over
153 // MakeNonfusionComputations().
154 std::vector<HloComputation*> MakeNonfusionComputations() const;
155
156 const HloModuleConfig& config() const { return config_; }
157
158 // Return a string representation of the module.
159 //
160 // (We express the default options using an overload rather than a default
161 // param because gdb ignores default params, but does resolve overloads.)
162 string ToString() const { return ToString(HloPrintOptions()); }
163 string ToString(const HloPrintOptions& options) const;
164
165 // Convert an HloModule to or from a proto.
166 HloModuleProto ToProto() const;
167 static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
168 const HloModuleProto& proto, const HloModuleConfig& module_config,
169 const VersionedComputationHandle& entry_computation_handle =
170 VersionedComputationHandle());
171
172 // Creates and returns an HloModuleConfig with an appropriate program shape
173 // for the HLO module in the given proto.
174 static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
175 const HloModuleProto& module, const DebugOptions& debug_options);
176
177 // Outlines the given expression from the given computation.
178 // instructions_to_outline contains the instructions that form the expression.
179 //
180 // Precondition: instructions in instructions_to_outline are in topological
181 // order (root of outlined instructions last). TODO(jingyue): takes a set of
182 // instructions and topologically sorts them.
183 HloInstruction* OutlineExpressionFromComputation(
184 tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
185 const string& outlined_computation_name, HloComputation* computation);
186
187 // Returns a randomly generated uint64.
188 uint64 RandomNew64() const;
189
190 // Returns the NameUniquer for uniquing instruction names in this module.
191 NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; }
192
193 // Assign a new unique dense id for an instruction
194 int NewUniqueInstructionId() {
195 int result = next_unique_id_;
196 next_unique_id_++;
197 return result;
198 }
199
200 // Returns the number of unique intruction ids given out. All ids up to
201 // this point are guaranteed to be in the range [0..NumUniqueInstructionIds())
202 int NumUniqueInstructionIds() const { return next_unique_id_; }
203
204 // Returns an id that is unique to this module across all modules created over
205 // the lifetime of this process.
206 int unique_id() const { return unique_id_; }
207
208 private:
209 HloComputation* AddComputationInternal(
210 std::unique_ptr<HloComputation> computation, bool is_entry,
211 bool uniquify_names);
212
213 const string name_;
214 HloModuleConfig config_;
215 HloComputation* entry_computation_ = nullptr;
216 std::vector<std::unique_ptr<HloComputation>> computations_;
217
218 // Random number generator engine to use when generating random numbers per
219 // HloModule compilation.
220 // TODO(b/25995601): Replace with better seed setting or dev/random for
221 // where we don't need deterministic execution.
222 mutable std::mt19937_64 rng_{42};
223 mutable tensorflow::mutex rng_mutex_;
224
225 // Versioned handle of the entry computation of the module.
226 bool has_entry_computation_handle_ = false;
227 VersionedComputationHandle entry_computation_handle_;
228
229 // Unique name generator for computation and instruction names, which are
230 // unique per module.
231 NameUniquer computation_name_uniquer_{/*separator=*/"."};
232 NameUniquer instruction_name_uniquer_{/*separator=*/"."};
233 int next_unique_id_ = 0;
234
235 // Used to keep track of the next unique module id that should be assigned.
236 static std::atomic<int> next_unique_module_id_;
237 // A unique id to label modules with.
238 int unique_id_;
239};
240
241} // namespace xla
242
243#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
244