tf_1.8_xla_doc
hlo_verifier.h
Go to the documentation of this file.
1 
3 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7  http://www.apache.org/licenses/LICENSE-2.0
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
15 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
16 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
17 #include "tensorflow/compiler/xla/service/shape_inference.h"
18 namespace xla {
19 // Visitor which verifies that the output shape is correctly set. Verifies
20 // against the inferred shape for the instruction.
21 // TODO(b/26024837): Check output shape for all instruction types.
22 class ShapeVerifier : public DfsHloVisitor {
23  public:
24  explicit ShapeVerifier() : allow_mixed_precision_(false) {}
25  explicit ShapeVerifier(bool allow_mixed_precision)
26  : allow_mixed_precision_(allow_mixed_precision) {}
27  Status HandleElementwiseUnary(HloInstruction* hlo) override;
28  Status HandleElementwiseBinary(HloInstruction* hlo) override;
29  Status HandleClamp(HloInstruction* clamp) override;
30  Status HandleSelect(HloInstruction* select) override;
31  Status HandleConcatenate(HloInstruction* concatenate) override;
32  Status HandleConvert(HloInstruction* convert) override;
33  Status HandleBitcastConvert(HloInstruction* convert) override;
34  Status HandleCopy(HloInstruction* copy) override;
35  Status HandleDot(HloInstruction* dot) override;
36  Status HandleConvolution(HloInstruction* convolution) override;
37  Status HandleFft(HloInstruction* fft) override;
38  Status HandleCrossReplicaSum(HloInstruction* crs) override;
39  Status HandleReducePrecision(HloInstruction* reduce_precision) override;
40  Status HandleInfeed(HloInstruction*) override;
41  Status HandleOutfeed(HloInstruction*) override;
42  Status HandleRng(HloInstruction*) override;
43  Status HandleReverse(HloInstruction* reverse) override;
44  Status HandleSort(HloInstruction* sort) override;
45  Status HandleConstant(HloInstruction* constant) override;
46  Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
47  Status HandleReduce(HloInstruction* reduce) override;
48  Status HandleBitcast(HloInstruction* bitcast) override;
49  Status HandleBroadcast(HloInstruction* broadcast) override;
50  Status HandleBroadcastDimOne(HloInstruction* broadcastDimOne) override;
51  Status HandleReshape(HloInstruction* reshape) override;
52  Status HandleTranspose(HloInstruction* transpose) override;
53  Status HandleParameter(HloInstruction*) override;
54  Status HandleFusion(HloInstruction*) override;
55  Status HandleCall(HloInstruction* call) override;
56  Status HandleCustomCall(HloInstruction*) override;
57  Status HandleHostCompute(HloInstruction*) override;
58  Status HandleSlice(HloInstruction* slice) override;
59  Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
60  Status HandleDynamicUpdateSlice(
61  HloInstruction* dynamic_update_slice) override;
62  Status HandleTuple(HloInstruction* tuple) override;
63  Status HandleMap(HloInstruction* map) override;
64  Status HandleReduceWindow(HloInstruction* reduce_window) override;
65  Status HandleSelectAndScatter(HloInstruction* instruction) override;
66  Status HandleWhile(HloInstruction* xla_while) override;
67  Status HandleConditional(HloInstruction* conditional) override;
68  Status HandlePad(HloInstruction* pad) override;
69  Status HandleSend(HloInstruction* send) override;
70  Status HandleSendDone(HloInstruction* send_done) override;
71  Status HandleRecv(HloInstruction* recv) override;
72  Status HandleRecvDone(HloInstruction* recv_done) override;
73  Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
74  Status HandleBatchNormInference(
75  HloInstruction* batch_norm_inference) override;
76  Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
77  Status HandleGather(HloInstruction* gather) override;
78  Status FinishVisit(HloInstruction*) override {
79  return tensorflow::Status::OK();
80  }
81  protected:
82  // Check the instruction's shape against the shape given by ShapeInference
83  // and return an appropriate error if there is a mismatch.
84  Status CheckShape(const HloInstruction* instruction,
85  const Shape& inferred_shape);
86  // Overload which takes a StatusOr to reduce boilerplate in the caller.
87  Status CheckShape(const HloInstruction* instruction,
88  const StatusOr<Shape>& inferred_shape_status);
89  // Check a unary (binary, etc) instruction's shape against the inferred shape.
90  Status CheckUnaryShape(const HloInstruction* instruction);
91  Status CheckBinaryShape(const HloInstruction* instruction);
92  Status CheckTernaryShape(const HloInstruction* instruction);
93  Status CheckVariadicShape(const HloInstruction* instruction);
94  // Checks if the given two instructions shares the same channel id.
95  Status CheckSameChannel(const HloInstruction* instr1,
96  const HloInstruction* instr2);
97  private:
98  // Whether the inputs and output of an instruction can contain both F32s and
99  // BF16s. Tuples that include both F32s and BF16s are allowed regardless of
100  // this flag.
101  bool allow_mixed_precision_;
102 };
112 class HloVerifier : public HloPassInterface {
113  public:
114  using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
115  // Uses standard shape inference.
116  explicit HloVerifier()
117  : shape_verifier_factory_(
118  [] { return MakeUnique<ShapeVerifier>(false); }) {}
119  explicit HloVerifier(bool allow_mixed_precision)
120  : shape_verifier_factory_([allow_mixed_precision] {
121  return MakeUnique<ShapeVerifier>(allow_mixed_precision);
122  }) {}
123  // Uses custom shape verification.
124  explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
125  : shape_verifier_factory_(std::move(shape_verifier_factory)) {}
126  ~HloVerifier() override = default;
127  tensorflow::StringPiece name() const override { return "verifier"; }
128  // Note: always returns false (no instructions are ever modified by this
129  // pass).
130  StatusOr<bool> Run(HloModule* module) override;
131  private:
132  // CHECKs various invariants of a fusion instruction.
133  Status CheckFusionInstruction(HloInstruction* fusion) const;
134  // Creates a ShapeVerifier that checks that shapes match inferred
135  // expectations. This is a factory function because ShapeVerifier, Note that
136  // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object
137  // for each run of the verifier.
138  ShapeVerifierFactory shape_verifier_factory_;
139 };
140 } // namespace xla
141 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
142 
Definition: hlo_verifier.h:112
Definition: hlo_instruction.h:165
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52