1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // XLA implementation of OneHot operator. |
17 | |
18 | #include "tensorflow/compiler/tf2xla/literal_util.h" |
19 | #include "tensorflow/compiler/tf2xla/xla_helpers.h" |
20 | #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" |
21 | #include "tensorflow/compiler/tf2xla/xla_op_registry.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace { |
25 | |
26 | class OneHotOp : public XlaOpKernel { |
27 | public: |
28 | explicit OneHotOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { |
29 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis" , &axis_)); |
30 | } |
31 | |
32 | void Compile(XlaOpKernelContext* ctx) override { |
33 | const TensorShape indices_shape = ctx->InputShape(0); |
34 | const TensorShape depth_shape = ctx->InputShape(1); |
35 | const TensorShape on_value_shape = ctx->InputShape(2); |
36 | const TensorShape off_value_shape = ctx->InputShape(3); |
37 | |
38 | const int indices_dims = indices_shape.dims(); |
39 | const int output_dims = indices_dims + 1; |
40 | |
41 | // Preliminary validation of sizes. |
42 | OP_REQUIRES( |
43 | ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims), |
44 | errors::InvalidArgument("Expected axis to be -1 or between [0, " , |
45 | output_dims, "). But received: " , axis_)); |
46 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth_shape), |
47 | errors::InvalidArgument("depth must be a scalar, but got: " , |
48 | depth_shape.DebugString())); |
49 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value_shape), |
50 | errors::InvalidArgument("on_value must be a scalar, but got: " , |
51 | on_value_shape.DebugString())); |
52 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value_shape), |
53 | errors::InvalidArgument("off_value must be a scalar, but got: " , |
54 | off_value_shape.DebugString())); |
55 | |
56 | const int axis = (axis_ == -1) ? indices_dims : axis_; |
57 | |
58 | // The one-hot dimension. |
59 | int64 depth; |
60 | OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &depth)); |
61 | OP_REQUIRES( |
62 | ctx, depth >= 0, |
63 | errors::InvalidArgument("depth must be non-negative, got: " , depth)); |
64 | |
65 | xla::ComputationDataHandle one_hot; |
66 | OP_REQUIRES_OK( |
67 | ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0), |
68 | indices_shape, ctx->Input(0), ctx->Input(2), |
69 | ctx->Input(3), &one_hot)); |
70 | ctx->SetOutput(0, one_hot); |
71 | } |
72 | |
73 | private: |
74 | int32 axis_; |
75 | |
76 | TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); |
77 | }; |
78 | |
79 | REGISTER_XLA_OP(Name("OneHot" ).CompileTimeConstInput("depth" ), OneHotOp); |
80 | |
81 | } // namespace |
82 | } // namespace tensorflow |
83 | |