1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |
17 | #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |
18 | |
19 | #include <atomic> |
20 | #include <list> |
21 | #include <memory> |
22 | #include <random> |
23 | #include <string> |
24 | #include <unordered_map> |
25 | #include <vector> |
26 | |
27 | #include "tensorflow/compiler/xla/iterator_util.h" |
28 | #include "tensorflow/compiler/xla/service/hlo.pb.h" |
29 | #include "tensorflow/compiler/xla/service/hlo_computation.h" |
30 | #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
31 | #include "tensorflow/compiler/xla/service/hlo_module_config.h" |
32 | #include "tensorflow/compiler/xla/service/name_uniquer.h" |
33 | #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" |
34 | #include "tensorflow/compiler/xla/types.h" |
35 | #include "tensorflow/core/lib/gtl/array_slice.h" |
36 | #include "tensorflow/core/lib/gtl/iterator_range.h" |
37 | #include "tensorflow/core/platform/logging.h" |
38 | #include "tensorflow/core/platform/mutex.h" |
39 | |
40 | namespace xla { |
41 | |
42 | // Describes a compilation unit at the HLO level. |
43 | // |
44 | // A HLO module contains one or more HLO computations. The module contains one |
45 | // "entry" computation which produces the result. The module also includes any |
46 | // embedded computations used by instructions such as "map" and "reduce". All |
47 | // computations are owned by the module. |
48 | class HloModule { |
49 | public: |
50 | HloModule(const string& name, |
51 | const VersionedComputationHandle& entry_computation_handle, |
52 | const HloModuleConfig& config); |
53 | |
54 | // Constructor without a versioned computation handle. This constructor should |
55 | // only be used for HloModules used outside of the XLA service (eg |
56 | // tests). The versioned handle is used by the service in the compilation |
57 | // cache. A default configuration is created for this module. |
58 | explicit HloModule(const string& name); |
59 | explicit HloModule(const string& name, const HloModuleConfig& config); |
60 | |
61 | // Adds an entry computation to the module. A module can only have one entry |
62 | // computation. Returns a pointer to the newly added computation. |
63 | HloComputation* AddEntryComputation( |
64 | std::unique_ptr<HloComputation> computation); |
65 | |
66 | // Adds an embedded computation to the module. |
67 | HloComputation* AddEmbeddedComputation( |
68 | std::unique_ptr<HloComputation> computation); |
69 | |
70 | // Removes an embedded computation. |
71 | Status RemoveEmbeddedComputation(HloComputation* to_remove); |
72 | |
73 | // Replaces all uses of computations that are keys of 'replacements' with |
74 | // the corresponding values in 'replacements'. Replaces the entry computation, |
75 | // if applicable. |
76 | // |
77 | // This function iterates over all instructions in the module to find |
78 | // computations to replace. We could speed it up by keeping track of users of |
79 | // computations. |
80 | void ReplaceComputations( |
81 | const std::unordered_map<HloComputation*, HloComputation*>& replacements); |
82 | |
83 | const string& name() const { return name_; } |
84 | |
85 | // Returns a deep copy of this module including all computations. |
86 | std::unique_ptr<HloModule> Clone(const string& suffix = "clone" ) const; |
87 | |
88 | // Performs a deep clone of the computation, by recursively cloning all |
89 | // the called computations as well. |
90 | HloComputation* DeepCloneComputation(HloComputation* computation); |
91 | |
92 | // Return a pointer to the entry computation of the module.. |
93 | const HloComputation* entry_computation() const { |
94 | CHECK_NE(nullptr, entry_computation_); |
95 | return entry_computation_; |
96 | } |
97 | HloComputation* entry_computation() { |
98 | CHECK_NE(nullptr, entry_computation_); |
99 | return entry_computation_; |
100 | } |
101 | |
102 | ComputationLayout* mutable_entry_computation_layout() { |
103 | return config_.mutable_entry_computation_layout(); |
104 | } |
105 | |
106 | const ComputationLayout& entry_computation_layout() const { |
107 | return config_.entry_computation_layout(); |
108 | } |
109 | |
110 | const VersionedComputationHandle& entry_computation_handle() const { |
111 | return entry_computation_handle_; |
112 | } |
113 | |
114 | // Gets the computations in this module. |
115 | // |
116 | // Returns a view of HloComputation*s, so you can iterate over this in the |
117 | // natural way: |
118 | // |
119 | // for (HloComputation* c : module->computations()) { ... } |
120 | // |
121 | tensorflow::gtl::iterator_range<UnwrappingIterator< |
122 | std::vector<std::unique_ptr<HloComputation>>::const_iterator>> |
123 | computations() const { |
124 | return {MakeUnwrappingIterator(computations_.begin()), |
125 | MakeUnwrappingIterator(computations_.end())}; |
126 | } |
127 | tensorflow::gtl::iterator_range<UnwrappingIterator< |
128 | std::vector<std::unique_ptr<HloComputation>>::iterator>> |
129 | computations() { |
130 | return {MakeUnwrappingIterator(computations_.begin()), |
131 | MakeUnwrappingIterator(computations_.end())}; |
132 | } |
133 | |
134 | // Gets the number of computations in this module. |
135 | int64 computation_count() const { return computations_.size(); } |
136 | |
137 | // Gets the number of instructions in this module. |
138 | int64 instruction_count() const; |
139 | |
140 | // Compute and return a post order of all computations in the module. The sort |
141 | // is defined like so: if computation A has an instruction which calls |
142 | // computation B, then A will appear after B in the sort. |
143 | std::list<HloComputation*> MakeComputationPostOrder() const; |
144 | |
145 | // Gets the computations in this module which aren't for fusion nodes. |
146 | // |
147 | // Postcondition: All computations in the returned list have |
148 | // !IsFusionComputation(). |
149 | // |
150 | // Note: Callers can and do rely on the return value here being a *snapshot* |
151 | // of the module's non-fusion computations -- that is, it's OK to add or |
152 | // remove computations from a module while iterating over |
153 | // MakeNonfusionComputations(). |
154 | std::vector<HloComputation*> MakeNonfusionComputations() const; |
155 | |
156 | const HloModuleConfig& config() const { return config_; } |
157 | |
158 | // Return a string representation of the module. |
159 | // |
160 | // (We express the default options using an overload rather than a default |
161 | // param because gdb ignores default params, but does resolve overloads.) |
162 | string ToString() const { return ToString(HloPrintOptions()); } |
163 | string ToString(const HloPrintOptions& options) const; |
164 | |
165 | // Convert an HloModule to or from a proto. |
166 | HloModuleProto ToProto() const; |
167 | static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( |
168 | const HloModuleProto& proto, const HloModuleConfig& module_config, |
169 | const VersionedComputationHandle& entry_computation_handle = |
170 | VersionedComputationHandle()); |
171 | |
172 | // Creates and returns an HloModuleConfig with an appropriate program shape |
173 | // for the HLO module in the given proto. |
174 | static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( |
175 | const HloModuleProto& module, const DebugOptions& debug_options); |
176 | |
177 | // Outlines the given expression from the given computation. |
178 | // instructions_to_outline contains the instructions that form the expression. |
179 | // |
180 | // Precondition: instructions in instructions_to_outline are in topological |
181 | // order (root of outlined instructions last). TODO(jingyue): takes a set of |
182 | // instructions and topologically sorts them. |
183 | HloInstruction* OutlineExpressionFromComputation( |
184 | tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline, |
185 | const string& outlined_computation_name, HloComputation* computation); |
186 | |
187 | // Returns a randomly generated uint64. |
188 | uint64 RandomNew64() const; |
189 | |
190 | // Returns the NameUniquer for uniquing instruction names in this module. |
191 | NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } |
192 | |
193 | // Assign a new unique dense id for an instruction |
194 | int NewUniqueInstructionId() { |
195 | int result = next_unique_id_; |
196 | next_unique_id_++; |
197 | return result; |
198 | } |
199 | |
200 | // Returns the number of unique intruction ids given out. All ids up to |
201 | // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) |
202 | int NumUniqueInstructionIds() const { return next_unique_id_; } |
203 | |
204 | // Returns an id that is unique to this module across all modules created over |
205 | // the lifetime of this process. |
206 | int unique_id() const { return unique_id_; } |
207 | |
208 | private: |
209 | HloComputation* AddComputationInternal( |
210 | std::unique_ptr<HloComputation> computation, bool is_entry, |
211 | bool uniquify_names); |
212 | |
213 | const string name_; |
214 | HloModuleConfig config_; |
215 | HloComputation* entry_computation_ = nullptr; |
216 | std::vector<std::unique_ptr<HloComputation>> computations_; |
217 | |
218 | // Random number generator engine to use when generating random numbers per |
219 | // HloModule compilation. |
220 | // TODO(b/25995601): Replace with better seed setting or dev/random for |
221 | // where we don't need deterministic execution. |
222 | mutable std::mt19937_64 rng_{42}; |
223 | mutable tensorflow::mutex rng_mutex_; |
224 | |
225 | // Versioned handle of the entry computation of the module. |
226 | bool has_entry_computation_handle_ = false; |
227 | VersionedComputationHandle entry_computation_handle_; |
228 | |
229 | // Unique name generator for computation and instruction names, which are |
230 | // unique per module. |
231 | NameUniquer computation_name_uniquer_{/*separator=*/"." }; |
232 | NameUniquer instruction_name_uniquer_{/*separator=*/"." }; |
233 | int next_unique_id_ = 0; |
234 | |
235 | // Used to keep track of the next unique module id that should be assigned. |
236 | static std::atomic<int> next_unique_module_id_; |
237 | // A unique id to label modules with. |
238 | int unique_id_; |
239 | }; |
240 | |
241 | } // namespace xla |
242 | |
243 | #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |
244 | |