1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ |
17 | #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ |
18 | |
19 | #include <iosfwd> |
20 | #include <string> |
21 | #include "tensorflow/compiler/xla/statusor.h" |
22 | #include "tensorflow/compiler/xla/types.h" |
23 | |
24 | namespace xla { |
25 | |
26 | // High-level optimizer instruction opcodes -- these are linear-algebra level |
27 | // opcodes. They are a flattened form of the UnaryOp, BinaryOp, ... opcodes |
28 | // present in the XLA service protobuf. |
29 | // |
30 | // See the XLA documentation for the semantics of each opcode. |
31 | // |
32 | // Each entry has the format: |
33 | // (enum_name, opcode_name) |
34 | // or |
35 | // (enum_name, opcode_name, p1 | p2 | ...) |
36 | // |
37 | // with p1, p2, ... are members of HloOpcodeProperty. They are combined |
38 | // using bitwise-or. |
39 | // |
40 | // Note: Do not use ':' in opcode names. It is used as a special character |
41 | // in these places: |
42 | // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to |
43 | // separate the opcode from the fusion kind |
44 | // - In fully qualified names (HloInstruction::FullyQualifiedName()), to |
45 | // separate the qualifiers (name of the computation and potentially the |
46 | // fusion instruction) from the name |
47 | #define HLO_OPCODE_LIST(V) \ |
48 | V(kAbs, "abs") \ |
49 | V(kAdd, "add") \ |
50 | V(kAtan2, "atan2") \ |
51 | V(kBatchNormGrad, "batch-norm-grad") \ |
52 | V(kBatchNormInference, "batch-norm-inference") \ |
53 | V(kBatchNormTraining, "batch-norm-training") \ |
54 | V(kBitcast, "bitcast") \ |
55 | V(kBitcastConvert, "bitcast-convert") \ |
56 | V(kBroadcast, "broadcast") \ |
57 | V(kBroadcastDimOne, "broadcast-dim-one") \ |
58 | V(kCall, "call", kHloOpcodeIsVariadic) \ |
59 | V(kCeil, "ceil") \ |
60 | V(kClamp, "clamp") \ |
61 | V(kComplex, "complex") \ |
62 | V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ |
63 | V(kConditional, "conditional") \ |
64 | V(kConstant, "constant") \ |
65 | V(kConvert, "convert") \ |
66 | V(kConvolution, "convolution") \ |
67 | V(kCopy, "copy") \ |
68 | V(kCos, "cosine") \ |
69 | V(kCrossReplicaSum, "cross-replica-sum") \ |
70 | V(kCustomCall, "custom-call") \ |
71 | V(kDivide, "divide") \ |
72 | V(kDot, "dot") \ |
73 | V(kDynamicSlice, "dynamic-slice") \ |
74 | V(kDynamicUpdateSlice, "dynamic-update-slice") \ |
75 | V(kEq, "equal-to", kHloOpcodeIsComparison) \ |
76 | V(kExp, "exponential") \ |
77 | V(kFft, "fft") \ |
78 | V(kFloor, "floor") \ |
79 | V(kFusion, "fusion", kHloOpcodeIsVariadic) \ |
80 | V(kGather, "gather") \ |
81 | V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ |
82 | V(kGetTupleElement, "get-tuple-element") \ |
83 | V(kGt, "greater-than", kHloOpcodeIsComparison) \ |
84 | V(kHostCompute, "host-compute") \ |
85 | V(kImag, "imag") \ |
86 | V(kInfeed, "infeed") \ |
87 | V(kIsFinite, "is-finite") \ |
88 | V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ |
89 | V(kLog, "log") \ |
90 | V(kAnd, "and") \ |
91 | V(kNot, "not") \ |
92 | V(kOr, "or") \ |
93 | V(kLt, "less-than", kHloOpcodeIsComparison) \ |
94 | V(kMap, "map", kHloOpcodeIsVariadic) \ |
95 | V(kMaximum, "maximum") \ |
96 | V(kMinimum, "minimum") \ |
97 | V(kMultiply, "multiply") \ |
98 | V(kNe, "not-equal-to", kHloOpcodeIsComparison) \ |
99 | V(kNegate, "negate") \ |
100 | V(kOutfeed, "outfeed") \ |
101 | V(kPad, "pad") \ |
102 | V(kParameter, "parameter") \ |
103 | V(kPower, "power") \ |
104 | V(kReal, "real") \ |
105 | V(kRecv, "recv") \ |
106 | V(kRecvDone, "recv-done") \ |
107 | V(kReduce, "reduce") \ |
108 | V(kReducePrecision, "reduce-precision") \ |
109 | V(kReduceWindow, "reduce-window") \ |
110 | V(kRemainder, "remainder") \ |
111 | V(kReshape, "reshape") \ |
112 | V(kReverse, "reverse") \ |
113 | V(kRng, "rng") \ |
114 | V(kRoundNearestAfz, "round-nearest-afz") \ |
115 | V(kSelect, "select") \ |
116 | V(kSelectAndScatter, "select-and-scatter") \ |
117 | V(kSend, "send") \ |
118 | V(kSendDone, "send-done") \ |
119 | V(kShiftLeft, "shift-left") \ |
120 | V(kShiftRightArithmetic, "shift-right-arithmetic") \ |
121 | V(kShiftRightLogical, "shift-right-logical") \ |
122 | V(kSign, "sign") \ |
123 | V(kSin, "sine") \ |
124 | V(kSlice, "slice") \ |
125 | V(kSort, "sort") \ |
126 | V(kSubtract, "subtract") \ |
127 | V(kTanh, "tanh") \ |
128 | V(kTrace, "trace") \ |
129 | V(kTranspose, "transpose") \ |
130 | V(kTuple, "tuple", kHloOpcodeIsVariadic) \ |
131 | V(kWhile, "while") |
132 | |
133 | enum class HloOpcode { |
134 | #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, |
135 | HLO_OPCODE_LIST(DECLARE_ENUM) |
136 | #undef DECLARE_ENUM |
137 | }; |
138 | |
139 | // List of properties associated with opcodes. |
140 | // Properties are defined as increasing powers of two, so that we can use |
141 | // bitwise-or to combine properties, and bitwise-and to test for them. |
142 | enum HloOpcodeProperty { |
143 | kHloOpcodeIsComparison = 1 << 0, |
144 | kHloOpcodeIsVariadic = 1 << 1, |
145 | }; |
146 | |
147 | // Returns a string representation of the opcode. |
148 | string HloOpcodeString(HloOpcode opcode); |
149 | |
150 | // Returns a string representation of the opcode. |
151 | StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name); |
152 | |
153 | inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { |
154 | return os << HloOpcodeString(opcode); |
155 | } |
156 | |
157 | // Returns true iff the given opcode is a comparison operation. |
158 | bool HloOpcodeIsComparison(HloOpcode opcode); |
159 | |
160 | // Returns true iff the given opcode has variadic operands. |
161 | bool HloOpcodeIsVariadic(HloOpcode opcode); |
162 | |
163 | // Returns the number of HloOpcode values. |
164 | inline const uint32_t HloOpcodeCount() { |
165 | #define HLO_COUNT_ONE(...) +1 |
166 | #define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE) |
167 | return HLO_XLIST_LENGTH(HLO_OPCODE_LIST); |
168 | } |
169 | |
170 | } // namespace xla |
171 | |
172 | #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ |
173 | |