14 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 15 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 16 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 17 #include "tensorflow/compiler/xla/service/shape_inference.h" 22 class ShapeVerifier :
public DfsHloVisitor {
24 explicit ShapeVerifier() : allow_mixed_precision_(false) {}
25 explicit ShapeVerifier(
bool allow_mixed_precision)
26 : allow_mixed_precision_(allow_mixed_precision) {}
27 Status HandleElementwiseUnary(HloInstruction* hlo)
override;
28 Status HandleElementwiseBinary(HloInstruction* hlo)
override;
29 Status HandleClamp(HloInstruction* clamp)
override;
30 Status HandleSelect(HloInstruction* select)
override;
31 Status HandleConcatenate(HloInstruction* concatenate)
override;
32 Status HandleConvert(HloInstruction* convert)
override;
33 Status HandleBitcastConvert(HloInstruction* convert)
override;
34 Status HandleCopy(HloInstruction* copy)
override;
35 Status HandleDot(HloInstruction* dot)
override;
36 Status HandleConvolution(HloInstruction* convolution)
override;
37 Status HandleFft(HloInstruction* fft)
override;
38 Status HandleCrossReplicaSum(HloInstruction* crs)
override;
39 Status HandleReducePrecision(HloInstruction* reduce_precision)
override;
40 Status HandleInfeed(HloInstruction*)
override;
41 Status HandleOutfeed(HloInstruction*)
override;
42 Status HandleRng(HloInstruction*)
override;
43 Status HandleReverse(HloInstruction* reverse)
override;
44 Status HandleSort(HloInstruction* sort)
override;
45 Status HandleConstant(HloInstruction* constant)
override;
46 Status HandleGetTupleElement(HloInstruction* get_tuple_element)
override;
47 Status HandleReduce(HloInstruction* reduce)
override;
48 Status HandleBitcast(HloInstruction* bitcast)
override;
49 Status HandleBroadcast(HloInstruction* broadcast)
override;
50 Status HandleBroadcastDimOne(HloInstruction* broadcastDimOne)
override;
51 Status HandleReshape(HloInstruction* reshape)
override;
52 Status HandleTranspose(HloInstruction* transpose)
override;
53 Status HandleParameter(HloInstruction*)
override;
54 Status HandleFusion(HloInstruction*)
override;
55 Status HandleCall(HloInstruction* call)
override;
56 Status HandleCustomCall(HloInstruction*)
override;
57 Status HandleHostCompute(HloInstruction*)
override;
58 Status HandleSlice(HloInstruction* slice)
override;
59 Status HandleDynamicSlice(HloInstruction* dynamic_slice)
override;
60 Status HandleDynamicUpdateSlice(
61 HloInstruction* dynamic_update_slice)
override;
62 Status HandleTuple(HloInstruction* tuple)
override;
63 Status HandleMap(HloInstruction* map)
override;
64 Status HandleReduceWindow(HloInstruction* reduce_window)
override;
65 Status HandleSelectAndScatter(HloInstruction* instruction)
override;
66 Status HandleWhile(HloInstruction* xla_while)
override;
67 Status HandleConditional(HloInstruction* conditional)
override;
68 Status HandlePad(HloInstruction* pad)
override;
69 Status HandleSend(HloInstruction* send)
override;
70 Status HandleSendDone(HloInstruction* send_done)
override;
71 Status HandleRecv(HloInstruction* recv)
override;
72 Status HandleRecvDone(HloInstruction* recv_done)
override;
73 Status HandleBatchNormTraining(HloInstruction* batch_norm_training)
override;
74 Status HandleBatchNormInference(
75 HloInstruction* batch_norm_inference)
override;
76 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad)
override;
77 Status HandleGather(HloInstruction* gather)
override;
78 Status FinishVisit(HloInstruction*)
override {
79 return tensorflow::Status::OK();
84 Status CheckShape(
const HloInstruction* instruction,
85 const Shape& inferred_shape);
87 Status CheckShape(
const HloInstruction* instruction,
88 const StatusOr<Shape>& inferred_shape_status);
90 Status CheckUnaryShape(
const HloInstruction* instruction);
91 Status CheckBinaryShape(
const HloInstruction* instruction);
92 Status CheckTernaryShape(
const HloInstruction* instruction);
93 Status CheckVariadicShape(
const HloInstruction* instruction);
95 Status CheckSameChannel(
const HloInstruction* instr1,
96 const HloInstruction* instr2);
101 bool allow_mixed_precision_;
114 using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
117 : shape_verifier_factory_(
118 [] {
return MakeUnique<ShapeVerifier>(
false); }) {}
120 : shape_verifier_factory_([allow_mixed_precision] {
121 return MakeUnique<ShapeVerifier>(allow_mixed_precision);
124 explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
125 : shape_verifier_factory_(std::move(shape_verifier_factory)) {}
127 tensorflow::StringPiece name()
const override {
return "verifier"; }
130 StatusOr<bool> Run(
HloModule* module)
override;
138 ShapeVerifierFactory shape_verifier_factory_;
141 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ Definition: hlo_verifier.h:112
Definition: hlo_instruction.h:165
namespace for xla
Definition: client_library.cc:26
Definition: hlo_module.h:52