1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements a quantized eight-bit version of the matmul operation.
17
18#define EIGEN_USE_THREADS
19
20#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
21#include "public/gemmlowp.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/kernels/meta_support.h"
25#include "tensorflow/core/kernels/quantization_utils.h"
26#include "tensorflow/core/kernels/reference_gemm.h"
27#include "tensorflow/core/lib/core/errors.h"
28
29namespace tensorflow {
30
31// We have to break this out as a separate function because there are multiple
32// combinations of transpose attributes we need to support, and they have to be
33// compile-time constants to work with the templates used internally.
34template <bool TransposeA, bool TransposeB, bool TransposeC>
35void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data,
36 const quint8* b_data, qint32* c_data, int m, int n, int k,
37 int offset_a, int offset_b, int lda, int ldb, int ldc) {
38 const uint8* a_data_as_uint8 = &(a_data->value);
39 const uint8* b_data_as_uint8 = &(b_data->value);
40 int32* c_data_as_int32 = &(c_data->value);
41 static const gemmlowp::MapOrder ResultOrder =
42 !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
43 static const gemmlowp::MapOrder LhsOrder =
44 !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
45 static const gemmlowp::MapOrder RhsOrder =
46 !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
47 gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k,
48 lda);
49 gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n,
50 ldb);
51 gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n,
52 ldc);
53 const std::tuple<> empty_pipeline = {};
54 auto& worker_threads =
55 *(op_context->device()->tensorflow_cpu_worker_threads());
56 TensorflowGemmContext context(worker_threads.num_threads,
57 worker_threads.workers);
58 gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
59 gemmlowp::DefaultL8R8BitDepthParams>(
60 &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline);
61 // Since gemmlowp uses assembly to write to the output, msan won't detect
62 // the output buffer as written to, so we mark it manually.
63 TF_ANNOTATE_MEMORY_IS_INITIALIZED(c_data_as_int32, m * n * sizeof(int32));
64}
65
66template <class T1, class T2, class Toutput>
67class QuantizedMatMulOp : public OpKernel {
68 public:
69 explicit QuantizedMatMulOp(OpKernelConstruction* context)
70 : OpKernel(context) {
71 OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
72 OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
73 }
74
75 void Compute(OpKernelContext* context) override {
76 const Tensor& a = context->input(0);
77 const Tensor& b = context->input(1);
78 const float min_a = context->input(2).flat<float>()(0);
79 const float max_a = context->input(3).flat<float>()(0);
80 const float min_b = context->input(4).flat<float>()(0);
81 const float max_b = context->input(5).flat<float>()(0);
82
83 // Make sure that we have valid quantization ranges for the input buffers.
84 // If the difference between the min and max is negative or zero, it makes
85 // it hard to do meaningful intermediate operations on the values.
86 OP_REQUIRES(context, (max_a > min_a),
87 errors::InvalidArgument("max_a must be larger than min_a."));
88 OP_REQUIRES(context, (max_b > min_b),
89 errors::InvalidArgument("max_b must be larger than min_b."));
90 const int32 offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a);
91 const int32 offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b);
92 const int32 offset_c = 0;
93 const int32 mult_c = 1;
94 const int32 shift_c = 0;
95
96 // Check that the dimensions of the two matrices are valid.
97 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()),
98 errors::InvalidArgument("In[0] is not a matrix"));
99 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()),
100 errors::InvalidArgument("In[1] is not a matrix"));
101 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
102 dim_pair[0].first = transpose_a_ ? 0 : 1;
103 dim_pair[0].second = transpose_b_ ? 1 : 0;
104
105 OP_REQUIRES(context,
106 a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
107 errors::InvalidArgument(
108 "Matrix size-compatible: In[0]: ", a.shape().DebugString(),
109 ", In[1]: ", b.shape().DebugString()));
110
111 OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)),
112 errors::InvalidArgument("shift_c must be between 0 and 31, "
113 "inclusive."));
114
115 int a_dim_remaining = 1 - dim_pair[0].first;
116 int b_dim_remaining = 1 - dim_pair[0].second;
117 TensorShape out_shape(
118 {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
119 Tensor* c = nullptr;
120 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
121 CHECK(c);
122
123 const T1* a_data = a.flat<T1>().data();
124 const T2* b_data = b.flat<T2>().data();
125 Toutput* c_data = c->flat<Toutput>().data();
126
127 const bool transpose_c = false;
128 const size_t m = a.dim_size(a_dim_remaining);
129 const size_t n = b.dim_size(b_dim_remaining);
130 const size_t k = a.dim_size(dim_pair[0].first);
131 const size_t lda = a.dim_size(1);
132 const size_t ldb = b.dim_size(1);
133 const size_t ldc = n;
134
135 if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
136 std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() &&
137 (offset_c == 0) && (mult_c == 1) && (shift_c == 0) &&
138 (transpose_c == false) && (k <= 2048)) {
139 // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and
140 // allows optimized quantized 8bit to 32bit gemm.
141 meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data,
142 c_data, m, n, k, -offset_a, -offset_b, lda, ldb, ldc);
143 } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
144 std::is_same<Toutput, qint32>() && (offset_c == 0) &&
145 (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) {
146 // The gemmlowp optimized library only works for a particular set of data
147 // types, so check if we meet those requirements and fall back to a slower
148 // reference implementation if not.
149 if (transpose_a_) {
150 if (transpose_b_) {
151 GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data,
152 m, n, k, offset_a, offset_b, lda,
153 ldb, ldc);
154 } else {
155 GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data,
156 m, n, k, offset_a, offset_b, lda,
157 ldb, ldc);
158 }
159 } else {
160 if (transpose_b_) {
161 GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data,
162 m, n, k, offset_a, offset_b, lda,
163 ldb, ldc);
164 } else {
165 GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data,
166 m, n, k, offset_a, offset_b,
167 lda, ldb, ldc);
168 }
169 }
170 } else {
171 ReferenceGemm<T1, T2, Toutput>(
172 transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a,
173 lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc);
174 }
175
176 float min_c_value;
177 float max_c_value;
178 QuantizationRangeForMultiplication<T1, T2, Toutput>(
179 min_a, max_a, min_b, max_b, &min_c_value, &max_c_value);
180 Tensor* c_min = nullptr;
181 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min));
182 c_min->flat<float>()(0) = min_c_value;
183
184 Tensor* c_max = nullptr;
185 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max));
186 c_max->flat<float>()(0) = max_c_value;
187 }
188
189 private:
190 bool transpose_a_;
191 bool transpose_b_;
192};
193
194REGISTER_KERNEL_BUILDER(Name("QuantizedMatMul")
195 .Device(DEVICE_CPU)
196 .TypeConstraint<quint8>("T1")
197 .TypeConstraint<quint8>("T2")
198 .TypeConstraint<qint32>("Toutput"),
199 QuantizedMatMulOp<quint8, quint8, qint32>);
200
201} // namespace tensorflow
202