1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// XLA implementation of OneHot operator.
17
18#include "tensorflow/compiler/tf2xla/literal_util.h"
19#include "tensorflow/compiler/tf2xla/xla_helpers.h"
20#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22
23namespace tensorflow {
24namespace {
25
26class OneHotOp : public XlaOpKernel {
27 public:
28 explicit OneHotOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
29 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
30 }
31
32 void Compile(XlaOpKernelContext* ctx) override {
33 const TensorShape indices_shape = ctx->InputShape(0);
34 const TensorShape depth_shape = ctx->InputShape(1);
35 const TensorShape on_value_shape = ctx->InputShape(2);
36 const TensorShape off_value_shape = ctx->InputShape(3);
37
38 const int indices_dims = indices_shape.dims();
39 const int output_dims = indices_dims + 1;
40
41 // Preliminary validation of sizes.
42 OP_REQUIRES(
43 ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims),
44 errors::InvalidArgument("Expected axis to be -1 or between [0, ",
45 output_dims, "). But received: ", axis_));
46 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth_shape),
47 errors::InvalidArgument("depth must be a scalar, but got: ",
48 depth_shape.DebugString()));
49 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value_shape),
50 errors::InvalidArgument("on_value must be a scalar, but got: ",
51 on_value_shape.DebugString()));
52 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value_shape),
53 errors::InvalidArgument("off_value must be a scalar, but got: ",
54 off_value_shape.DebugString()));
55
56 const int axis = (axis_ == -1) ? indices_dims : axis_;
57
58 // The one-hot dimension.
59 int64 depth;
60 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &depth));
61 OP_REQUIRES(
62 ctx, depth >= 0,
63 errors::InvalidArgument("depth must be non-negative, got: ", depth));
64
65 xla::ComputationDataHandle one_hot;
66 OP_REQUIRES_OK(
67 ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0),
68 indices_shape, ctx->Input(0), ctx->Input(2),
69 ctx->Input(3), &one_hot));
70 ctx->SetOutput(0, one_hot);
71 }
72
73 private:
74 int32 axis_;
75
76 TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
77};
78
79REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp);
80
81} // namespace
82} // namespace tensorflow
83