14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ 15 #define TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ 18 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 32 bool rewrite_inference_op =
false,
33 bool rewrite_grad_op =
false,
bool use_fusion =
true)
34 : rewrite_training_op_(rewrite_training_op),
35 rewrite_inference_op_(rewrite_inference_op),
36 rewrite_grad_op_(rewrite_grad_op),
37 use_fusion_(use_fusion) {}
39 tensorflow::StringPiece name()
const override {
return "batchnorm_expander"; }
42 StatusOr<bool> Run(
HloModule* module)
override;
44 bool rewrite_training_op_;
45 bool rewrite_inference_op_;
46 bool rewrite_grad_op_;
50 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ Definition: batchnorm_expander.h:28
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52