[XLA] Implement GeneralDot semantics in HloEvaluator.

Also:
- add a general matmul test, enable interpreter to run dot_operation_test.
- remove now redundant/obselete CHECKS for HandleDot in HloEvaluator.
- improve documentation for DotGeneral a bit.
PiperOrigin-RevId: 185211512
This commit is contained in:
Kay Zhu 2018-02-09 16:39:48 -08:00 committed by TensorFlower Gardener
parent 26114e0dc1
commit 7a5d775685
4 changed files with 138 additions and 32 deletions

View File

@ -1028,55 +1028,119 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
CHECK(ShapeUtil::IsArray(lhs->shape()));
CHECK(ShapeUtil::IsArray(rhs->shape()));
// Dot only supports operands of rank 1 and 2.
const auto dot_rank = ShapeUtil::Rank(dot->shape());
const auto& dnums = dot->dot_dimension_numbers();
const auto lhs_rank = ShapeUtil::Rank(lhs->shape());
const auto rhs_rank = ShapeUtil::Rank(rhs->shape());
CHECK(lhs_rank > 0 && lhs_rank <= 2);
CHECK(rhs_rank > 0 && rhs_rank <= 2);
CHECK_EQ(dot_rank, lhs_rank + rhs_rank - 2);
CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape()));
CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape()));
// Check contracted dimensions are the same.
//
// Determine the index of the contracted dimensions for input tensors.
// dimensions -1 of lhs and dimension 0 of rhs are contracted.
const int64 lhs_contracted_dimension =
ShapeUtil::GetDimensionNumber(lhs->shape(), -1);
const int64 rhs_contracted_dimension = 0;
CHECK_EQ(lhs->shape().dimensions(lhs_contracted_dimension),
rhs->shape().dimensions(rhs_contracted_dimension))
// There must be 1 and only 1 Contracting dimension for lhs and rhs.
CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1);
CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1);
const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
// Contracted dimension sizes must be the same.
CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension),
rhs->shape().dimensions(rhs_contracting_dimension))
<< "lhs contracted dimension: "
<< lhs->shape().dimensions(lhs_contracted_dimension)
<< lhs->shape().dimensions(lhs_contracting_dimension)
<< " rhs contracted dimension: "
<< rhs->shape().dimensions(rhs_contracted_dimension);
<< rhs->shape().dimensions(rhs_contracting_dimension);
const int64 contracted_dimension_size =
lhs->shape().dimensions(lhs_contracted_dimension);
lhs->shape().dimensions(lhs_contracting_dimension);
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
auto result = Literal::CreateFromShape(dot->shape());
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());
std::vector<int64> lhs_non_contracting_dims;
for (int64 i = 0; i < lhs_rank; i++) {
if (i != lhs_contracting_dimension) {
lhs_non_contracting_dims.push_back(i);
}
}
std::vector<int64> rhs_non_batch_non_contracting_dims;
tensorflow::gtl::FlatSet<int64> batch_dims_set(
dnums.rhs_batch_dimensions().begin(),
dnums.rhs_batch_dimensions().end());
for (int64 i = 0; i < rhs_rank; i++) {
if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
rhs_non_batch_non_contracting_dims.push_back(i);
}
}
const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
DimensionVector lhs_index(lhs_rank);
DimensionVector rhs_index(rhs_rank);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
std::vector<int64> lhs_index(lhs_rank, 0);
std::vector<int64> rhs_index(rhs_rank, 0);
// Set index for non-contracted dimension for lhs and rhs.
if (lhs_rank > 1) {
lhs_index[0] = multi_index[0];
// Find the corresponding non-contracting indices for lhs and rhs.
//
// For `result_index`, its batch dimension, if exists, will be at the
// same dimension as the batch dimension of lhs and rhs. More
// specifically:
// - For lhs, the non-contracting dimensions, including the batch
// dimension have the same index as the `result_index`.
// - For rhs, the batch dimension is set seperately from other
// non-contracting dimensions, since these other non-contracting
// dimensions in rhs follow the non-contracting dimensions of lhs in
// the resulting index.
//
// As an example, for a resulting index:
// result_index [result_batch, result_x, result_y]
// the effecting lhs and rhs indices are:
// lhs [result_batch, lhs_non_contracting_dim, contracting_dim
// rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
// `result_x` is only affected by the lhs_non_contracting_dim and
// likewise `result_y` only depends on rhs_non_contracting_dim.
//
// so we can look up the lhs and rhs indices by:
//
// lhs:
// batch index is the same as `result_batch`.
// non-contracting dimension is the same as
// result_index[lhs_non_contracting_dim]
// rhs:
// batch index: the same as `result_batch`.
// non-contracting dimension index: *not* the same as
// result_index[rhs_non_contractng_dim], since the
// non-contracting dimensions of lhs are included in the
// result_index first. Instead, the non_contracting_dim of rhs must
// be calculated as following:
// lhs_non_contracting_dimensions_size +
// (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
//
// Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
// the index offset to the result_index that only depends on
// the non_batch and non-contracting dimensions of rhs. -1 at the
// end translates size to index.
for (auto i : lhs_non_contracting_dims) {
lhs_index[i] = result_index[i];
}
if (rhs_rank > 1) {
rhs_index[1] = multi_index[multi_index.size() - 1];
for (auto i : dnums.rhs_batch_dimensions()) {
rhs_index[i] = result_index[i];
}
for (auto i : rhs_non_batch_non_contracting_dims) {
const int64 rhs_non_batch_non_contracting_dim =
lhs_non_contracting_size + (i - batch_dim_size) - 1;
rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
}
// Accumulates resulting product along the contracted dimension.
for (int64 i = 0; i < contracted_dimension_size; ++i) {
lhs_index[lhs_contracted_dimension] = i;
rhs_index[rhs_contracted_dimension] = i;
lhs_index[lhs_contracting_dimension] = i;
rhs_index[rhs_contracting_dimension] = i;
result_val +=
static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *

View File

@ -601,6 +601,9 @@ xla_test(
xla_test(
name = "dot_operation_test",
srcs = ["dot_operation_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
@ -623,6 +626,9 @@ xla_test(
xla_test(
name = "dot_operation_runtime_test",
srcs = ["dot_operation_test.cc"],
tags = [
"enable_for_xla_interpreter",
],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",

View File

@ -520,9 +520,39 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) {
ComputeAndCompareR4<float>(
&builder,
/*expected=*/{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}},
{{{42900, 79200}, {429, 792}},
{{250800, 299200}, {2508, 2992}}}},
/*expected=*/
{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}},
{{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}},
{x_data.get(), y_data.get()}, error_spec_);
}
XLA_TEST_F(DotOperationTest, GeneralMatMul) {
ComputationBuilder builder(client_, TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y");
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
auto out = builder.DotGeneral(x, y, dnums);
auto x_data = client_
->TransferToServer(*Literal::CreateR3<float>(
{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}))
.ConsumeValueOrDie();
auto y_data = client_
->TransferToServer(*Literal::CreateR3<float>(
{{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}}))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(
&builder,
/*expected=*/
{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}},
{x_data.get(), y_data.get()}, error_spec_);
}

View File

@ -717,6 +717,7 @@ in 'dimension_numbers'.
Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need
to be the same, but must be listed in the same order in both
'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes.
There must be exactly one contracting dimension on both 'lhs' and 'rhs'.
Example with contracting dimension numbers:
@ -736,8 +737,9 @@ DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
```
Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same
dimension number, must be listed in the same order in both arrays, and must
have the same dimension sizes.
dimension number, must be listed in the same order in both arrays, must
have the same dimension sizes, and must be ordered before contracting and
non-contracting/non-batch dimension numbers.
Example with batch dimension numbers (batch size 2, 2x2 matrices):
@ -769,6 +771,10 @@ DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul |
| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul |
It follows that the resulting dimension number starts with the batch dimension,
then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs'
non-contracting/non-batch dimension.
## DynamicSlice
See also