1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
18
19#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
20#include "tensorflow/compiler/xla/service/hlo_computation.h"
21#include "tensorflow/compiler/xla/service/hlo_instruction.h"
22#include "tensorflow/compiler/xla/service/hlo_opcode.h"
23#include "tensorflow/compiler/xla/shape_util.h"
24#include "tensorflow/compiler/xla/statusor.h"
25#include "tensorflow/compiler/xla/xla_data.pb.h"
26#include "tensorflow/core/lib/gtl/array_slice.h"
27#include "tensorflow/core/platform/macros.h"
28#include "tensorflow/core/platform/types.h"
29
30namespace xla {
31
32// HloCostAnalysis traverses an HLO graph and calculates the amount of
33// computations required for the graph. Each HLO instruction handler provides
34// the computation cost of the instruction, and the values are accumulated
35// during the traversal for the entire graph. We treat normal floating point
36// operations separately from transcendental operations.
37class HloCostAnalysis : public ConstDfsHloVisitor {
38 public:
39 // Each HLO is associated to a vector of properties with the indices given
40 // below. Sub-classes can add further properties.
41 typedef std::map<string, float> Properties;
42 static constexpr char kFlopsKey[] = "flops";
43 static constexpr char kTranscendentalsKey[] = "transcendentals";
44 static constexpr char kBytesAccessedKey[] = "bytes accessed";
45 static constexpr char kOptimalSecondsKey[] = "optimal_seconds";
46
47 // shape_size is a function which returns the size in bytes of the top-level
48 // buffer of a shape.
49 using ShapeSizeFunction = std::function<int64(const Shape&)>;
50 explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
51
52 Status HandleElementwiseUnary(const HloInstruction* hlo) override;
53 Status HandleElementwiseBinary(const HloInstruction* hlo) override;
54 Status HandleConstant(const HloInstruction* constant) override;
55 Status HandleGetTupleElement(
56 const HloInstruction* get_tuple_element) override;
57 Status HandleSelect(const HloInstruction* select) override;
58 Status HandleCompare(const HloInstruction* compare) override;
59 Status HandleClamp(const HloInstruction* clamp) override;
60 Status HandleReducePrecision(const HloInstruction* hlo) override;
61 Status HandleConcatenate(const HloInstruction* concatenate) override;
62 Status HandleSend(const HloInstruction* send) override;
63 Status HandleSendDone(const HloInstruction* send_done) override;
64 Status HandleRecv(const HloInstruction* recv) override;
65 Status HandleRecvDone(const HloInstruction* recv_done) override;
66 Status HandleConvert(const HloInstruction* convert) override;
67 Status HandleCopy(const HloInstruction* copy) override;
68 Status HandleDot(const HloInstruction* dot) override;
69 Status HandleConvolution(const HloInstruction* convolution) override;
70 Status HandleFft(const HloInstruction* fft) override;
71 Status HandleCrossReplicaSum(const HloInstruction* crs) override;
72 Status HandleInfeed(const HloInstruction* infeed) override;
73 Status HandleOutfeed(const HloInstruction* outfeed) override;
74 Status HandleHostCompute(const HloInstruction* host_compute) override;
75 Status HandleRng(const HloInstruction* random) override;
76 Status HandleReverse(const HloInstruction* reverse) override;
77 Status HandleSort(const HloInstruction* sort) override;
78 Status HandleParameter(const HloInstruction* parameter) override;
79 Status HandleReduce(const HloInstruction* reduce) override;
80 Status HandleBatchNormTraining(
81 const HloInstruction* batch_norm_training) override;
82 Status HandleBatchNormInference(
83 const HloInstruction* batch_norm_inference) override;
84 Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override;
85 Status HandleFusion(const HloInstruction* fusion) override;
86 Status HandleCall(const HloInstruction* call) override;
87 Status HandleCustomCall(const HloInstruction* custom_call) override;
88 Status HandleSlice(const HloInstruction* slice) override;
89 Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override;
90 Status HandleDynamicUpdateSlice(
91 const HloInstruction* dynamic_update_slice) override;
92 Status HandleTuple(const HloInstruction* tuple) override;
93 Status HandleMap(const HloInstruction* map) override;
94 Status HandleReduceWindow(const HloInstruction* reduce_window) override;
95 Status HandleSelectAndScatter(const HloInstruction* instruction) override;
96 Status HandleBitcast(const HloInstruction* bitcast) override;
97 Status HandleBroadcast(const HloInstruction* broadcast) override;
98 Status HandleBroadcastDimOne(const HloInstruction* broadcastDimOne) override;
99 Status HandlePad(const HloInstruction* pad) override;
100 Status HandleReshape(const HloInstruction* reshape) override;
101 Status HandleTranspose(const HloInstruction* transpose) override;
102 Status HandleWhile(const HloInstruction* xla_while) override;
103 Status HandleConditional(const HloInstruction* conditional) override;
104 Status HandleGather(const HloInstruction* gather) override;
105 Status FinishVisit(const HloInstruction* root) override;
106
107 Status Preprocess(const HloInstruction* hlo) override;
108 Status Postprocess(const HloInstruction* hlo) override;
109
110 // Set the rates used to calculate the time taken by the computation. These
111 // need to be set before visiting starts.
112 void set_flops_per_second(float value) {
113 per_second_rates_[kFlopsKey] = value;
114 }
115 void set_transcendentals_per_second(float value) {
116 per_second_rates_[kTranscendentalsKey] = value;
117 }
118 void set_bytes_per_second(float value) {
119 per_second_rates_[kBytesAccessedKey] = value;
120 }
121
122 // Returns properties for the computation.
123 float flop_count() const;
124 float transcendental_count() const;
125 float bytes_accessed() const;
126 float optimal_seconds() const;
127
128 // Returns the respective cost computed for a particular HLO instruction, or 0
129 // if the HLO was not found to have a cost in the analysis.
130 int64 flop_count(const HloInstruction& hlo) const;
131 int64 transcendental_count(const HloInstruction& hlo) const;
132 int64 bytes_accessed(const HloInstruction& hlo) const;
133 float optimal_seconds(const HloInstruction& hlo) const;
134
135 const Properties& properties() const { return properties_sum_; }
136 const float property(const string& key) const {
137 return GetProperty(key, properties());
138 }
139
140 protected:
141 typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties;
142
143 // An FMA counts as two floating point operations in these analyzes.
144 static constexpr int64 kFmaFlops = 2;
145
146 HloCostAnalysis(const ShapeSizeFunction& shape_size,
147 const Properties& per_second_rates);
148
149 // Returns the properties computed from visiting the computation rooted at the
150 // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null,
151 // otherwise uses shape_size_.
152 StatusOr<Properties> ProcessSubcomputation(
153 HloComputation* computation,
154 const ShapeSizeFunction* shape_size = nullptr);
155
156 // Utility function to handle all element-wise operations.
157 Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
158
159 // Returns the default value if the key is not present in the
160 // properties. Otherwise, returns the value that the key maps to from the
161 // properties parameter.
162 static float GetProperty(const string& key, const Properties& properties,
163 float default_value = 0.0f);
164
165 // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key
166 // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that
167 // the key maps to in the properties of the given hlo.
168 static float GetPropertyForHlo(const HloInstruction& hlo, const string& key,
169 const HloToProperties& hlo_to_properties);
170
171 // Function which computes the size of the top-level of a given shape (not
172 // including nested elements, if any). If null then bytes_accessed methods
173 // return an error.
174 const ShapeSizeFunction shape_size_;
175
176 HloToProperties hlo_properties_;
177
178 // If true, the time taken will be computed from the rates for each property
179 // and the total time will be the maximum time, which is the time of the
180 // bottleneck.
181 bool current_should_compute_bottleneck_time_;
182
183 // The properties of the currently visited instruction. A HandleFoo method can
184 // modify these to change the default values computed in Preprocess.
185 Properties current_properties_;
186
187 // The sum of the properties of all HLOs in the computation.
188 Properties properties_sum_;
189
190 // How much of each property can be processed per second. E.g. if the property
191 // is bytes accessed, this is the number of bytes that can be processed per
192 // second. Is empty if no rates have been set.
193 Properties per_second_rates_;
194
195 TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis);
196};
197
198} // namespace xla
199
200#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
201