[XLA] Implement GeneralDot semantics in HloEvaluator.
Also: - add a general matmul test, enable interpreter to run dot_operation_test. - remove now redundant/obselete CHECKS for HandleDot in HloEvaluator. - improve documentation for DotGeneral a bit. PiperOrigin-RevId: 185211512
This commit is contained in:
parent
26114e0dc1
commit
7a5d775685
@ -1028,55 +1028,119 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
CHECK(ShapeUtil::IsArray(lhs->shape()));
|
||||
CHECK(ShapeUtil::IsArray(rhs->shape()));
|
||||
|
||||
// Dot only supports operands of rank 1 and 2.
|
||||
const auto dot_rank = ShapeUtil::Rank(dot->shape());
|
||||
const auto& dnums = dot->dot_dimension_numbers();
|
||||
|
||||
const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
|
||||
const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
|
||||
CHECK(lhs_rank > 0 && lhs_rank <= 2);
|
||||
CHECK(rhs_rank > 0 && rhs_rank <= 2);
|
||||
CHECK_EQ(dot_rank, lhs_rank + rhs_rank - 2);
|
||||
|
||||
CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
|
||||
CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
|
||||
|
||||
// Check contracted dimensions are the same.
|
||||
//
|
||||
// Determine the index of the contracted dimensions for input tensors.
|
||||
// dimensions -1 of lhs and dimension 0 of rhs are contracted.
|
||||
const int64 lhs_contracted_dimension =
|
||||
ShapeUtil::GetDimensionNumber(lhs->shape(), -1);
|
||||
const int64 rhs_contracted_dimension = 0;
|
||||
CHECK_EQ(lhs->shape().dimensions(lhs_contracted_dimension),
|
||||
rhs->shape().dimensions(rhs_contracted_dimension))
|
||||
// There must be 1 and only 1 Contracting dimension for lhs and rhs.
|
||||
CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1);
|
||||
CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1);
|
||||
const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
|
||||
const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
|
||||
// Contracted dimension sizes must be the same.
|
||||
CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension),
|
||||
rhs->shape().dimensions(rhs_contracting_dimension))
|
||||
<< "lhs contracted dimension: "
|
||||
<< lhs->shape().dimensions(lhs_contracted_dimension)
|
||||
<< lhs->shape().dimensions(lhs_contracting_dimension)
|
||||
<< " rhs contracted dimension: "
|
||||
<< rhs->shape().dimensions(rhs_contracted_dimension);
|
||||
<< rhs->shape().dimensions(rhs_contracting_dimension);
|
||||
const int64 contracted_dimension_size =
|
||||
lhs->shape().dimensions(lhs_contracted_dimension);
|
||||
lhs->shape().dimensions(lhs_contracting_dimension);
|
||||
|
||||
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
|
||||
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
|
||||
|
||||
auto result = Literal::CreateFromShape(dot->shape());
|
||||
|
||||
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
|
||||
dnums.rhs_batch_dimensions_size());
|
||||
|
||||
std::vector<int64> lhs_non_contracting_dims;
|
||||
for (int64 i = 0; i < lhs_rank; i++) {
|
||||
if (i != lhs_contracting_dimension) {
|
||||
lhs_non_contracting_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64> rhs_non_batch_non_contracting_dims;
|
||||
tensorflow::gtl::FlatSet<int64> batch_dims_set(
|
||||
dnums.rhs_batch_dimensions().begin(),
|
||||
dnums.rhs_batch_dimensions().end());
|
||||
for (int64 i = 0; i < rhs_rank; i++) {
|
||||
if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
|
||||
rhs_non_batch_non_contracting_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
|
||||
const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
|
||||
|
||||
DimensionVector lhs_index(lhs_rank);
|
||||
DimensionVector rhs_index(rhs_rank);
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
|
||||
ElementwiseT result_val = static_cast<ElementwiseT>(0);
|
||||
|
||||
std::vector<int64> lhs_index(lhs_rank, 0);
|
||||
std::vector<int64> rhs_index(rhs_rank, 0);
|
||||
// Set index for non-contracted dimension for lhs and rhs.
|
||||
if (lhs_rank > 1) {
|
||||
lhs_index[0] = multi_index[0];
|
||||
// Find the corresponding non-contracting indices for lhs and rhs.
|
||||
//
|
||||
// For `result_index`, its batch dimension, if exists, will be at the
|
||||
// same dimension as the batch dimension of lhs and rhs. More
|
||||
// specifically:
|
||||
// - For lhs, the non-contracting dimensions, including the batch
|
||||
// dimension have the same index as the `result_index`.
|
||||
// - For rhs, the batch dimension is set seperately from other
|
||||
// non-contracting dimensions, since these other non-contracting
|
||||
// dimensions in rhs follow the non-contracting dimensions of lhs in
|
||||
// the resulting index.
|
||||
//
|
||||
// As an example, for a resulting index:
|
||||
// result_index [result_batch, result_x, result_y]
|
||||
// the effecting lhs and rhs indices are:
|
||||
// lhs [result_batch, lhs_non_contracting_dim, contracting_dim
|
||||
// rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
|
||||
// `result_x` is only affected by the lhs_non_contracting_dim and
|
||||
// likewise `result_y` only depends on rhs_non_contracting_dim.
|
||||
//
|
||||
// so we can look up the lhs and rhs indices by:
|
||||
//
|
||||
// lhs:
|
||||
// batch index is the same as `result_batch`.
|
||||
// non-contracting dimension is the same as
|
||||
// result_index[lhs_non_contracting_dim]
|
||||
// rhs:
|
||||
// batch index: the same as `result_batch`.
|
||||
// non-contracting dimension index: *not* the same as
|
||||
// result_index[rhs_non_contractng_dim], since the
|
||||
// non-contracting dimensions of lhs are included in the
|
||||
// result_index first. Instead, the non_contracting_dim of rhs must
|
||||
// be calculated as following:
|
||||
// lhs_non_contracting_dimensions_size +
|
||||
// (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
|
||||
//
|
||||
// Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
|
||||
// the index offset to the result_index that only depends on
|
||||
// the non_batch and non-contracting dimensions of rhs. -1 at the
|
||||
// end translates size to index.
|
||||
for (auto i : lhs_non_contracting_dims) {
|
||||
lhs_index[i] = result_index[i];
|
||||
}
|
||||
if (rhs_rank > 1) {
|
||||
rhs_index[1] = multi_index[multi_index.size() - 1];
|
||||
for (auto i : dnums.rhs_batch_dimensions()) {
|
||||
rhs_index[i] = result_index[i];
|
||||
}
|
||||
for (auto i : rhs_non_batch_non_contracting_dims) {
|
||||
const int64 rhs_non_batch_non_contracting_dim =
|
||||
lhs_non_contracting_size + (i - batch_dim_size) - 1;
|
||||
rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
|
||||
}
|
||||
|
||||
// Accumulates resulting product along the contracted dimension.
|
||||
for (int64 i = 0; i < contracted_dimension_size; ++i) {
|
||||
lhs_index[lhs_contracted_dimension] = i;
|
||||
rhs_index[rhs_contracted_dimension] = i;
|
||||
lhs_index[lhs_contracting_dimension] = i;
|
||||
rhs_index[rhs_contracting_dimension] = i;
|
||||
|
||||
result_val +=
|
||||
static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
|
||||
|
@ -601,6 +601,9 @@ xla_test(
|
||||
xla_test(
|
||||
name = "dot_operation_test",
|
||||
srcs = ["dot_operation_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
@ -623,6 +626,9 @@ xla_test(
|
||||
xla_test(
|
||||
name = "dot_operation_runtime_test",
|
||||
srcs = ["dot_operation_test.cc"],
|
||||
tags = [
|
||||
"enable_for_xla_interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
|
@ -520,9 +520,39 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) {
|
||||
|
||||
ComputeAndCompareR4<float>(
|
||||
&builder,
|
||||
/*expected=*/{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}},
|
||||
{{{42900, 79200}, {429, 792}},
|
||||
{{250800, 299200}, {2508, 2992}}}},
|
||||
/*expected=*/
|
||||
{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}},
|
||||
{{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}},
|
||||
{x_data.get(), y_data.get()}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DotOperationTest, GeneralMatMul) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x");
|
||||
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y");
|
||||
|
||||
DotDimensionNumbers dnums;
|
||||
dnums.add_lhs_contracting_dimensions(2);
|
||||
dnums.add_rhs_contracting_dimensions(1);
|
||||
dnums.add_lhs_batch_dimensions(0);
|
||||
dnums.add_rhs_batch_dimensions(0);
|
||||
|
||||
auto out = builder.DotGeneral(x, y, dnums);
|
||||
|
||||
auto x_data = client_
|
||||
->TransferToServer(*Literal::CreateR3<float>(
|
||||
{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto y_data = client_
|
||||
->TransferToServer(*Literal::CreateR3<float>(
|
||||
{{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}}))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(
|
||||
&builder,
|
||||
/*expected=*/
|
||||
{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}},
|
||||
{x_data.get(), y_data.get()}, error_spec_);
|
||||
}
|
||||
|
||||
|
@ -717,6 +717,7 @@ in 'dimension_numbers'.
|
||||
Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need
|
||||
to be the same, but must be listed in the same order in both
|
||||
'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes.
|
||||
There must be exactly one contracting dimension on both 'lhs' and 'rhs'.
|
||||
|
||||
Example with contracting dimension numbers:
|
||||
|
||||
@ -736,8 +737,9 @@ DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
|
||||
```
|
||||
|
||||
Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same
|
||||
dimension number, must be listed in the same order in both arrays, and must
|
||||
have the same dimension sizes.
|
||||
dimension number, must be listed in the same order in both arrays, must
|
||||
have the same dimension sizes, and must be ordered before contracting and
|
||||
non-contracting/non-batch dimension numbers.
|
||||
|
||||
Example with batch dimension numbers (batch size 2, 2x2 matrices):
|
||||
|
||||
@ -769,6 +771,10 @@ DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
|
||||
| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul |
|
||||
| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul |
|
||||
|
||||
It follows that the resulting dimension number starts with the batch dimension,
|
||||
then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs'
|
||||
non-contracting/non-batch dimension.
|
||||
|
||||
## DynamicSlice
|
||||
|
||||
See also
|
||||
|
Loading…
Reference in New Issue
Block a user