1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
17
18#include <memory>
19#include <vector>
20
21#include "tensorflow/compiler/xla/shape_util.h"
22#include "tensorflow/compiler/xla/status_macros.h"
23#include "tensorflow/compiler/xla/statusor.h"
24#include "tensorflow/core/lib/core/errors.h"
25
26namespace tensorflow {
27
28xla::StatusOr<xla::ComputationDataHandle> BatchDot(
29 xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
30 xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
31 bool conjugate_x, bool conjugate_y) {
32 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
33 builder->GetShape(x));
34 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
35 builder->GetShape(y));
36
37 // Check that both tensors have the same number of dimensions. There must be
38 // at least two (the batch dimensions can be empty).
39 if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) {
40 return errors::InvalidArgument(
41 "Arguments to BatchedDot have different ranks: ",
42 xla::ShapeUtil::HumanString(*x_shape), " vs. ",
43 xla::ShapeUtil::HumanString(*y_shape));
44 }
45 const int ndims = xla::ShapeUtil::Rank(*x_shape);
46 if (ndims < 2) {
47 return errors::InvalidArgument(
48 "Arguments to BatchedDot must have rank >= 2: ", ndims);
49 }
50
51 // The batch dimensions must be equal and the matrix dimensions must be
52 // valid.
53 std::vector<int64> batch_dimension_numbers;
54 for (int i = 0; i < ndims - 2; ++i) {
55 if (x_shape->dimensions(i) != y_shape->dimensions(i)) {
56 return errors::InvalidArgument(
57 "Dimension ", i, " of inputs to BatchedDot must be equal: ",
58 xla::ShapeUtil::HumanString(*x_shape), " vs ",
59 xla::ShapeUtil::HumanString(*y_shape));
60 }
61 batch_dimension_numbers.push_back(i);
62 }
63
64 int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
65 int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
66 if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) {
67 return errors::InvalidArgument(
68 "Dimensions ", x_inner_dim, " and ", y_inner_dim,
69 " of arguments to BatchedDot must be equal: ",
70 xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x,
71 " vs. ", xla::ShapeUtil::HumanString(*y_shape),
72 " transpose: ", transpose_y);
73 }
74
75 // Check for zero lhs/rhs dim size.
76 if (xla::ShapeUtil::HasZeroElements(*x_shape) ||
77 xla::ShapeUtil::HasZeroElements(*y_shape)) {
78 std::vector<int64> dimensions(batch_dimension_numbers.size());
79 for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
80 dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]);
81 }
82 int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
83 int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
84 dimensions.push_back(x_shape->dimensions(x_outer_dim));
85 dimensions.push_back(y_shape->dimensions(y_outer_dim));
86 return builder->Broadcast(
87 builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())),
88 dimensions);
89 }
90
91 if (x_shape->element_type() == xla::C64 && conjugate_x) {
92 x = builder->Conj(x);
93 }
94 if (y_shape->element_type() == xla::C64 && conjugate_y) {
95 y = builder->Conj(y);
96 }
97
98 // If there are no batch dimensions, use a regular Dot.
99 // TODO(b/69062148) Remove this code when Dot emitters can be passed
100 // dimensions to transpose directly (i.e. without requiring a Transpose HLO).
101 if (batch_dimension_numbers.empty()) {
102 auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x;
103 auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y;
104 return builder->Dot(lhs, rhs);
105 }
106
107 xla::DotDimensionNumbers dot_dnums;
108 dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
109 dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
110 for (auto batch_dimension_number : batch_dimension_numbers) {
111 dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
112 dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
113 }
114 return builder->DotGeneral(x, y, dot_dnums);
115}
116
117} // namespace tensorflow
118