tf_1.8_xla_doc
transpose_folding.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8 
9  http://www.apache.org/licenses/LICENSE-2.0
10 
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16 ==============================================================================*/
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
20 
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 
24 namespace xla {
29 class TransposeFolding : public HloPassInterface {
30  public:
31  using OperandIndices = std::vector<int64>;
32 
33  // Returns the set of foldable operands for a given HLO and some candidate
34  // operands.
35  using FoldableOperands = std::function<OperandIndices(const HloInstruction&,
36  const OperandIndices&)>;
37  using TransposableGemmOperandsFn = FoldableOperands;
38  using TransposableConvOperandsFn = FoldableOperands;
39 
40  // Helper function to explicitly not fold transposes.
41  static OperandIndices NeverFoldTranspose(const HloInstruction&,
42  const OperandIndices&) {
43  return {};
44  }
45  // transposable_gemm_operands returns the set of operands it wants to fold if
46  // the instruction argument is implemented as a GEMM kernel that supports
47  // transposing its arguments.
48  //
49  // transposable_conv_operands returns the set of operands it wants to fold if
50  // the instruction argument is implemented as a convolution that supports
51  // transposing its arguments.
52  explicit TransposeFolding(
53  TransposableGemmOperandsFn transposable_gemm_operands,
54  TransposableConvOperandsFn transposable_conv_operands);
55  tensorflow::StringPiece name() const override { return "transpose-folding"; }
56 
57  StatusOr<bool> Run(HloModule* module) override;
58 
59  private:
60  TransposableGemmOperandsFn transposable_gemm_operands_;
61  TransposableConvOperandsFn transposable_conv_operands_;
62 };
63 
64 } // namespace xla
65 
66 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
StatusOr< bool > Run(HloModule *module) override
entry point of xla::TransposeFolding pass
Definition: transpose_folding.cc:193
Definition: hlo_instruction.h:165
Definition: transpose_folding.h:29
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52