Make the HLO evaluator correctly handle the lhs contracting dim for Dots
This CL fixes the bug by rewriting how we map the result index to the lhs/rhs index in the dot evaluator. PiperOrigin-RevId: 202588171
This commit is contained in:
parent
6f5feedd57
commit
f04400f18f
@ -1025,83 +1025,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
|
||||
dnums.rhs_batch_dimensions_size());
|
||||
|
||||
std::vector<int64> lhs_non_contracting_dims;
|
||||
for (int64 i = 0; i < lhs_rank; i++) {
|
||||
if (i != lhs_contracting_dimension) {
|
||||
lhs_non_contracting_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64> rhs_non_batch_non_contracting_dims;
|
||||
tensorflow::gtl::FlatSet<int64> batch_dims_set(
|
||||
dnums.rhs_batch_dimensions().begin(),
|
||||
dnums.rhs_batch_dimensions().end());
|
||||
for (int64 i = 0; i < rhs_rank; i++) {
|
||||
if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
|
||||
rhs_non_batch_non_contracting_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
|
||||
const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
|
||||
|
||||
DimensionVector lhs_index(lhs_rank);
|
||||
DimensionVector rhs_index(rhs_rank);
|
||||
|
||||
// result_index_locations[i] contains one or two pointers to the locations
|
||||
// in lhs_index or rhs_index where the i'th result index should go.
|
||||
tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
|
||||
result_index_locations;
|
||||
result_index_locations.reserve(lhs_rank + rhs_rank - 2);
|
||||
|
||||
// The first components in the output shape are the LHS and RHS batch
|
||||
// dimensions:
|
||||
for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) {
|
||||
result_index_locations.push_back(
|
||||
{&lhs_index[dnums.lhs_batch_dimensions(i)],
|
||||
&rhs_index[dnums.rhs_batch_dimensions(i)]});
|
||||
}
|
||||
|
||||
// Then we have the LHS and RHS non-contracting dimensions, if any:
|
||||
for (int64 i = 0; i < lhs_rank; i++) {
|
||||
if (i != lhs_contracting_dimension &&
|
||||
!ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) {
|
||||
result_index_locations.push_back({&lhs_index[i], nullptr});
|
||||
}
|
||||
}
|
||||
for (int64 i = 0; i < rhs_rank; i++) {
|
||||
if (i != rhs_contracting_dimension &&
|
||||
!ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) {
|
||||
result_index_locations.push_back({&rhs_index[i], nullptr});
|
||||
}
|
||||
}
|
||||
|
||||
auto result = MakeUnique<Literal>(dot->shape());
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
|
||||
ElementwiseT result_val = static_cast<ElementwiseT>(0);
|
||||
|
||||
// Find the corresponding non-contracting indices for lhs and rhs.
|
||||
//
|
||||
// For `result_index`, its batch dimension, if exists, will be at the
|
||||
// same dimension as the batch dimension of lhs and rhs. More
|
||||
// specifically:
|
||||
// - For lhs, the non-contracting dimensions, including the batch
|
||||
// dimension have the same index as the `result_index`.
|
||||
// - For rhs, the batch dimension is set seperately from other
|
||||
// non-contracting dimensions, since these other non-contracting
|
||||
// dimensions in rhs follow the non-contracting dimensions of lhs in
|
||||
// the resulting index.
|
||||
//
|
||||
// As an example, for a resulting index:
|
||||
// result_index [result_batch, result_x, result_y]
|
||||
// the effecting lhs and rhs indices are:
|
||||
// lhs [result_batch, lhs_non_contracting_dim, contracting_dim
|
||||
// rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
|
||||
// `result_x` is only affected by the lhs_non_contracting_dim and
|
||||
// likewise `result_y` only depends on rhs_non_contracting_dim.
|
||||
//
|
||||
// so we can look up the lhs and rhs indices by:
|
||||
//
|
||||
// lhs:
|
||||
// batch index is the same as `result_batch`.
|
||||
// non-contracting dimension is the same as
|
||||
// result_index[lhs_non_contracting_dim]
|
||||
// rhs:
|
||||
// batch index: the same as `result_batch`.
|
||||
// non-contracting dimension index: *not* the same as
|
||||
// result_index[rhs_non_contractng_dim], since the
|
||||
// non-contracting dimensions of lhs are included in the
|
||||
// result_index first. Instead, the non_contracting_dim of rhs must
|
||||
// be calculated as following:
|
||||
// lhs_non_contracting_dimensions_size +
|
||||
// (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
|
||||
//
|
||||
// Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
|
||||
// the index offset to the result_index that only depends on
|
||||
// the non_batch and non-contracting dimensions of rhs. -1 at the
|
||||
// end translates size to index.
|
||||
for (auto i : lhs_non_contracting_dims) {
|
||||
lhs_index[i] = result_index[i];
|
||||
}
|
||||
for (auto i : dnums.rhs_batch_dimensions()) {
|
||||
rhs_index[i] = result_index[i];
|
||||
}
|
||||
for (auto i : rhs_non_batch_non_contracting_dims) {
|
||||
const int64 rhs_non_batch_non_contracting_dim =
|
||||
lhs_non_contracting_size + (i - batch_dim_size) - 1;
|
||||
rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
|
||||
for (int64 i = 0; i < result_index.size(); i++) {
|
||||
*result_index_locations[i].first = result_index[i];
|
||||
if (result_index_locations[i].second) {
|
||||
*result_index_locations[i].second = result_index[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulates resulting product along the contracted dimension.
|
||||
|
@ -853,10 +853,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
|
||||
XLA_TEST_F(DotOperationTest,
|
||||
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
|
||||
DotOfGatherOptimizationWithConstRHSReverseMM)))) {
|
||||
|
||||
DotOfGatherOptimizationWithConstRHSReverseMM) {
|
||||
std::unique_ptr<Array2D<float>> constant_lhs_array(
|
||||
new Array2D<float>({{1.0, 2.0, 3.0},
|
||||
{4.0, 5.0, 6.0},
|
||||
@ -883,10 +882,7 @@ XLA_TEST_F(DotOperationTest,
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
|
||||
XLA_TEST_F(DotOperationTest,
|
||||
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
|
||||
DotOfGatherOptimizationWithConstLHSReverseMM)))) {
|
||||
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
|
||||
std::unique_ptr<Array2D<float>> constant_lhs_array(
|
||||
new Array2D<float>({{1.0, 2.0, 3.0},
|
||||
{4.0, 5.0, 6.0},
|
||||
@ -913,10 +909,7 @@ XLA_TEST_F(DotOperationTest,
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
|
||||
XLA_TEST_F(DotOperationTest,
|
||||
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
|
||||
DotOfGatherOptimizationWithConstRHSRows)))) {
|
||||
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
|
||||
std::unique_ptr<Array2D<float>> constant_lhs_array(
|
||||
new Array2D<float>({{1.0, 2.0},
|
||||
{3.0, 4.0},
|
||||
@ -948,10 +941,7 @@ XLA_TEST_F(DotOperationTest,
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
|
||||
XLA_TEST_F(DotOperationTest,
|
||||
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
|
||||
DotOfGatherOptimizationWithConstLHSRows)))) {
|
||||
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
|
||||
std::unique_ptr<Array2D<float>> constant_lhs_array(
|
||||
new Array2D<float>({{1.0, 2.0},
|
||||
{3.0, 4.0},
|
||||
@ -983,10 +973,7 @@ XLA_TEST_F(DotOperationTest,
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
|
||||
XLA_TEST_F(DotOperationTest,
|
||||
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
|
||||
DotOfGatherOptimizationWithConstRHSCols)))) {
|
||||
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
|
||||
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
|
||||
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
|
||||
std::unique_ptr<Array2D<float>> constant_rhs_array(
|
||||
@ -1010,10 +997,7 @@ XLA_TEST_F(DotOperationTest,
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
|
||||
XLA_TEST_F(DotOperationTest,
|
||||
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
|
||||
DotOfGatherOptimizationWithConstLHSCols)))) {
|
||||
XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
|
||||
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
|
||||
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
|
||||
std::unique_ptr<Array2D<float>> constant_rhs_array(
|
||||
@ -1036,5 +1020,28 @@ XLA_TEST_F(DotOperationTest,
|
||||
Array2D<float> expected({{168.0}, {168.0}});
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
|
||||
auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
|
||||
|
||||
Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
|
||||
auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(0);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
DotGeneral(lhs_constant, rhs_constant, dot_dnums);
|
||||
|
||||
Array2D<float> expected({
|
||||
{26.f, 30.f},
|
||||
{38.f, 44.f},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user