14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 15 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 18 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 30 using ValidBitcastCallback =
31 std::function<bool(const Shape& from_shape, const Shape& to_shape)>;
37 ValidBitcastCallback valid_bitcast_callback,
38 bool enable_dot_strength_reduction =
true,
39 bool enable_conv_simplification =
true)
40 : is_layout_sensitive_(is_layout_sensitive),
41 valid_bitcast_callback_(std::move(valid_bitcast_callback)),
42 enable_dot_strength_reduction_(enable_dot_strength_reduction),
43 enable_conv_simplification_(enable_conv_simplification) {}
45 tensorflow::StringPiece name()
const override {
return "algsimp"; }
48 StatusOr<bool> Run(
HloModule* module)
override;
50 bool is_layout_sensitive_;
51 ValidBitcastCallback valid_bitcast_callback_;
53 bool enable_dot_strength_reduction_;
55 bool enable_conv_simplification_;
58 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ Definition: algebraic_simplifier.h:23
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52