1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/cast_op.h" |
21 | |
22 | #include "tensorflow/core/common_runtime/device.h" |
23 | #include "tensorflow/core/framework/op.h" |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | #include "tensorflow/core/platform/macros.h" |
28 | #include "tensorflow/core/platform/types.h" |
29 | #include "tensorflow/core/util/work_sharder.h" |
30 | |
31 | #include "tensorflow/core/kernels/cast_op_impl.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | typedef Eigen::ThreadPoolDevice CPUDevice; |
36 | typedef Eigen::GpuDevice GPUDevice; |
37 | #ifdef TENSORFLOW_USE_SYCL |
38 | typedef Eigen::SyclDevice SYCLDevice; |
39 | #endif // TENSORFLOW_USE_SYCL |
40 | |
41 | #define CURRY_TYPES2(FN, arg0) \ |
42 | FN(arg0, bool); \ |
43 | FN(arg0, uint8); \ |
44 | FN(arg0, int8); \ |
45 | FN(arg0, uint16); \ |
46 | FN(arg0, int16); \ |
47 | FN(arg0, int32); \ |
48 | FN(arg0, int64); \ |
49 | FN(arg0, Eigen::half); \ |
50 | FN(arg0, float); \ |
51 | FN(arg0, double); \ |
52 | FN(arg0, std::complex<float>); \ |
53 | FN(arg0, std::complex<double>) |
54 | |
55 | CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { |
56 | OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT" , &src_dtype_)); |
57 | OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT" , &dst_dtype_)); |
58 | } |
59 | |
60 | void CastOpBase::Compute(OpKernelContext* ctx) { |
61 | const Tensor& inp = ctx->input(0); |
62 | if (work_ == nullptr) { |
63 | ctx->set_output(0, inp); |
64 | } else { |
65 | Tensor* out = nullptr; |
66 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); |
67 | work_(ctx, inp, out); |
68 | } |
69 | } |
70 | |
71 | Status CastOpBase::Unimplemented() { |
72 | return errors::Unimplemented("Cast " , DataTypeString(src_dtype_), " to " , |
73 | DataTypeString(dst_dtype_), " is not supported" ); |
74 | } |
75 | |
76 | CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { |
77 | OP_REQUIRES_OK(ctx, Prepare()); |
78 | } |
79 | |
80 | Status CpuCastOp::Prepare() { |
81 | if (src_dtype_ == dst_dtype_) { |
82 | work_ = nullptr; // Identity |
83 | return Status::OK(); |
84 | } |
85 | if (src_dtype_ == DT_BOOL) { |
86 | work_ = GetCpuCastFromBool(dst_dtype_); |
87 | } else if (src_dtype_ == DT_UINT8) { |
88 | work_ = GetCpuCastFromUint8(dst_dtype_); |
89 | } else if (src_dtype_ == DT_INT8) { |
90 | work_ = GetCpuCastFromInt8(dst_dtype_); |
91 | } else if (src_dtype_ == DT_UINT16) { |
92 | work_ = GetCpuCastFromUint16(dst_dtype_); |
93 | } else if (src_dtype_ == DT_INT16) { |
94 | work_ = GetCpuCastFromInt16(dst_dtype_); |
95 | } else if (src_dtype_ == DT_INT32) { |
96 | work_ = GetCpuCastFromInt32(dst_dtype_); |
97 | } else if (src_dtype_ == DT_INT64) { |
98 | work_ = GetCpuCastFromInt64(dst_dtype_); |
99 | } else if (src_dtype_ == DT_HALF) { |
100 | work_ = GetCpuCastFromHalf(dst_dtype_); |
101 | } else if (src_dtype_ == DT_FLOAT) { |
102 | work_ = GetCpuCastFromFloat(dst_dtype_); |
103 | } else if (src_dtype_ == DT_DOUBLE) { |
104 | work_ = GetCpuCastFromDouble(dst_dtype_); |
105 | } else if (src_dtype_ == DT_COMPLEX64) { |
106 | work_ = GetCpuCastFromComplex64(dst_dtype_); |
107 | } else if (src_dtype_ == DT_COMPLEX128) { |
108 | work_ = GetCpuCastFromComplex128(dst_dtype_); |
109 | } else if (src_dtype_ == DT_BFLOAT16) { |
110 | work_ = GetCpuCastFromBfloat(dst_dtype_); |
111 | } |
112 | |
113 | // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a |
114 | // bottleneck, we could probably implement specialized support for |
115 | // vectorized versions (not the least based on F16C for Haswell |
116 | // or newer). |
117 | |
118 | return work_ == nullptr ? Unimplemented() : Status::OK(); |
119 | } |
120 | |
121 | #if GOOGLE_CUDA |
122 | class GpuCastOp : public CastOpBase { |
123 | public: |
124 | explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { |
125 | OP_REQUIRES_OK(ctx, Prepare()); |
126 | } |
127 | |
128 | private: |
129 | Status Prepare() { |
130 | if (src_dtype_ == dst_dtype_) { |
131 | work_ = nullptr; // Identity |
132 | return Status::OK(); |
133 | } |
134 | if (src_dtype_ == DT_BOOL) { |
135 | work_ = GetGpuCastFromBool(dst_dtype_); |
136 | } else if (src_dtype_ == DT_UINT8) { |
137 | work_ = GetGpuCastFromUint8(dst_dtype_); |
138 | } else if (src_dtype_ == DT_INT8) { |
139 | work_ = GetGpuCastFromInt8(dst_dtype_); |
140 | } else if (src_dtype_ == DT_UINT16) { |
141 | work_ = GetGpuCastFromUint16(dst_dtype_); |
142 | } else if (src_dtype_ == DT_INT16) { |
143 | work_ = GetGpuCastFromInt16(dst_dtype_); |
144 | } else if (src_dtype_ == DT_INT32) { |
145 | work_ = GetGpuCastFromInt32(dst_dtype_); |
146 | } else if (src_dtype_ == DT_INT64) { |
147 | work_ = GetGpuCastFromInt64(dst_dtype_); |
148 | } else if (src_dtype_ == DT_HALF) { |
149 | work_ = GetGpuCastFromHalf(dst_dtype_); |
150 | } else if (src_dtype_ == DT_FLOAT) { |
151 | work_ = GetGpuCastFromFloat(dst_dtype_); |
152 | } else if (src_dtype_ == DT_DOUBLE) { |
153 | work_ = GetGpuCastFromDouble(dst_dtype_); |
154 | } else if (src_dtype_ == DT_COMPLEX64) { |
155 | work_ = GetGpuCastFromComplex64(dst_dtype_); |
156 | } else if (src_dtype_ == DT_COMPLEX128) { |
157 | work_ = GetGpuCastFromComplex128(dst_dtype_); |
158 | } else if (src_dtype_ == DT_BFLOAT16) { |
159 | work_ = GetGpuCastFromBfloat(dst_dtype_); |
160 | } |
161 | |
162 | return work_ == nullptr ? Unimplemented() : Status::OK(); |
163 | } |
164 | }; |
165 | #endif // GOOGLE_CUDA |
166 | |
167 | #undef CAST_CASE |
168 | |
169 | REGISTER_KERNEL_BUILDER(Name("Cast" ).Device(DEVICE_CPU), CpuCastOp); |
170 | |
171 | #if GOOGLE_CUDA |
172 | #define REGISTER_CAST_GPU(srctype, dsttype) \ |
173 | REGISTER_KERNEL_BUILDER(Name("Cast") \ |
174 | .TypeConstraint<srctype>("SrcT") \ |
175 | .TypeConstraint<dsttype>("DstT") \ |
176 | .Device(DEVICE_GPU), \ |
177 | GpuCastOp) |
178 | |
179 | CURRY_TYPES2(REGISTER_CAST_GPU, bool); |
180 | CURRY_TYPES2(REGISTER_CAST_GPU, uint8); |
181 | CURRY_TYPES2(REGISTER_CAST_GPU, int8); |
182 | CURRY_TYPES2(REGISTER_CAST_GPU, uint16); |
183 | CURRY_TYPES2(REGISTER_CAST_GPU, int16); |
184 | CURRY_TYPES2(REGISTER_CAST_GPU, int32); |
185 | CURRY_TYPES2(REGISTER_CAST_GPU, int64); |
186 | CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half); |
187 | CURRY_TYPES2(REGISTER_CAST_GPU, float); |
188 | CURRY_TYPES2(REGISTER_CAST_GPU, double); |
189 | CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>); |
190 | CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>); |
191 | REGISTER_CAST_GPU(float, bfloat16); |
192 | REGISTER_CAST_GPU(bfloat16, float); |
193 | |
194 | #undef REGISTER_CAST_GPU |
195 | #endif // GOOGLE_CUDA |
196 | |
197 | #ifdef TENSORFLOW_USE_SYCL |
198 | class SyclCastOp : public CastOpBase { |
199 | public: |
200 | explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { |
201 | OP_REQUIRES_OK(ctx, Prepare()); |
202 | } |
203 | |
204 | private: |
205 | Status Prepare() { |
206 | if (src_dtype_ == dst_dtype_) { |
207 | work_ = nullptr; // Identity |
208 | return Status::OK(); |
209 | } |
210 | if (src_dtype_ == DT_BOOL) { |
211 | work_ = GetSyclCastFromBool(dst_dtype_); |
212 | } else if (src_dtype_ == DT_INT32) { |
213 | work_ = GetSyclCastFromInt32(dst_dtype_); |
214 | } else if (src_dtype_ == DT_INT64) { |
215 | work_ = GetSyclCastFromInt64(dst_dtype_); |
216 | } else if (src_dtype_ == DT_FLOAT) { |
217 | work_ = GetSyclCastFromFloat(dst_dtype_); |
218 | } else if (src_dtype_ == DT_DOUBLE) { |
219 | work_ = GetSyclCastFromDouble(dst_dtype_); |
220 | } |
221 | |
222 | return work_ == nullptr ? Unimplemented() : Status::OK(); |
223 | } |
224 | }; |
225 | |
226 | #define REGISTER_CAST_SYCL(srctype, dsttype) \ |
227 | REGISTER_KERNEL_BUILDER(Name("Cast") \ |
228 | .TypeConstraint<srctype>("SrcT") \ |
229 | .TypeConstraint<dsttype>("DstT") \ |
230 | .Device(DEVICE_SYCL), \ |
231 | SyclCastOp) |
232 | CURRY_TYPES2(REGISTER_CAST_SYCL, bool); |
233 | CURRY_TYPES2(REGISTER_CAST_SYCL, int32); |
234 | CURRY_TYPES2(REGISTER_CAST_SYCL, int64); |
235 | CURRY_TYPES2(REGISTER_CAST_SYCL, float); |
236 | CURRY_TYPES2(REGISTER_CAST_SYCL, double); |
237 | |
238 | #undef REGISTER_CAST_SYCL |
239 | |
240 | #endif // TENSORFLOW_USE_SYCL |
241 | |
242 | #undef CURRY_TYPES2 |
243 | |
244 | // HostCast differs from Cast in that its input and output are in host memory. |
245 | REGISTER_KERNEL_BUILDER(Name("_HostCast" ).Device(DEVICE_CPU), CpuCastOp); |
246 | REGISTER_KERNEL_BUILDER( |
247 | Name("_HostCast" ).Device(DEVICE_GPU).HostMemory("x" ).HostMemory("y" ), |
248 | CpuCastOp); |
249 | #ifdef TENSORFLOW_USE_SYCL |
250 | REGISTER_KERNEL_BUILDER( |
251 | Name("_HostCast" ).Device(DEVICE_SYCL).HostMemory("x" ).HostMemory("y" ), |
252 | CpuCastOp); |
253 | #endif // TENSORFLOW_USE_SYCL |
254 | } // end namespace tensorflow |
255 | |