[XLA] Allow non-monotonic batch dimensions, they are supported by the
evaluator. PiperOrigin-RevId: 226702189
This commit is contained in:
parent
1853b08d9c
commit
6e304fadf8
@ -516,18 +516,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
// Current DotDimensionNumbers Requirements:
|
||||
//
|
||||
// Contracting Dimensions:
|
||||
// *) Same number of contracting dimensions on both lhs and rhs.
|
||||
// *) Contracting dimension size must be the same on both lhs and rhs.
|
||||
// *) Contracting dimension numbers do not need to be the same (i.e. transposes
|
||||
// are passed on to emitter implementations).
|
||||
//
|
||||
// Batch Dimensions:
|
||||
// *) Same number of batch dimensions on both lhs and rhs.
|
||||
// *) Same batch dimension numbers (and sizes) on both lhs and rhs.
|
||||
// *) Batch dimension numbers must be ordered before contracting and
|
||||
// non-contracting/non-batch dimension numbers.
|
||||
//
|
||||
// Non-Contracting-Non-Batch Dimensions:
|
||||
// *) Can be 0 (matrix-vector) or 1 (matrix-matrix).
|
||||
// *) Same batch dimension sizes on both lhs and rhs.
|
||||
//
|
||||
|
||||
namespace {
|
||||
@ -580,19 +574,6 @@ Status ValidateDotDimensionNumbers(
|
||||
dimension_numbers.DebugString());
|
||||
}
|
||||
|
||||
// Check that batch dimension numbers are ordered before all others, and
|
||||
// that they are monotonically increasing.
|
||||
std::vector<int64> batch_dim_numbers(lhs_batch_dimensions.size());
|
||||
std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0);
|
||||
if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
|
||||
lhs_batch_dimensions.begin()) ||
|
||||
!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
|
||||
rhs_batch_dimensions.begin())) {
|
||||
return InvalidArgument(
|
||||
"Batch dimension numbers must precede non-batch dimensions and be"
|
||||
"monotonically increasing.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -654,11 +635,9 @@ Status ValidateDotDimensionNumbers(
|
||||
|
||||
// Check that batch dimension numbers and sizes match.
|
||||
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
|
||||
if (dimension_numbers.lhs_batch_dimensions(i) !=
|
||||
dimension_numbers.rhs_batch_dimensions(i) ||
|
||||
lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
|
||||
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
|
||||
return fail("Batch dimension numbers and sizes must match for lhs/rhs.");
|
||||
if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
|
||||
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
|
||||
return fail("Batch dimension sizes must match for lhs/rhs.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -668,19 +647,20 @@ Status ValidateDotDimensionNumbers(
|
||||
// Generate the result dimensions in order, rhs dimensions followed by lhs
|
||||
// dimensions except the contracted and batch dimensions.
|
||||
std::vector<int64> dimensions;
|
||||
std::unordered_set<int64> rhs_batch_dims(
|
||||
dimension_numbers.rhs_batch_dimensions().begin(),
|
||||
dimension_numbers.rhs_batch_dimensions().end());
|
||||
for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
|
||||
dimensions.push_back(lhs.dimensions(lhs_dim));
|
||||
}
|
||||
for (int64 i = 0; i < lhs.rank(); i++) {
|
||||
if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
|
||||
i)) {
|
||||
i) &&
|
||||
!absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
|
||||
dimensions.push_back(lhs.dimensions(i));
|
||||
}
|
||||
}
|
||||
for (int64 i = 0; i < rhs.rank(); i++) {
|
||||
if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
|
||||
i) &&
|
||||
rhs_batch_dims.count(i) == 0) {
|
||||
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
|
||||
dimensions.push_back(rhs.dimensions(i));
|
||||
}
|
||||
}
|
||||
|
@ -1153,11 +1153,11 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) {
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("Batch dimension numbers and sizes must match"));
|
||||
HasSubstr("Batch dimension sizes must match"));
|
||||
}
|
||||
|
||||
// BatchMatMul with different batch dimension numbers fails.
|
||||
TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
|
||||
// BatchMatMul with different batch dimension numbers passes
|
||||
TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) {
|
||||
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
|
||||
|
||||
@ -1170,9 +1170,9 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
|
||||
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("Batch dimension numbers must precede non-batch"));
|
||||
ASSERT_TRUE(inferred_status.ok());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
|
||||
ShapeUtil::MakeShape(F32, {2, 11, 14})));
|
||||
}
|
||||
|
||||
// BatchMatMul with out-of-range dimension numbers fails.
|
||||
|
Loading…
Reference in New Issue
Block a user