diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index eda026ac568..dbabd82dd55 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -28,6 +28,13 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( *rhs = batch_dot->mutable_operand(1); const Shape& lhs_shape = lhs->shape(); + // A dot with no contracting dims will be rewritten into a multiply by + // AlgebraicSimplifier. Dots with multiple contracting dims are currently + // unsupported. + if (dim_numbers.lhs_contracting_dimensions_size() != 1) { + return false; + } + std::vector degenerate_dims; for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { if (lhs_shape.dimensions(batch_dim) == 1) { diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 52ec1a794c5..a81f394a38f 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -169,5 +169,47 @@ main { /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); } +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsNonContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,101] parameter(0) + b = f32[1,101] parameter(1) + ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0}, + lhs_contracting_dims={}, + rhs_batch_dims={0}, + rhs_contracting_dims={} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsMultipleContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + lhs = f32[1,5,17,10,13] parameter(0) + rhs = f32[1,9,10,13,6,5] parameter(1) + ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0}, + rhs_batch_dims={2,0}, + lhs_contracting_dims={1,4}, + rhs_contracting_dims={5,3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + } // namespace } // namespace xla