[XLA] Don't crash in when trying to simplify batch dots with no contraction dims
Dot with no contraction is just a multiply. Not a super useful operation but valid HLO. AlgebraicSimplifier will rewrite it into a multiply, so don't even try simplifying it. PiperOrigin-RevId: 234010983
This commit is contained in:
parent
c314d0cd20
commit
b51e216294
@ -28,6 +28,13 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
|
||||
*rhs = batch_dot->mutable_operand(1);
|
||||
const Shape& lhs_shape = lhs->shape();
|
||||
|
||||
// A dot with no contracting dims will be rewritten into a multiply by
|
||||
// AlgebraicSimplifier. Dots with multiple contracting dims are currently
|
||||
// unsupported.
|
||||
if (dim_numbers.lhs_contracting_dimensions_size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64> degenerate_dims;
|
||||
for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) {
|
||||
if (lhs_shape.dimensions(batch_dim) == 1) {
|
||||
|
||||
@ -169,5 +169,47 @@ main {
|
||||
/*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2)));
|
||||
}
|
||||
|
||||
TEST_F(BatchDotSimplificationTest,
|
||||
ElideMultipleDegenerateBatchDotDimsNonContracting) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule BatchDot
|
||||
|
||||
main {
|
||||
a = f32[1,101] parameter(0)
|
||||
b = f32[1,101] parameter(1)
|
||||
ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0},
|
||||
lhs_contracting_dims={},
|
||||
rhs_batch_dims={0},
|
||||
rhs_contracting_dims={}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
BatchDotSimplification pass;
|
||||
ASSERT_FALSE(pass.Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(BatchDotSimplificationTest,
|
||||
ElideMultipleDegenerateBatchDotDimsMultipleContracting) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule BatchDot
|
||||
|
||||
main {
|
||||
lhs = f32[1,5,17,10,13] parameter(0)
|
||||
rhs = f32[1,9,10,13,6,5] parameter(1)
|
||||
ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0},
|
||||
rhs_batch_dims={2,0},
|
||||
lhs_contracting_dims={1,4},
|
||||
rhs_contracting_dims={5,3}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
BatchDotSimplification pass;
|
||||
ASSERT_FALSE(pass.Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user