1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/cast_op.h"
21
22#include "tensorflow/core/common_runtime/device.h"
23#include "tensorflow/core/framework/op.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/platform/macros.h"
28#include "tensorflow/core/platform/types.h"
29#include "tensorflow/core/util/work_sharder.h"
30
31#include "tensorflow/core/kernels/cast_op_impl.h"
32
33namespace tensorflow {
34
35typedef Eigen::ThreadPoolDevice CPUDevice;
36typedef Eigen::GpuDevice GPUDevice;
37#ifdef TENSORFLOW_USE_SYCL
38typedef Eigen::SyclDevice SYCLDevice;
39#endif // TENSORFLOW_USE_SYCL
40
41#define CURRY_TYPES2(FN, arg0) \
42 FN(arg0, bool); \
43 FN(arg0, uint8); \
44 FN(arg0, int8); \
45 FN(arg0, uint16); \
46 FN(arg0, int16); \
47 FN(arg0, int32); \
48 FN(arg0, int64); \
49 FN(arg0, Eigen::half); \
50 FN(arg0, float); \
51 FN(arg0, double); \
52 FN(arg0, std::complex<float>); \
53 FN(arg0, std::complex<double>)
54
55CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
56 OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
57 OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
58}
59
60void CastOpBase::Compute(OpKernelContext* ctx) {
61 const Tensor& inp = ctx->input(0);
62 if (work_ == nullptr) {
63 ctx->set_output(0, inp);
64 } else {
65 Tensor* out = nullptr;
66 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
67 work_(ctx, inp, out);
68 }
69}
70
71Status CastOpBase::Unimplemented() {
72 return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
73 DataTypeString(dst_dtype_), " is not supported");
74}
75
76CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
77 OP_REQUIRES_OK(ctx, Prepare());
78}
79
80Status CpuCastOp::Prepare() {
81 if (src_dtype_ == dst_dtype_) {
82 work_ = nullptr; // Identity
83 return Status::OK();
84 }
85 if (src_dtype_ == DT_BOOL) {
86 work_ = GetCpuCastFromBool(dst_dtype_);
87 } else if (src_dtype_ == DT_UINT8) {
88 work_ = GetCpuCastFromUint8(dst_dtype_);
89 } else if (src_dtype_ == DT_INT8) {
90 work_ = GetCpuCastFromInt8(dst_dtype_);
91 } else if (src_dtype_ == DT_UINT16) {
92 work_ = GetCpuCastFromUint16(dst_dtype_);
93 } else if (src_dtype_ == DT_INT16) {
94 work_ = GetCpuCastFromInt16(dst_dtype_);
95 } else if (src_dtype_ == DT_INT32) {
96 work_ = GetCpuCastFromInt32(dst_dtype_);
97 } else if (src_dtype_ == DT_INT64) {
98 work_ = GetCpuCastFromInt64(dst_dtype_);
99 } else if (src_dtype_ == DT_HALF) {
100 work_ = GetCpuCastFromHalf(dst_dtype_);
101 } else if (src_dtype_ == DT_FLOAT) {
102 work_ = GetCpuCastFromFloat(dst_dtype_);
103 } else if (src_dtype_ == DT_DOUBLE) {
104 work_ = GetCpuCastFromDouble(dst_dtype_);
105 } else if (src_dtype_ == DT_COMPLEX64) {
106 work_ = GetCpuCastFromComplex64(dst_dtype_);
107 } else if (src_dtype_ == DT_COMPLEX128) {
108 work_ = GetCpuCastFromComplex128(dst_dtype_);
109 } else if (src_dtype_ == DT_BFLOAT16) {
110 work_ = GetCpuCastFromBfloat(dst_dtype_);
111 }
112
113 // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
114 // bottleneck, we could probably implement specialized support for
115 // vectorized versions (not the least based on F16C for Haswell
116 // or newer).
117
118 return work_ == nullptr ? Unimplemented() : Status::OK();
119}
120
121#if GOOGLE_CUDA
122class GpuCastOp : public CastOpBase {
123 public:
124 explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
125 OP_REQUIRES_OK(ctx, Prepare());
126 }
127
128 private:
129 Status Prepare() {
130 if (src_dtype_ == dst_dtype_) {
131 work_ = nullptr; // Identity
132 return Status::OK();
133 }
134 if (src_dtype_ == DT_BOOL) {
135 work_ = GetGpuCastFromBool(dst_dtype_);
136 } else if (src_dtype_ == DT_UINT8) {
137 work_ = GetGpuCastFromUint8(dst_dtype_);
138 } else if (src_dtype_ == DT_INT8) {
139 work_ = GetGpuCastFromInt8(dst_dtype_);
140 } else if (src_dtype_ == DT_UINT16) {
141 work_ = GetGpuCastFromUint16(dst_dtype_);
142 } else if (src_dtype_ == DT_INT16) {
143 work_ = GetGpuCastFromInt16(dst_dtype_);
144 } else if (src_dtype_ == DT_INT32) {
145 work_ = GetGpuCastFromInt32(dst_dtype_);
146 } else if (src_dtype_ == DT_INT64) {
147 work_ = GetGpuCastFromInt64(dst_dtype_);
148 } else if (src_dtype_ == DT_HALF) {
149 work_ = GetGpuCastFromHalf(dst_dtype_);
150 } else if (src_dtype_ == DT_FLOAT) {
151 work_ = GetGpuCastFromFloat(dst_dtype_);
152 } else if (src_dtype_ == DT_DOUBLE) {
153 work_ = GetGpuCastFromDouble(dst_dtype_);
154 } else if (src_dtype_ == DT_COMPLEX64) {
155 work_ = GetGpuCastFromComplex64(dst_dtype_);
156 } else if (src_dtype_ == DT_COMPLEX128) {
157 work_ = GetGpuCastFromComplex128(dst_dtype_);
158 } else if (src_dtype_ == DT_BFLOAT16) {
159 work_ = GetGpuCastFromBfloat(dst_dtype_);
160 }
161
162 return work_ == nullptr ? Unimplemented() : Status::OK();
163 }
164};
165#endif // GOOGLE_CUDA
166
167#undef CAST_CASE
168
169REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
170
171#if GOOGLE_CUDA
172#define REGISTER_CAST_GPU(srctype, dsttype) \
173 REGISTER_KERNEL_BUILDER(Name("Cast") \
174 .TypeConstraint<srctype>("SrcT") \
175 .TypeConstraint<dsttype>("DstT") \
176 .Device(DEVICE_GPU), \
177 GpuCastOp)
178
179CURRY_TYPES2(REGISTER_CAST_GPU, bool);
180CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
181CURRY_TYPES2(REGISTER_CAST_GPU, int8);
182CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
183CURRY_TYPES2(REGISTER_CAST_GPU, int16);
184CURRY_TYPES2(REGISTER_CAST_GPU, int32);
185CURRY_TYPES2(REGISTER_CAST_GPU, int64);
186CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
187CURRY_TYPES2(REGISTER_CAST_GPU, float);
188CURRY_TYPES2(REGISTER_CAST_GPU, double);
189CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
190CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
191REGISTER_CAST_GPU(float, bfloat16);
192REGISTER_CAST_GPU(bfloat16, float);
193
194#undef REGISTER_CAST_GPU
195#endif // GOOGLE_CUDA
196
197#ifdef TENSORFLOW_USE_SYCL
198class SyclCastOp : public CastOpBase {
199 public:
200 explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
201 OP_REQUIRES_OK(ctx, Prepare());
202 }
203
204 private:
205 Status Prepare() {
206 if (src_dtype_ == dst_dtype_) {
207 work_ = nullptr; // Identity
208 return Status::OK();
209 }
210 if (src_dtype_ == DT_BOOL) {
211 work_ = GetSyclCastFromBool(dst_dtype_);
212 } else if (src_dtype_ == DT_INT32) {
213 work_ = GetSyclCastFromInt32(dst_dtype_);
214 } else if (src_dtype_ == DT_INT64) {
215 work_ = GetSyclCastFromInt64(dst_dtype_);
216 } else if (src_dtype_ == DT_FLOAT) {
217 work_ = GetSyclCastFromFloat(dst_dtype_);
218 } else if (src_dtype_ == DT_DOUBLE) {
219 work_ = GetSyclCastFromDouble(dst_dtype_);
220 }
221
222 return work_ == nullptr ? Unimplemented() : Status::OK();
223 }
224};
225
226#define REGISTER_CAST_SYCL(srctype, dsttype) \
227 REGISTER_KERNEL_BUILDER(Name("Cast") \
228 .TypeConstraint<srctype>("SrcT") \
229 .TypeConstraint<dsttype>("DstT") \
230 .Device(DEVICE_SYCL), \
231 SyclCastOp)
232CURRY_TYPES2(REGISTER_CAST_SYCL, bool);
233CURRY_TYPES2(REGISTER_CAST_SYCL, int32);
234CURRY_TYPES2(REGISTER_CAST_SYCL, int64);
235CURRY_TYPES2(REGISTER_CAST_SYCL, float);
236CURRY_TYPES2(REGISTER_CAST_SYCL, double);
237
238#undef REGISTER_CAST_SYCL
239
240#endif // TENSORFLOW_USE_SYCL
241
242#undef CURRY_TYPES2
243
244// HostCast differs from Cast in that its input and output are in host memory.
245REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
246REGISTER_KERNEL_BUILDER(
247 Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"),
248 CpuCastOp);
249#ifdef TENSORFLOW_USE_SYCL
250REGISTER_KERNEL_BUILDER(
251 Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"),
252 CpuCastOp);
253#endif // TENSORFLOW_USE_SYCL
254} // end namespace tensorflow
255