Make the HLO evaluator correctly handle the lhs contracting dim for Dots

This CL fixes the bug by rewriting how we map the result index to the lhs/rhs
index in the dot evaluator.

PiperOrigin-RevId: 202588171
This commit is contained in:
Sanjoy Das 2018-06-28 19:58:53 -07:00 committed by TensorFlower Gardener
parent 6f5feedd57
commit f04400f18f
2 changed files with 64 additions and 93 deletions

View File

@ -1025,83 +1025,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());
std::vector<int64> lhs_non_contracting_dims;
for (int64 i = 0; i < lhs_rank; i++) {
if (i != lhs_contracting_dimension) {
lhs_non_contracting_dims.push_back(i);
}
}
std::vector<int64> rhs_non_batch_non_contracting_dims;
tensorflow::gtl::FlatSet<int64> batch_dims_set(
dnums.rhs_batch_dimensions().begin(),
dnums.rhs_batch_dimensions().end());
for (int64 i = 0; i < rhs_rank; i++) {
if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
rhs_non_batch_non_contracting_dims.push_back(i);
}
}
const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
DimensionVector lhs_index(lhs_rank);
DimensionVector rhs_index(rhs_rank);
// result_index_locations[i] contains one or two pointers to the locations
// in lhs_index or rhs_index where the i'th result index should go.
tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
result_index_locations;
result_index_locations.reserve(lhs_rank + rhs_rank - 2);
// The first components in the output shape are the LHS and RHS batch
// dimensions:
for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) {
result_index_locations.push_back(
{&lhs_index[dnums.lhs_batch_dimensions(i)],
&rhs_index[dnums.rhs_batch_dimensions(i)]});
}
// Then we have the LHS and RHS non-contracting dimensions, if any:
for (int64 i = 0; i < lhs_rank; i++) {
if (i != lhs_contracting_dimension &&
!ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) {
result_index_locations.push_back({&lhs_index[i], nullptr});
}
}
for (int64 i = 0; i < rhs_rank; i++) {
if (i != rhs_contracting_dimension &&
!ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) {
result_index_locations.push_back({&rhs_index[i], nullptr});
}
}
auto result = MakeUnique<Literal>(dot->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
// Find the corresponding non-contracting indices for lhs and rhs.
//
// For `result_index`, its batch dimension, if exists, will be at the
// same dimension as the batch dimension of lhs and rhs. More
// specifically:
// - For lhs, the non-contracting dimensions, including the batch
// dimension have the same index as the `result_index`.
// - For rhs, the batch dimension is set seperately from other
// non-contracting dimensions, since these other non-contracting
// dimensions in rhs follow the non-contracting dimensions of lhs in
// the resulting index.
//
// As an example, for a resulting index:
// result_index [result_batch, result_x, result_y]
// the effecting lhs and rhs indices are:
// lhs [result_batch, lhs_non_contracting_dim, contracting_dim
// rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
// `result_x` is only affected by the lhs_non_contracting_dim and
// likewise `result_y` only depends on rhs_non_contracting_dim.
//
// so we can look up the lhs and rhs indices by:
//
// lhs:
// batch index is the same as `result_batch`.
// non-contracting dimension is the same as
// result_index[lhs_non_contracting_dim]
// rhs:
// batch index: the same as `result_batch`.
// non-contracting dimension index: *not* the same as
// result_index[rhs_non_contractng_dim], since the
// non-contracting dimensions of lhs are included in the
// result_index first. Instead, the non_contracting_dim of rhs must
// be calculated as following:
// lhs_non_contracting_dimensions_size +
// (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
//
// Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
// the index offset to the result_index that only depends on
// the non_batch and non-contracting dimensions of rhs. -1 at the
// end translates size to index.
for (auto i : lhs_non_contracting_dims) {
lhs_index[i] = result_index[i];
}
for (auto i : dnums.rhs_batch_dimensions()) {
rhs_index[i] = result_index[i];
}
for (auto i : rhs_non_batch_non_contracting_dims) {
const int64 rhs_non_batch_non_contracting_dim =
lhs_non_contracting_size + (i - batch_dim_size) - 1;
rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
for (int64 i = 0; i < result_index.size(); i++) {
*result_index_locations[i].first = result_index[i];
if (result_index_locations[i].second) {
*result_index_locations[i].second = result_index[i];
}
}
// Accumulates resulting product along the contracted dimension.

View File

@ -853,10 +853,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
DotOfGatherOptimizationWithConstRHSReverseMM)))) {
DotOfGatherOptimizationWithConstRHSReverseMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@ -883,10 +882,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
DotOfGatherOptimizationWithConstLHSReverseMM)))) {
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@ -913,10 +909,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
DotOfGatherOptimizationWithConstRHSRows)))) {
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@ -948,10 +941,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
DotOfGatherOptimizationWithConstLHSRows)))) {
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@ -983,10 +973,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
DotOfGatherOptimizationWithConstRHSCols)))) {
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@ -1010,10 +997,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
DotOfGatherOptimizationWithConstLHSCols)))) {
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@ -1036,5 +1020,28 @@ XLA_TEST_F(DotOperationTest,
Array2D<float> expected({{168.0}, {168.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
XlaBuilder builder(TestName());
Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
DotGeneral(lhs_constant, rhs_constant, dot_dnums);
Array2D<float> expected({
{26.f, 30.f},
{38.f, 44.f},
});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
} // namespace
} // namespace xla