1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ |
17 | #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ |
18 | |
19 | #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" |
20 | #include "tensorflow/compiler/xla/service/hlo_computation.h" |
21 | #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
22 | #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
23 | #include "tensorflow/compiler/xla/shape_util.h" |
24 | #include "tensorflow/compiler/xla/statusor.h" |
25 | #include "tensorflow/compiler/xla/xla_data.pb.h" |
26 | #include "tensorflow/core/lib/gtl/array_slice.h" |
27 | #include "tensorflow/core/platform/macros.h" |
28 | #include "tensorflow/core/platform/types.h" |
29 | |
30 | namespace xla { |
31 | |
32 | // HloCostAnalysis traverses an HLO graph and calculates the amount of |
33 | // computations required for the graph. Each HLO instruction handler provides |
34 | // the computation cost of the instruction, and the values are accumulated |
35 | // during the traversal for the entire graph. We treat normal floating point |
36 | // operations separately from transcendental operations. |
37 | class HloCostAnalysis : public ConstDfsHloVisitor { |
38 | public: |
39 | // Each HLO is associated to a vector of properties with the indices given |
40 | // below. Sub-classes can add further properties. |
41 | typedef std::map<string, float> Properties; |
42 | static constexpr char kFlopsKey[] = "flops" ; |
43 | static constexpr char kTranscendentalsKey[] = "transcendentals" ; |
44 | static constexpr char kBytesAccessedKey[] = "bytes accessed" ; |
45 | static constexpr char kOptimalSecondsKey[] = "optimal_seconds" ; |
46 | |
47 | // shape_size is a function which returns the size in bytes of the top-level |
48 | // buffer of a shape. |
49 | using ShapeSizeFunction = std::function<int64(const Shape&)>; |
50 | explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); |
51 | |
52 | Status HandleElementwiseUnary(const HloInstruction* hlo) override; |
53 | Status HandleElementwiseBinary(const HloInstruction* hlo) override; |
54 | Status HandleConstant(const HloInstruction* constant) override; |
55 | Status HandleGetTupleElement( |
56 | const HloInstruction* get_tuple_element) override; |
57 | Status HandleSelect(const HloInstruction* select) override; |
58 | Status HandleCompare(const HloInstruction* compare) override; |
59 | Status HandleClamp(const HloInstruction* clamp) override; |
60 | Status HandleReducePrecision(const HloInstruction* hlo) override; |
61 | Status HandleConcatenate(const HloInstruction* concatenate) override; |
62 | Status HandleSend(const HloInstruction* send) override; |
63 | Status HandleSendDone(const HloInstruction* send_done) override; |
64 | Status HandleRecv(const HloInstruction* recv) override; |
65 | Status HandleRecvDone(const HloInstruction* recv_done) override; |
66 | Status HandleConvert(const HloInstruction* convert) override; |
67 | Status HandleCopy(const HloInstruction* copy) override; |
68 | Status HandleDot(const HloInstruction* dot) override; |
69 | Status HandleConvolution(const HloInstruction* convolution) override; |
70 | Status HandleFft(const HloInstruction* fft) override; |
71 | Status HandleCrossReplicaSum(const HloInstruction* crs) override; |
72 | Status HandleInfeed(const HloInstruction* infeed) override; |
73 | Status HandleOutfeed(const HloInstruction* outfeed) override; |
74 | Status HandleHostCompute(const HloInstruction* host_compute) override; |
75 | Status HandleRng(const HloInstruction* random) override; |
76 | Status HandleReverse(const HloInstruction* reverse) override; |
77 | Status HandleSort(const HloInstruction* sort) override; |
78 | Status HandleParameter(const HloInstruction* parameter) override; |
79 | Status HandleReduce(const HloInstruction* reduce) override; |
80 | Status HandleBatchNormTraining( |
81 | const HloInstruction* batch_norm_training) override; |
82 | Status HandleBatchNormInference( |
83 | const HloInstruction* batch_norm_inference) override; |
84 | Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override; |
85 | Status HandleFusion(const HloInstruction* fusion) override; |
86 | Status HandleCall(const HloInstruction* call) override; |
87 | Status HandleCustomCall(const HloInstruction* custom_call) override; |
88 | Status HandleSlice(const HloInstruction* slice) override; |
89 | Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override; |
90 | Status HandleDynamicUpdateSlice( |
91 | const HloInstruction* dynamic_update_slice) override; |
92 | Status HandleTuple(const HloInstruction* tuple) override; |
93 | Status HandleMap(const HloInstruction* map) override; |
94 | Status HandleReduceWindow(const HloInstruction* reduce_window) override; |
95 | Status HandleSelectAndScatter(const HloInstruction* instruction) override; |
96 | Status HandleBitcast(const HloInstruction* bitcast) override; |
97 | Status HandleBroadcast(const HloInstruction* broadcast) override; |
98 | Status HandleBroadcastDimOne(const HloInstruction* broadcastDimOne) override; |
99 | Status HandlePad(const HloInstruction* pad) override; |
100 | Status HandleReshape(const HloInstruction* reshape) override; |
101 | Status HandleTranspose(const HloInstruction* transpose) override; |
102 | Status HandleWhile(const HloInstruction* xla_while) override; |
103 | Status HandleConditional(const HloInstruction* conditional) override; |
104 | Status HandleGather(const HloInstruction* gather) override; |
105 | Status FinishVisit(const HloInstruction* root) override; |
106 | |
107 | Status Preprocess(const HloInstruction* hlo) override; |
108 | Status Postprocess(const HloInstruction* hlo) override; |
109 | |
110 | // Set the rates used to calculate the time taken by the computation. These |
111 | // need to be set before visiting starts. |
112 | void set_flops_per_second(float value) { |
113 | per_second_rates_[kFlopsKey] = value; |
114 | } |
115 | void set_transcendentals_per_second(float value) { |
116 | per_second_rates_[kTranscendentalsKey] = value; |
117 | } |
118 | void set_bytes_per_second(float value) { |
119 | per_second_rates_[kBytesAccessedKey] = value; |
120 | } |
121 | |
122 | // Returns properties for the computation. |
123 | float flop_count() const; |
124 | float transcendental_count() const; |
125 | float bytes_accessed() const; |
126 | float optimal_seconds() const; |
127 | |
128 | // Returns the respective cost computed for a particular HLO instruction, or 0 |
129 | // if the HLO was not found to have a cost in the analysis. |
130 | int64 flop_count(const HloInstruction& hlo) const; |
131 | int64 transcendental_count(const HloInstruction& hlo) const; |
132 | int64 bytes_accessed(const HloInstruction& hlo) const; |
133 | float optimal_seconds(const HloInstruction& hlo) const; |
134 | |
135 | const Properties& properties() const { return properties_sum_; } |
136 | const float property(const string& key) const { |
137 | return GetProperty(key, properties()); |
138 | } |
139 | |
140 | protected: |
141 | typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties; |
142 | |
143 | // An FMA counts as two floating point operations in these analyzes. |
144 | static constexpr int64 kFmaFlops = 2; |
145 | |
146 | HloCostAnalysis(const ShapeSizeFunction& shape_size, |
147 | const Properties& per_second_rates); |
148 | |
149 | // Returns the properties computed from visiting the computation rooted at the |
150 | // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null, |
151 | // otherwise uses shape_size_. |
152 | StatusOr<Properties> ProcessSubcomputation( |
153 | HloComputation* computation, |
154 | const ShapeSizeFunction* shape_size = nullptr); |
155 | |
156 | // Utility function to handle all element-wise operations. |
157 | Status HandleElementwiseOp(const HloInstruction* hlo_instruction); |
158 | |
159 | // Returns the default value if the key is not present in the |
160 | // properties. Otherwise, returns the value that the key maps to from the |
161 | // properties parameter. |
162 | static float GetProperty(const string& key, const Properties& properties, |
163 | float default_value = 0.0f); |
164 | |
165 | // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key |
166 | // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that |
167 | // the key maps to in the properties of the given hlo. |
168 | static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, |
169 | const HloToProperties& hlo_to_properties); |
170 | |
171 | // Function which computes the size of the top-level of a given shape (not |
172 | // including nested elements, if any). If null then bytes_accessed methods |
173 | // return an error. |
174 | const ShapeSizeFunction shape_size_; |
175 | |
176 | HloToProperties hlo_properties_; |
177 | |
178 | // If true, the time taken will be computed from the rates for each property |
179 | // and the total time will be the maximum time, which is the time of the |
180 | // bottleneck. |
181 | bool current_should_compute_bottleneck_time_; |
182 | |
183 | // The properties of the currently visited instruction. A HandleFoo method can |
184 | // modify these to change the default values computed in Preprocess. |
185 | Properties current_properties_; |
186 | |
187 | // The sum of the properties of all HLOs in the computation. |
188 | Properties properties_sum_; |
189 | |
190 | // How much of each property can be processed per second. E.g. if the property |
191 | // is bytes accessed, this is the number of bytes that can be processed per |
192 | // second. Is empty if no rates have been set. |
193 | Properties per_second_rates_; |
194 | |
195 | TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis); |
196 | }; |
197 | |
198 | } // namespace xla |
199 | |
200 | #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ |
201 | |