diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 84d16633110..b0e241d216d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -516,18 +516,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // Current DotDimensionNumbers Requirements: // // Contracting Dimensions: +// *) Same number of contracting dimensions on both lhs and rhs. // *) Contracting dimension size must be the same on both lhs and rhs. -// *) Contracting dimension numbers do not need to be the same (i.e. transposes -// are passed on to emitter implementations). // // Batch Dimensions: // *) Same number of batch dimensions on both lhs and rhs. -// *) Same batch dimension numbers (and sizes) on both lhs and rhs. -// *) Batch dimension numbers must be ordered before contracting and -// non-contracting/non-batch dimension numbers. -// -// Non-Contracting-Non-Batch Dimensions: -// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// *) Same batch dimension sizes on both lhs and rhs. // namespace { @@ -580,19 +574,6 @@ Status ValidateDotDimensionNumbers( dimension_numbers.DebugString()); } - // Check that batch dimension numbers are ordered before all others, and - // that they are monotonically increasing. - std::vector batch_dim_numbers(lhs_batch_dimensions.size()); - std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); - if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), - lhs_batch_dimensions.begin()) || - !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), - rhs_batch_dimensions.begin())) { - return InvalidArgument( - "Batch dimension numbers must precede non-batch dimensions and be" - "monotonically increasing."); - } - return Status::OK(); } @@ -654,11 +635,9 @@ Status ValidateDotDimensionNumbers( // Check that batch dimension numbers and sizes match. for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { - if (dimension_numbers.lhs_batch_dimensions(i) != - dimension_numbers.rhs_batch_dimensions(i) || - lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != - rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { - return fail("Batch dimension numbers and sizes must match for lhs/rhs."); + if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { + return fail("Batch dimension sizes must match for lhs/rhs."); } } @@ -668,19 +647,20 @@ Status ValidateDotDimensionNumbers( // Generate the result dimensions in order, rhs dimensions followed by lhs // dimensions except the contracted and batch dimensions. std::vector dimensions; - std::unordered_set rhs_batch_dims( - dimension_numbers.rhs_batch_dimensions().begin(), - dimension_numbers.rhs_batch_dimensions().end()); + for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) { + dimensions.push_back(lhs.dimensions(lhs_dim)); + } for (int64 i = 0; i < lhs.rank(); i++) { if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(), - i)) { + i) && + !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { dimensions.push_back(lhs.dimensions(i)); } } for (int64 i = 0; i < rhs.rank(); i++) { if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(), i) && - rhs_batch_dims.count(i) == 0) { + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) { dimensions.push_back(rhs.dimensions(i)); } } diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index b44ae08af13..bb00bfd2eee 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1153,11 +1153,11 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch dimension numbers and sizes must match")); + HasSubstr("Batch dimension sizes must match")); } -// BatchMatMul with different batch dimension numbers fails. -TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { +// BatchMatMul with different batch dimension numbers passes +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); @@ -1170,9 +1170,9 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch dimension numbers must precede non-batch")); + ASSERT_TRUE(inferred_status.ok()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), + ShapeUtil::MakeShape(F32, {2, 11, 14}))); } // BatchMatMul with out-of-range dimension numbers fails.