18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_ 19 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_ 22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 31 using OperandIndices = std::vector<int64>;
35 using FoldableOperands = std::function<OperandIndices(
const HloInstruction&,
36 const OperandIndices&)>;
37 using TransposableGemmOperandsFn = FoldableOperands;
38 using TransposableConvOperandsFn = FoldableOperands;
42 const OperandIndices&) {
53 TransposableGemmOperandsFn transposable_gemm_operands,
54 TransposableConvOperandsFn transposable_conv_operands);
55 tensorflow::StringPiece name()
const override {
return "transpose-folding"; }
60 TransposableGemmOperandsFn transposable_gemm_operands_;
61 TransposableConvOperandsFn transposable_conv_operands_;
66 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_ StatusOr< bool > Run(HloModule *module) override
entry point of xla::TransposeFolding pass
Definition: transpose_folding.cc:193
Definition: hlo_instruction.h:165
Definition: transpose_folding.h:29
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52