tf_1.8_xla_doc
reduce_precision_insertion.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7  http://www.apache.org/licenses/LICENSE-2.0
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
15 #define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
16 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
22 #include "tensorflow/core/lib/gtl/flatmap.h"
23 namespace xla {
30 class ReducePrecisionInsertion : public HloPassInterface {
31  using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
32  public:
33  // The exponent_bits and mantissa_bits arguments specify the parameters of
34  // the instructions to insert. The instructions will be inserted after each
35  // instruction with an opcode for which the instruction_filter_function
36  // function returns true and the output type is F32.
37  explicit ReducePrecisionInsertion(
38  const int exponent_bits, const int mantissa_bits,
39  const HloReducePrecisionOptions::Location location,
40  const InstructionFilterFunction& instruction_filter_function)
41  : exponent_bits_(exponent_bits),
42  mantissa_bits_(mantissa_bits),
43  location_(location),
44  instruction_filter_function_(instruction_filter_function) {}
45  // Version of the constructor that takes an HloReducePrecisionOptions proto
46  // rather than explicitly-enumerated parameters, for convenience when
47  // creating passes based on DebugOptions.
48  explicit ReducePrecisionInsertion(
49  const HloReducePrecisionOptions& reduce_precision_options)
50  : exponent_bits_(reduce_precision_options.exponent_bits()),
51  mantissa_bits_(reduce_precision_options.mantissa_bits()),
52  location_(reduce_precision_options.location()),
53  instruction_filter_function_(
54  make_filter_function(reduce_precision_options)) {}
55  ~ReducePrecisionInsertion() override{};
56  tensorflow::StringPiece name() const override {
57  return "reduce-precision-insertion";
58  }
59  // Run the pass on the given module. Returns whether the module was changed
60  // (reduce-precision instructions were inserted).
61  StatusOr<bool> Run(HloModule* module) override;
62  // Convert between the (inconvenient) xla.proto HloReducePrecisionOptions
63  // representation and InstructionFilterFunction functions.
64  static InstructionFilterFunction make_filter_function(
65  const HloReducePrecisionOptions& reduce_precision_options);
66  static HloReducePrecisionOptions make_options_proto(
67  const HloReducePrecisionOptions::Location location,
68  const int exponent_bits, const int mantissa_bits,
69  const std::function<bool(HloOpcode)>& opcode_filter_function,
70  const std::vector<string>& opname_substring_list = {});
72  enum class PassTiming { BEFORE_OPTIMIZATION, AFTER_FUSION };
79  static bool AddPasses(HloPassPipeline* pipeline,
80  const DebugOptions& debug_options,
81  const PassTiming pass_timing);
82  private:
83  // Select the instructions that should have reduce-precision operations
84  // attached to them.
85  std::vector<HloInstruction*> instructions_to_modify(
86  const HloComputation* computation);
87  // Insert a reduce-precision operation into the graph on the output of the
88  // given instruction.
89  StatusOr<bool> insert_after(HloInstruction* instruction);
90  // Insert reduce-precision operations into the graph on the inputs of the
91  // given instructions. (For fusion instructions, the operations will be
92  // inserted inside the fusion computation, on the outputs of the relevant
93  // input parameters.)
94  StatusOr<bool> insert_on_inputs(
95  const std::vector<HloInstruction*>& instructions);
96  // Insert reduce-precision operations into the graph on the outputs of the
97  // given instructions. (For fusion instructions, the operations will be
98  // inserted inside the fusion computation as a new root.)
99  StatusOr<bool> insert_on_outputs(
100  const std::vector<HloInstruction*>& instructions);
101  // Is this shape valid for inserting a reduce-precision operation?
102  bool is_valid_shape(const Shape& shape) {
103  // For now, ReducePrecision is only implemented for F32 arrays, so this
104  // ignores instructions that produce other data. In particular, this
105  // currently ignores instructions producing tuples, even if those tuples
106  // contain F32 arrays inside them. The assumption is that in most cases
107  // equivalent behavior can be obtained by adding ReducePrecision
108  // instructions after the instructions that pull the F32 arrays out of
109  // the tuples.
110  //
111  // TODO(b/64093391): Remove the IsScalar check once this won't cause
112  // failures on the GPU backend if the ReducePrecision instruction ends up
113  // inserted between a scalar constant and the init_value argument of a
114  // Reduce operation.
115  return shape.element_type() == PrimitiveType::F32 &&
116  !ShapeUtil::IsScalar(shape);
117  }
118  // Is this instruction one such that following or preceding it with a new
119  // reduce-precision operation will be redundant?
120  bool is_redundant(const HloInstruction* instruction) {
121  return instruction->opcode() == HloOpcode::kReducePrecision &&
122  instruction->exponent_bits() <= exponent_bits_ &&
123  instruction->mantissa_bits() <= mantissa_bits_;
124  }
125  // Parameters for the precision reduction to be added.
126  const int exponent_bits_;
127  const int mantissa_bits_;
128  // Pass "timing" parameter. This also controls aspects of how the pass
129  // selects locations to insert instructions.
130  const HloReducePrecisionOptions::Location location_;
131  // User-provided Function to determine whether a given instruction should
132  // have a reduce-precision instruction inserted in its output stream.
133  const InstructionFilterFunction instruction_filter_function_;
134 };
135 } // namespace xla
136 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
137 
Definition: hlo_computation.h:60
static bool AddPasses(HloPassPipeline *pipeline, const DebugOptions &debug_options, const PassTiming pass_timing)
Definition: reduce_precision_insertion.cc:256
Definition: hlo_instruction.h:165
Definition: reduce_precision_insertion.h:30
namespace for xla
Definition: client_library.cc:26
PassTiming
Definition: reduce_precision_insertion.h:72
Definition: hlo_module.h:52
Definition: hlo_pass_pipeline.h:30