| 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | |
| 3 | Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | you may not use this file except in compliance with the License. |
| 5 | You may obtain a copy of the License at |
| 6 | |
| 7 | http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | |
| 9 | Unless required by applicable law or agreed to in writing, software |
| 10 | distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | See the License for the specific language governing permissions and |
| 13 | limitations under the License. |
| 14 | ==============================================================================*/ |
| 15 | |
| 16 | #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" |
| 17 | |
| 18 | #include <memory> |
| 19 | #include <vector> |
| 20 | |
| 21 | #include "tensorflow/compiler/xla/shape_util.h" |
| 22 | #include "tensorflow/compiler/xla/status_macros.h" |
| 23 | #include "tensorflow/compiler/xla/statusor.h" |
| 24 | #include "tensorflow/core/lib/core/errors.h" |
| 25 | |
| 26 | namespace tensorflow { |
| 27 | |
| 28 | xla::StatusOr<xla::ComputationDataHandle> BatchDot( |
| 29 | xla::ComputationBuilder* builder, xla::ComputationDataHandle x, |
| 30 | xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, |
| 31 | bool conjugate_x, bool conjugate_y) { |
| 32 | TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape, |
| 33 | builder->GetShape(x)); |
| 34 | TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape, |
| 35 | builder->GetShape(y)); |
| 36 | |
| 37 | // Check that both tensors have the same number of dimensions. There must be |
| 38 | // at least two (the batch dimensions can be empty). |
| 39 | if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { |
| 40 | return errors::InvalidArgument( |
| 41 | "Arguments to BatchedDot have different ranks: " , |
| 42 | xla::ShapeUtil::HumanString(*x_shape), " vs. " , |
| 43 | xla::ShapeUtil::HumanString(*y_shape)); |
| 44 | } |
| 45 | const int ndims = xla::ShapeUtil::Rank(*x_shape); |
| 46 | if (ndims < 2) { |
| 47 | return errors::InvalidArgument( |
| 48 | "Arguments to BatchedDot must have rank >= 2: " , ndims); |
| 49 | } |
| 50 | |
| 51 | // The batch dimensions must be equal and the matrix dimensions must be |
| 52 | // valid. |
| 53 | std::vector<int64> batch_dimension_numbers; |
| 54 | for (int i = 0; i < ndims - 2; ++i) { |
| 55 | if (x_shape->dimensions(i) != y_shape->dimensions(i)) { |
| 56 | return errors::InvalidArgument( |
| 57 | "Dimension " , i, " of inputs to BatchedDot must be equal: " , |
| 58 | xla::ShapeUtil::HumanString(*x_shape), " vs " , |
| 59 | xla::ShapeUtil::HumanString(*y_shape)); |
| 60 | } |
| 61 | batch_dimension_numbers.push_back(i); |
| 62 | } |
| 63 | |
| 64 | int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); |
| 65 | int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); |
| 66 | if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { |
| 67 | return errors::InvalidArgument( |
| 68 | "Dimensions " , x_inner_dim, " and " , y_inner_dim, |
| 69 | " of arguments to BatchedDot must be equal: " , |
| 70 | xla::ShapeUtil::HumanString(*x_shape), " transpose: " , transpose_x, |
| 71 | " vs. " , xla::ShapeUtil::HumanString(*y_shape), |
| 72 | " transpose: " , transpose_y); |
| 73 | } |
| 74 | |
| 75 | // Check for zero lhs/rhs dim size. |
| 76 | if (xla::ShapeUtil::HasZeroElements(*x_shape) || |
| 77 | xla::ShapeUtil::HasZeroElements(*y_shape)) { |
| 78 | std::vector<int64> dimensions(batch_dimension_numbers.size()); |
| 79 | for (int i = 0; i < batch_dimension_numbers.size(); ++i) { |
| 80 | dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); |
| 81 | } |
| 82 | int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); |
| 83 | int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); |
| 84 | dimensions.push_back(x_shape->dimensions(x_outer_dim)); |
| 85 | dimensions.push_back(y_shape->dimensions(y_outer_dim)); |
| 86 | return builder->Broadcast( |
| 87 | builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), |
| 88 | dimensions); |
| 89 | } |
| 90 | |
| 91 | if (x_shape->element_type() == xla::C64 && conjugate_x) { |
| 92 | x = builder->Conj(x); |
| 93 | } |
| 94 | if (y_shape->element_type() == xla::C64 && conjugate_y) { |
| 95 | y = builder->Conj(y); |
| 96 | } |
| 97 | |
| 98 | // If there are no batch dimensions, use a regular Dot. |
| 99 | // TODO(b/69062148) Remove this code when Dot emitters can be passed |
| 100 | // dimensions to transpose directly (i.e. without requiring a Transpose HLO). |
| 101 | if (batch_dimension_numbers.empty()) { |
| 102 | auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; |
| 103 | auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; |
| 104 | return builder->Dot(lhs, rhs); |
| 105 | } |
| 106 | |
| 107 | xla::DotDimensionNumbers dot_dnums; |
| 108 | dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); |
| 109 | dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); |
| 110 | for (auto batch_dimension_number : batch_dimension_numbers) { |
| 111 | dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); |
| 112 | dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); |
| 113 | } |
| 114 | return builder->DotGeneral(x, y, dot_dnums); |
| 115 | } |
| 116 | |
| 117 | } // namespace tensorflow |
| 118 | |