[XLA] Allow non-monotonic batch dimensions, they are supported by the
evaluator. PiperOrigin-RevId: 226702189
This commit is contained in:
parent
1853b08d9c
commit
6e304fadf8
@ -516,18 +516,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
|||||||
// Current DotDimensionNumbers Requirements:
|
// Current DotDimensionNumbers Requirements:
|
||||||
//
|
//
|
||||||
// Contracting Dimensions:
|
// Contracting Dimensions:
|
||||||
|
// *) Same number of contracting dimensions on both lhs and rhs.
|
||||||
// *) Contracting dimension size must be the same on both lhs and rhs.
|
// *) Contracting dimension size must be the same on both lhs and rhs.
|
||||||
// *) Contracting dimension numbers do not need to be the same (i.e. transposes
|
|
||||||
// are passed on to emitter implementations).
|
|
||||||
//
|
//
|
||||||
// Batch Dimensions:
|
// Batch Dimensions:
|
||||||
// *) Same number of batch dimensions on both lhs and rhs.
|
// *) Same number of batch dimensions on both lhs and rhs.
|
||||||
// *) Same batch dimension numbers (and sizes) on both lhs and rhs.
|
// *) Same batch dimension sizes on both lhs and rhs.
|
||||||
// *) Batch dimension numbers must be ordered before contracting and
|
|
||||||
// non-contracting/non-batch dimension numbers.
|
|
||||||
//
|
|
||||||
// Non-Contracting-Non-Batch Dimensions:
|
|
||||||
// *) Can be 0 (matrix-vector) or 1 (matrix-matrix).
|
|
||||||
//
|
//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -580,19 +574,6 @@ Status ValidateDotDimensionNumbers(
|
|||||||
dimension_numbers.DebugString());
|
dimension_numbers.DebugString());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that batch dimension numbers are ordered before all others, and
|
|
||||||
// that they are monotonically increasing.
|
|
||||||
std::vector<int64> batch_dim_numbers(lhs_batch_dimensions.size());
|
|
||||||
std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0);
|
|
||||||
if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
|
|
||||||
lhs_batch_dimensions.begin()) ||
|
|
||||||
!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
|
|
||||||
rhs_batch_dimensions.begin())) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Batch dimension numbers must precede non-batch dimensions and be"
|
|
||||||
"monotonically increasing.");
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -654,11 +635,9 @@ Status ValidateDotDimensionNumbers(
|
|||||||
|
|
||||||
// Check that batch dimension numbers and sizes match.
|
// Check that batch dimension numbers and sizes match.
|
||||||
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
|
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
|
||||||
if (dimension_numbers.lhs_batch_dimensions(i) !=
|
if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
|
||||||
dimension_numbers.rhs_batch_dimensions(i) ||
|
|
||||||
lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
|
|
||||||
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
|
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
|
||||||
return fail("Batch dimension numbers and sizes must match for lhs/rhs.");
|
return fail("Batch dimension sizes must match for lhs/rhs.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -668,19 +647,20 @@ Status ValidateDotDimensionNumbers(
|
|||||||
// Generate the result dimensions in order, rhs dimensions followed by lhs
|
// Generate the result dimensions in order, rhs dimensions followed by lhs
|
||||||
// dimensions except the contracted and batch dimensions.
|
// dimensions except the contracted and batch dimensions.
|
||||||
std::vector<int64> dimensions;
|
std::vector<int64> dimensions;
|
||||||
std::unordered_set<int64> rhs_batch_dims(
|
for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
|
||||||
dimension_numbers.rhs_batch_dimensions().begin(),
|
dimensions.push_back(lhs.dimensions(lhs_dim));
|
||||||
dimension_numbers.rhs_batch_dimensions().end());
|
}
|
||||||
for (int64 i = 0; i < lhs.rank(); i++) {
|
for (int64 i = 0; i < lhs.rank(); i++) {
|
||||||
if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
|
if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
|
||||||
i)) {
|
i) &&
|
||||||
|
!absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
|
||||||
dimensions.push_back(lhs.dimensions(i));
|
dimensions.push_back(lhs.dimensions(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int64 i = 0; i < rhs.rank(); i++) {
|
for (int64 i = 0; i < rhs.rank(); i++) {
|
||||||
if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
|
if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
|
||||||
i) &&
|
i) &&
|
||||||
rhs_batch_dims.count(i) == 0) {
|
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
|
||||||
dimensions.push_back(rhs.dimensions(i));
|
dimensions.push_back(rhs.dimensions(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1153,11 +1153,11 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) {
|
|||||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||||
ASSERT_FALSE(inferred_status.ok());
|
ASSERT_FALSE(inferred_status.ok());
|
||||||
ASSERT_THAT(inferred_status.status().error_message(),
|
ASSERT_THAT(inferred_status.status().error_message(),
|
||||||
HasSubstr("Batch dimension numbers and sizes must match"));
|
HasSubstr("Batch dimension sizes must match"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchMatMul with different batch dimension numbers fails.
|
// BatchMatMul with different batch dimension numbers passes
|
||||||
TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
|
TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) {
|
||||||
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
||||||
Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
|
||||||
|
|
||||||
@ -1170,9 +1170,9 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
|
|||||||
|
|
||||||
auto inferred_status =
|
auto inferred_status =
|
||||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||||
ASSERT_FALSE(inferred_status.ok());
|
ASSERT_TRUE(inferred_status.ok());
|
||||||
ASSERT_THAT(inferred_status.status().error_message(),
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
|
||||||
HasSubstr("Batch dimension numbers must precede non-batch"));
|
ShapeUtil::MakeShape(F32, {2, 11, 14})));
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchMatMul with out-of-range dimension numbers fails.
|
// BatchMatMul with out-of-range dimension numbers fails.
|
||||||
|
Loading…
Reference in New Issue
Block a user