1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implements a quantized eight-bit version of the matmul operation. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK |
21 | #include "public/gemmlowp.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/kernels/meta_support.h" |
25 | #include "tensorflow/core/kernels/quantization_utils.h" |
26 | #include "tensorflow/core/kernels/reference_gemm.h" |
27 | #include "tensorflow/core/lib/core/errors.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | // We have to break this out as a separate function because there are multiple |
32 | // combinations of transpose attributes we need to support, and they have to be |
33 | // compile-time constants to work with the templates used internally. |
34 | template <bool TransposeA, bool TransposeB, bool TransposeC> |
35 | void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data, |
36 | const quint8* b_data, qint32* c_data, int m, int n, int k, |
37 | int offset_a, int offset_b, int lda, int ldb, int ldc) { |
38 | const uint8* a_data_as_uint8 = &(a_data->value); |
39 | const uint8* b_data_as_uint8 = &(b_data->value); |
40 | int32* c_data_as_int32 = &(c_data->value); |
41 | static const gemmlowp::MapOrder ResultOrder = |
42 | !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; |
43 | static const gemmlowp::MapOrder LhsOrder = |
44 | !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; |
45 | static const gemmlowp::MapOrder RhsOrder = |
46 | !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor; |
47 | gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k, |
48 | lda); |
49 | gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n, |
50 | ldb); |
51 | gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n, |
52 | ldc); |
53 | const std::tuple<> empty_pipeline = {}; |
54 | auto& worker_threads = |
55 | *(op_context->device()->tensorflow_cpu_worker_threads()); |
56 | TensorflowGemmContext context(worker_threads.num_threads, |
57 | worker_threads.workers); |
58 | gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, |
59 | gemmlowp::DefaultL8R8BitDepthParams>( |
60 | &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline); |
61 | // Since gemmlowp uses assembly to write to the output, msan won't detect |
62 | // the output buffer as written to, so we mark it manually. |
63 | TF_ANNOTATE_MEMORY_IS_INITIALIZED(c_data_as_int32, m * n * sizeof(int32)); |
64 | } |
65 | |
66 | template <class T1, class T2, class Toutput> |
67 | class QuantizedMatMulOp : public OpKernel { |
68 | public: |
69 | explicit QuantizedMatMulOp(OpKernelConstruction* context) |
70 | : OpKernel(context) { |
71 | OP_REQUIRES_OK(context, context->GetAttr("transpose_a" , &transpose_a_)); |
72 | OP_REQUIRES_OK(context, context->GetAttr("transpose_b" , &transpose_b_)); |
73 | } |
74 | |
75 | void Compute(OpKernelContext* context) override { |
76 | const Tensor& a = context->input(0); |
77 | const Tensor& b = context->input(1); |
78 | const float min_a = context->input(2).flat<float>()(0); |
79 | const float max_a = context->input(3).flat<float>()(0); |
80 | const float min_b = context->input(4).flat<float>()(0); |
81 | const float max_b = context->input(5).flat<float>()(0); |
82 | |
83 | // Make sure that we have valid quantization ranges for the input buffers. |
84 | // If the difference between the min and max is negative or zero, it makes |
85 | // it hard to do meaningful intermediate operations on the values. |
86 | OP_REQUIRES(context, (max_a > min_a), |
87 | errors::InvalidArgument("max_a must be larger than min_a." )); |
88 | OP_REQUIRES(context, (max_b > min_b), |
89 | errors::InvalidArgument("max_b must be larger than min_b." )); |
90 | const int32 offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a); |
91 | const int32 offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b); |
92 | const int32 offset_c = 0; |
93 | const int32 mult_c = 1; |
94 | const int32 shift_c = 0; |
95 | |
96 | // Check that the dimensions of the two matrices are valid. |
97 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()), |
98 | errors::InvalidArgument("In[0] is not a matrix" )); |
99 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()), |
100 | errors::InvalidArgument("In[1] is not a matrix" )); |
101 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; |
102 | dim_pair[0].first = transpose_a_ ? 0 : 1; |
103 | dim_pair[0].second = transpose_b_ ? 1 : 0; |
104 | |
105 | OP_REQUIRES(context, |
106 | a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), |
107 | errors::InvalidArgument( |
108 | "Matrix size-compatible: In[0]: " , a.shape().DebugString(), |
109 | ", In[1]: " , b.shape().DebugString())); |
110 | |
111 | OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)), |
112 | errors::InvalidArgument("shift_c must be between 0 and 31, " |
113 | "inclusive." )); |
114 | |
115 | int a_dim_remaining = 1 - dim_pair[0].first; |
116 | int b_dim_remaining = 1 - dim_pair[0].second; |
117 | TensorShape out_shape( |
118 | {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); |
119 | Tensor* c = nullptr; |
120 | OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c)); |
121 | CHECK(c); |
122 | |
123 | const T1* a_data = a.flat<T1>().data(); |
124 | const T2* b_data = b.flat<T2>().data(); |
125 | Toutput* c_data = c->flat<Toutput>().data(); |
126 | |
127 | const bool transpose_c = false; |
128 | const size_t m = a.dim_size(a_dim_remaining); |
129 | const size_t n = b.dim_size(b_dim_remaining); |
130 | const size_t k = a.dim_size(dim_pair[0].first); |
131 | const size_t lda = a.dim_size(1); |
132 | const size_t ldb = b.dim_size(1); |
133 | const size_t ldc = n; |
134 | |
135 | if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && |
136 | std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() && |
137 | (offset_c == 0) && (mult_c == 1) && (shift_c == 0) && |
138 | (transpose_c == false) && (k <= 2048)) { |
139 | // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and |
140 | // allows optimized quantized 8bit to 32bit gemm. |
141 | meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data, |
142 | c_data, m, n, k, -offset_a, -offset_b, lda, ldb, ldc); |
143 | } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && |
144 | std::is_same<Toutput, qint32>() && (offset_c == 0) && |
145 | (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) { |
146 | // The gemmlowp optimized library only works for a particular set of data |
147 | // types, so check if we meet those requirements and fall back to a slower |
148 | // reference implementation if not. |
149 | if (transpose_a_) { |
150 | if (transpose_b_) { |
151 | GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data, |
152 | m, n, k, offset_a, offset_b, lda, |
153 | ldb, ldc); |
154 | } else { |
155 | GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data, |
156 | m, n, k, offset_a, offset_b, lda, |
157 | ldb, ldc); |
158 | } |
159 | } else { |
160 | if (transpose_b_) { |
161 | GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data, |
162 | m, n, k, offset_a, offset_b, lda, |
163 | ldb, ldc); |
164 | } else { |
165 | GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data, |
166 | m, n, k, offset_a, offset_b, |
167 | lda, ldb, ldc); |
168 | } |
169 | } |
170 | } else { |
171 | ReferenceGemm<T1, T2, Toutput>( |
172 | transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a, |
173 | lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc); |
174 | } |
175 | |
176 | float min_c_value; |
177 | float max_c_value; |
178 | QuantizationRangeForMultiplication<T1, T2, Toutput>( |
179 | min_a, max_a, min_b, max_b, &min_c_value, &max_c_value); |
180 | Tensor* c_min = nullptr; |
181 | OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min)); |
182 | c_min->flat<float>()(0) = min_c_value; |
183 | |
184 | Tensor* c_max = nullptr; |
185 | OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max)); |
186 | c_max->flat<float>()(0) = max_c_value; |
187 | } |
188 | |
189 | private: |
190 | bool transpose_a_; |
191 | bool transpose_b_; |
192 | }; |
193 | |
194 | REGISTER_KERNEL_BUILDER(Name("QuantizedMatMul" ) |
195 | .Device(DEVICE_CPU) |
196 | .TypeConstraint<quint8>("T1" ) |
197 | .TypeConstraint<quint8>("T2" ) |
198 | .TypeConstraint<qint32>("Toutput" ), |
199 | QuantizedMatMulOp<quint8, quint8, qint32>); |
200 | |
201 | } // namespace tensorflow |
202 | |