[XLA] Simplify batch dots that have no contracting dimensions into multiplies.
PiperOrigin-RevId: 233637534
This commit is contained in:
parent
ce3ee9b6a4
commit
ba941c212d
@ -1605,29 +1605,50 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If a dot only contains batch dimensions, then tranpose the rhs and lhs
|
||||
// acording to the batch dimension numbers and do a simple multiply.
|
||||
if (lhs->shape().rank() ==
|
||||
dot->dot_dimension_numbers().lhs_batch_dimensions_size() &&
|
||||
rhs->shape().rank() ==
|
||||
dot->dot_dimension_numbers().rhs_batch_dimensions_size()) {
|
||||
HloInstruction* new_rhs = rhs;
|
||||
HloInstruction* new_lhs = lhs;
|
||||
if (lhs->shape().rank() > 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
new_rhs,
|
||||
MakeTransposeHlo(
|
||||
rhs, AsInt64Slice(
|
||||
dot->dot_dimension_numbers().rhs_batch_dimensions())));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
new_lhs,
|
||||
MakeTransposeHlo(
|
||||
lhs, AsInt64Slice(
|
||||
dot->dot_dimension_numbers().lhs_batch_dimensions())));
|
||||
// If there are no contracting dimensions, a dot can be rewritten as
|
||||
// mul(broadcast(transpose(x)),broadcast(transpose(y)))
|
||||
if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
|
||||
std::vector<int64> lhs_transpose(
|
||||
dot->dot_dimension_numbers().lhs_batch_dimensions().begin(),
|
||||
dot->dot_dimension_numbers().lhs_batch_dimensions().end());
|
||||
for (int64 i = 0; i < lhs->shape().rank(); ++i) {
|
||||
if (!absl::c_linear_search(
|
||||
dot->dot_dimension_numbers().lhs_batch_dimensions(), i)) {
|
||||
lhs_transpose.push_back(i);
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto new_dot,
|
||||
MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
|
||||
return ReplaceInstruction(dot, new_dot);
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
|
||||
MakeTransposeHlo(lhs, lhs_transpose));
|
||||
if (dot->shape().rank() != lhs->shape().rank()) {
|
||||
std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
|
||||
absl::c_iota(lhs_broadcast_dims, 0);
|
||||
new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
dot->shape(), new_lhs, lhs_broadcast_dims));
|
||||
}
|
||||
std::vector<int64> rhs_transpose(
|
||||
dot->dot_dimension_numbers().rhs_batch_dimensions().begin(),
|
||||
dot->dot_dimension_numbers().rhs_batch_dimensions().end());
|
||||
for (int64 i = 0; i < rhs->shape().rank(); ++i) {
|
||||
if (!absl::c_linear_search(
|
||||
dot->dot_dimension_numbers().rhs_batch_dimensions(), i)) {
|
||||
rhs_transpose.push_back(i);
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
|
||||
MakeTransposeHlo(rhs, rhs_transpose));
|
||||
if (dot->shape().rank() != rhs->shape().rank()) {
|
||||
std::vector<int64> rhs_broadcast_dims(
|
||||
dot->dot_dimension_numbers().lhs_batch_dimensions_size());
|
||||
absl::c_iota(rhs_broadcast_dims, 0);
|
||||
for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
|
||||
rhs_broadcast_dims.push_back(i);
|
||||
}
|
||||
new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
dot->shape(), new_rhs, rhs_broadcast_dims));
|
||||
}
|
||||
return ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
|
||||
new_lhs, new_rhs));
|
||||
}
|
||||
|
||||
if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
|
||||
|
@ -4222,8 +4222,10 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
|
||||
std::tie(m, k, n, element_type) = GetParam();
|
||||
|
||||
Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n});
|
||||
Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k});
|
||||
Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n});
|
||||
Shape lhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k})
|
||||
: ShapeUtil::MakeShape(element_type, {1, 3, 5, m});
|
||||
Shape rhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n})
|
||||
: ShapeUtil::MakeShape(element_type, {1, 3, 5, n});
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
auto lhs = builder.AddInstruction(
|
||||
@ -4237,14 +4239,16 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
|
||||
dot_dnums.add_rhs_batch_dimensions(0);
|
||||
dot_dnums.add_rhs_batch_dimensions(1);
|
||||
dot_dnums.add_rhs_batch_dimensions(2);
|
||||
dot_dnums.add_lhs_contracting_dimensions(4);
|
||||
dot_dnums.add_rhs_contracting_dimensions(3);
|
||||
if (k > 0) {
|
||||
dot_dnums.add_lhs_contracting_dimensions(4);
|
||||
dot_dnums.add_rhs_contracting_dimensions(3);
|
||||
}
|
||||
builder.AddInstruction(HloInstruction::CreateDot(
|
||||
dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
AlgebraicSimplifier simplifier(default_options_);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
|
||||
const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
|
||||
const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1 || k == -1;
|
||||
const bool computation_should_be_modified = dot_should_be_transformed;
|
||||
EXPECT_EQ(changed, computation_should_be_modified);
|
||||
bool has_no_dot = true;
|
||||
@ -4259,7 +4263,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest,
|
||||
::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
|
||||
::testing::Combine(::testing::Values(1, 2), ::testing::Values(-1, 1, 2),
|
||||
::testing::Values(1, 2), ::testing::Values(F32, BF16)));
|
||||
|
||||
class DotStrengthReductionTest
|
||||
|
@ -2700,11 +2700,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
const Shape& operand, absl::Span<const int64> dimensions) {
|
||||
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
|
||||
|
||||
std::vector<int64> indices(operand.rank());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
if (dimensions.size() != operand.rank() ||
|
||||
!std::is_permutation(dimensions.begin(), dimensions.end(),
|
||||
indices.begin())) {
|
||||
if (!IsPermutation(dimensions, operand.rank())) {
|
||||
return InvalidArgument(
|
||||
"Transpose dimensions [%s] are not a permutation of the operand "
|
||||
"dimensions (operand shape is %s).",
|
||||
|
@ -1572,6 +1572,16 @@ TEST_F(ShapeInferenceTest, Transpose) {
|
||||
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, Rank1Transpose) {
|
||||
Shape a_shape = ShapeUtil::MakeShape(F32, {5});
|
||||
auto inferred_shape_and_status =
|
||||
ShapeInference::InferTransposeShape(a_shape, {0});
|
||||
EXPECT_IS_OK(inferred_shape_and_status);
|
||||
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, Conditional) {
|
||||
auto inferred_status0 = ShapeInference::InferConditionalShape(
|
||||
pred_, vector_32_, vector_64_,
|
||||
|
@ -1189,6 +1189,8 @@ std::vector<EinsumParamType> GetEinsumTestCases() {
|
||||
p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
|
||||
p{v{6}, v{6, 7}, "b,bc->c"},
|
||||
p{v{77}, v{77}, "a,a->a"},
|
||||
p{v{77}, v{77, 55}, "a,ab->ba"},
|
||||
p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
|
||||
p{v{55}, v{}, "a,->a"},
|
||||
p{v{11, 111}, v{11}, "ab,a->ab"},
|
||||
p{v{16, 34}, v{16, 34}, "ab,ab->ab"},
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <stdarg.h>
|
||||
#include <numeric>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -80,13 +81,9 @@ bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
|
||||
if (rank != permutation.size()) {
|
||||
return false;
|
||||
}
|
||||
std::vector<int64> output(permutation.size(), -1);
|
||||
for (auto index : permutation) {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, rank);
|
||||
output[index] = 0;
|
||||
}
|
||||
return !absl::c_linear_search(output, -1);
|
||||
absl::InlinedVector<int64, 8> trivial_permutation(rank);
|
||||
absl::c_iota(trivial_permutation, 0);
|
||||
return absl::c_is_permutation(permutation, trivial_permutation);
|
||||
}
|
||||
|
||||
std::vector<int64> InversePermutation(
|
||||
|
Loading…
x
Reference in New Issue
Block a user