1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ |
17 | #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ |
18 | |
19 | #include <ostream> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
24 | #include "tensorflow/compiler/xla/shape_tree.h" |
25 | #include "tensorflow/compiler/xla/types.h" |
26 | #include "tensorflow/compiler/xla/xla_data.pb.h" |
27 | #include "tensorflow/core/lib/gtl/array_slice.h" |
28 | #include "tensorflow/core/platform/macros.h" |
29 | |
30 | namespace xla { |
31 | |
32 | // Abstraction which identifies a specific point in the XLA graph. An |
33 | // HloPosition specifies a ShapeIndex within the output of a specific |
34 | // instruction. |
35 | struct HloPosition { |
36 | HloInstruction* instruction; |
37 | ShapeIndex index; |
38 | |
39 | // Returns the shape at this position. |
40 | const Shape& shape() const; |
41 | |
42 | string ToString() const; |
43 | |
44 | bool operator==(const HloPosition& other) const { |
45 | return instruction == other.instruction && index == other.index; |
46 | } |
47 | bool operator!=(const HloPosition& other) const { return !(*this == other); } |
48 | |
49 | // Stable less-than operator using instruction id and index. |
50 | bool operator<(const HloPosition& other) const { |
51 | return instruction->unique_id() < other.instruction->unique_id() || |
52 | (instruction->unique_id() == other.instruction->unique_id() && |
53 | index < other.index); |
54 | } |
55 | }; |
56 | |
57 | std::ostream& operator<<(std::ostream& out, const HloPosition& position); |
58 | |
59 | // Defines a single use of an HLO value. |
60 | struct HloUse { |
61 | // Instruction at which the value is used. |
62 | HloInstruction* instruction; |
63 | |
64 | // The operand number in which the value is appears. |
65 | int64 operand_number; |
66 | |
67 | // The shape index within the operand in which the value appears. |
68 | ShapeIndex operand_index; |
69 | |
70 | string ToString() const; |
71 | |
72 | bool operator==(const HloUse& other) const { |
73 | return instruction == other.instruction && |
74 | operand_number == other.operand_number && |
75 | operand_index == other.operand_index; |
76 | } |
77 | |
78 | bool operator!=(const HloUse& other) const { return !(*this == other); } |
79 | }; |
80 | |
81 | std::ostream& operator<<(std::ostream& out, const HloUse& use); |
82 | |
83 | // Class describing a value used by the dataflow analysis. XLA arrays are |
84 | // trivially a single HloValue. Tuples are made up of more than one HloValue: an |
85 | // HloValue for the pointer vector, and an HloValue for each child element. |
86 | // |
87 | // Every HloValue is defined by a particular instruction and most instructions |
88 | // define only a single HloValue. Instructions which define a single HloValue |
89 | // include array-shaped instructions such as Add but also includes Tuple-shaped |
90 | // instructions such as Tuple. The Tuple instruction defines a single HloValue |
91 | // which is a vector of pointers to the values containing the Tuple |
92 | // instruction's operands. Though the result of the Tuple instruction includes |
93 | // multiple values only the top-level HloValue (the vector of pointers) is |
94 | // defined by the Tuple instruction. The values containing the tuple elements |
95 | // are defined by earlier instructions, usually the operands of the Tuple |
96 | // instruction. |
97 | // |
98 | // Instructions which construct both the tuple *and* the tuple elements define |
99 | // more than one HloValue. This includes (at least) tuple-shaped Constant, |
100 | // Parameter, Infeed and While instructions. These tuple-shaped instructions do |
101 | // not assemble a tuple from existing HloValues like the Tuple instruction does, |
102 | // but rather define all the HloValues in the tuple. |
103 | class HloValue { |
104 | public: |
105 | using Id = int64; |
106 | |
107 | // Predicate comparing HloValues by increasing id, useful for std::sort. |
108 | static bool IdLessThan(const HloValue* a, const HloValue* b) { |
109 | return a->id() < b->id(); |
110 | } |
111 | |
112 | // Predicate comparing HloValues by equal id, useful for std::unique. |
113 | static bool IdEqual(const HloValue* a, const HloValue* b) { |
114 | return a->id() == b->id(); |
115 | } |
116 | |
117 | // Construct an HloValue defined by 'instruction' at shape index 'index'. If |
118 | // is_phi is true, then this value is a phi value, for example, at the |
119 | // parameter of a while body computation. Phi values are only used in the SSA |
120 | // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). |
121 | HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, |
122 | bool is_phi = false); |
123 | |
124 | // Sets the positions in the module at which the HloValue appears. Updates |
125 | // uses. Should be called once and only once. The defining position should not |
126 | // be included in 'positions' as this is set at construction time. |
127 | void SetPositionsAndComputeUses( |
128 | tensorflow::gtl::ArraySlice<HloPosition> positions); |
129 | |
130 | // Return a unique identifier for this HloValue. This value is used for stable |
131 | // sorting and iteration |
132 | Id id() const { return id_; } |
133 | |
134 | // Returns whether this value is a phi value. |
135 | bool is_phi() const { return is_phi_; } |
136 | |
137 | // Return the position where this value is defined. |
138 | const HloPosition& defining_position() const { return positions_[0]; } |
139 | |
140 | // Return the instruction which defines this HloValue. |
141 | HloInstruction* defining_instruction() const { |
142 | return defining_position().instruction; |
143 | } |
144 | |
145 | // Return the shape index at which this HloValue is defined in the output of |
146 | // its defining instruction. |
147 | const ShapeIndex& defining_index() const { return defining_position().index; } |
148 | |
149 | // Return the shape of this HloValue. |
150 | const Shape& shape() const { return defining_position().shape(); } |
151 | |
152 | // Return all positions of the HloValue in the module. |
153 | const std::vector<HloPosition>& positions() const { return positions_; } |
154 | |
155 | // Return all uses of the HloValue. |
156 | const std::vector<HloUse>& uses() const { return uses_; } |
157 | |
158 | // Get whether this HloValue is live out of the module. |
159 | bool live_out_of_module() const { return live_out_of_module_; } |
160 | |
161 | bool operator==(const HloValue& other) const; |
162 | bool operator!=(const HloValue& other) const; |
163 | |
164 | // Return a single-line string representation of the value. |
165 | string ToShortString() const; |
166 | |
167 | string ToString(int indent = 0) const; |
168 | |
169 | private: |
170 | // Unique identifier for this HloValue. Used for stable sorting and iteration. |
171 | const Id id_; |
172 | |
173 | // Whether this instruction is a phi value. |
174 | const bool is_phi_; |
175 | |
176 | // The set of positions of this HloValue. The first element is always the |
177 | // position of the definition. |
178 | std::vector<HloPosition> positions_; |
179 | |
180 | // The set of uses of this HloValue. |
181 | std::vector<HloUse> uses_; |
182 | |
183 | // Whether this value is live out of the HLO module. |
184 | bool live_out_of_module_ = false; |
185 | |
186 | // Whether this value is live out of its computation. |
187 | bool live_out_of_computation_ = false; |
188 | }; |
189 | |
190 | std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); |
191 | |
192 | // A class representing the possible set of HloValues at a particular point |
193 | // (shape index in the output of an instruction) in the XLA graph. This set |
194 | // contains the set of reaching HloValue definitions. For a simple array-shaped |
195 | // instruction like Add, the HloValueSet of the top-level of the instruction's |
196 | // output trivially contains only the HloValue defined by the instruction. For |
197 | // instructions which have non-trivial dataflow such as Tuple or Select, the |
198 | // HloValueSets of the instruction's output contains one or more HloValues |
199 | // defined by the instruction's operands or defined further up in the XLA graph. |
200 | class HloValueSet { |
201 | public: |
202 | HloValueSet() = default; |
203 | |
204 | explicit HloValueSet(tensorflow::gtl::ArraySlice<const HloValue*> values) |
205 | : values_(values.begin(), values.end()) { |
206 | SortAndUniquifyValues(); |
207 | } |
208 | |
209 | // Sets this value set to the union of the given value sets. Returns whether |
210 | // this value set changed. |
211 | bool AssignUnionOf(tensorflow::gtl::ArraySlice<const HloValueSet*> inputs); |
212 | |
213 | // Return the vector of HloValues in the set. Values in the vector are unique |
214 | // and stably sorted by value id. |
215 | const std::vector<const HloValue*>& values() const { return values_; } |
216 | |
217 | // Adds the value to the set. Returns true iff the value was added and didn't |
218 | // already exist in the set. |
219 | bool AddValue(const HloValue* value); |
220 | |
221 | // Clear all values from the set. |
222 | void Clear() { values_.clear(); } |
223 | |
224 | // Return the unique HLO value in the set. CHECKs if the set does not contain |
225 | // exactly one value. |
226 | const HloValue& GetUniqueValue() const { |
227 | CHECK_EQ(values_.size(), 1); |
228 | return *values_[0]; |
229 | } |
230 | |
231 | bool operator==(const HloValueSet& other) const { |
232 | if (values_.size() != other.values_.size()) return false; |
233 | for (size_t i = 0; i < values_.size(); ++i) { |
234 | if (values_[i]->id() != other.values_[i]->id()) { |
235 | return false; |
236 | } |
237 | } |
238 | return true; |
239 | } |
240 | bool operator!=(const HloValueSet& other) const { return !(*this == other); } |
241 | |
242 | string ToString() const; |
243 | |
244 | private: |
245 | // Sorts value_ and removes duplicates. This should be called after adding any |
246 | // elements to values_. |
247 | void SortAndUniquifyValues(); |
248 | |
249 | // HloValues sorted by HloValue::Id. |
250 | std::vector<const HloValue*> values_; |
251 | }; |
252 | |
253 | std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); |
254 | |
255 | // A class collecting the HloValues which might be contained in the output of |
256 | // an HLO instruction. For array-shaped instructions, an InstructionValueSet |
257 | // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets |
258 | // hold multiple HloValueSets. |
259 | class InstructionValueSet : public ShapeTree<HloValueSet> { |
260 | public: |
261 | InstructionValueSet(const Shape& shape) : ShapeTree<HloValueSet>(shape) {} |
262 | |
263 | // Sets this value set to the union of the given value sets. Returns whether |
264 | // this value set changed. |
265 | bool AssignUnionOf( |
266 | tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs); |
267 | |
268 | string ToString() const; |
269 | }; |
270 | |
271 | std::ostream& operator<<(std::ostream& out, |
272 | const InstructionValueSet& instruction_value_set); |
273 | |
274 | } // namespace xla |
275 | |
276 | #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ |
277 | |