[XLA] Don't crash in when trying to simplify batch dots with no contraction dims

Dot with no contraction is just a multiply. Not a super useful operation but
valid HLO. AlgebraicSimplifier will rewrite it into a multiply, so don't even
try simplifying it.

PiperOrigin-RevId: 234010983
This commit is contained in:
Benjamin Kramer 2019-02-14 12:51:22 -08:00 committed by TensorFlower Gardener
parent c314d0cd20
commit b51e216294
2 changed files with 49 additions and 0 deletions

View File

@ -28,6 +28,13 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
*rhs = batch_dot->mutable_operand(1);
const Shape& lhs_shape = lhs->shape();
// A dot with no contracting dims will be rewritten into a multiply by
// AlgebraicSimplifier. Dots with multiple contracting dims are currently
// unsupported.
if (dim_numbers.lhs_contracting_dimensions_size() != 1) {
return false;
}
std::vector<int64> degenerate_dims;
for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) {
if (lhs_shape.dimensions(batch_dim) == 1) {

View File

@ -169,5 +169,47 @@ main {
/*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2)));
}
TEST_F(BatchDotSimplificationTest,
ElideMultipleDegenerateBatchDotDimsNonContracting) {
const char* hlo_text = R"(
HloModule BatchDot
main {
a = f32[1,101] parameter(0)
b = f32[1,101] parameter(1)
ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0},
lhs_contracting_dims={},
rhs_batch_dims={0},
rhs_contracting_dims={}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(hlo_text));
BatchDotSimplification pass;
ASSERT_FALSE(pass.Run(m.get()).ValueOrDie());
}
TEST_F(BatchDotSimplificationTest,
ElideMultipleDegenerateBatchDotDimsMultipleContracting) {
const char* hlo_text = R"(
HloModule BatchDot
main {
lhs = f32[1,5,17,10,13] parameter(0)
rhs = f32[1,9,10,13,6,5] parameter(1)
ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0},
rhs_batch_dims={2,0},
lhs_contracting_dims={1,4},
rhs_contracting_dims={5,3}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(hlo_text));
BatchDotSimplification pass;
ASSERT_FALSE(pass.Run(m.get()).ValueOrDie());
}
} // namespace
} // namespace xla