14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ 15 #define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ 16 #include "tensorflow/compiler/xla/service/buffer_liveness.h" 20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 22 #include "tensorflow/core/lib/gtl/flatmap.h" 31 using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
38 const int exponent_bits,
const int mantissa_bits,
39 const HloReducePrecisionOptions::Location location,
40 const InstructionFilterFunction& instruction_filter_function)
41 : exponent_bits_(exponent_bits),
42 mantissa_bits_(mantissa_bits),
44 instruction_filter_function_(instruction_filter_function) {}
49 const HloReducePrecisionOptions& reduce_precision_options)
50 : exponent_bits_(reduce_precision_options.exponent_bits()),
51 mantissa_bits_(reduce_precision_options.mantissa_bits()),
52 location_(reduce_precision_options.location()),
53 instruction_filter_function_(
54 make_filter_function(reduce_precision_options)) {}
56 tensorflow::StringPiece name()
const override {
57 return "reduce-precision-insertion";
61 StatusOr<bool> Run(
HloModule* module)
override;
64 static InstructionFilterFunction make_filter_function(
65 const HloReducePrecisionOptions& reduce_precision_options);
66 static HloReducePrecisionOptions make_options_proto(
67 const HloReducePrecisionOptions::Location location,
68 const int exponent_bits,
const int mantissa_bits,
69 const std::function<
bool(HloOpcode)>& opcode_filter_function,
70 const std::vector<string>& opname_substring_list = {});
72 enum class PassTiming { BEFORE_OPTIMIZATION, AFTER_FUSION };
80 const DebugOptions& debug_options,
85 std::vector<HloInstruction*> instructions_to_modify(
94 StatusOr<bool> insert_on_inputs(
95 const std::vector<HloInstruction*>& instructions);
99 StatusOr<bool> insert_on_outputs(
100 const std::vector<HloInstruction*>& instructions);
102 bool is_valid_shape(
const Shape& shape) {
115 return shape.element_type() == PrimitiveType::F32 &&
116 !ShapeUtil::IsScalar(shape);
120 bool is_redundant(
const HloInstruction* instruction) {
121 return instruction->opcode() == HloOpcode::kReducePrecision &&
122 instruction->exponent_bits() <= exponent_bits_ &&
123 instruction->mantissa_bits() <= mantissa_bits_;
126 const int exponent_bits_;
127 const int mantissa_bits_;
130 const HloReducePrecisionOptions::Location location_;
133 const InstructionFilterFunction instruction_filter_function_;
136 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
Definition: hlo_computation.h:60
static bool AddPasses(HloPassPipeline *pipeline, const DebugOptions &debug_options, const PassTiming pass_timing)
Definition: reduce_precision_insertion.cc:256
Definition: hlo_instruction.h:165
Definition: reduce_precision_insertion.h:30
namespace for xla
Definition: client_library.cc:26
PassTiming
Definition: reduce_precision_insertion.h:72
Definition: hlo_module.h:52
Definition: hlo_pass_pipeline.h:30