[XLA] Allow non-monotonic batch dimensions, they are supported by the

evaluator.

PiperOrigin-RevId: 226702189
This commit is contained in:
Blake Hechtman 2018-12-23 14:28:34 -08:00 committed by TensorFlower Gardener
parent 1853b08d9c
commit 6e304fadf8
2 changed files with 17 additions and 37 deletions

View File

@ -516,18 +516,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
// Current DotDimensionNumbers Requirements:
//
// Contracting Dimensions:
// *) Same number of contracting dimensions on both lhs and rhs.
// *) Contracting dimension size must be the same on both lhs and rhs.
// *) Contracting dimension numbers do not need to be the same (i.e. transposes
// are passed on to emitter implementations).
//
// Batch Dimensions:
// *) Same number of batch dimensions on both lhs and rhs.
// *) Same batch dimension numbers (and sizes) on both lhs and rhs.
// *) Batch dimension numbers must be ordered before contracting and
// non-contracting/non-batch dimension numbers.
//
// Non-Contracting-Non-Batch Dimensions:
// *) Can be 0 (matrix-vector) or 1 (matrix-matrix).
// *) Same batch dimension sizes on both lhs and rhs.
//
namespace {
@ -580,19 +574,6 @@ Status ValidateDotDimensionNumbers(
dimension_numbers.DebugString());
}
// Check that batch dimension numbers are ordered before all others, and
// that they are monotonically increasing.
std::vector<int64> batch_dim_numbers(lhs_batch_dimensions.size());
std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0);
if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
lhs_batch_dimensions.begin()) ||
!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
rhs_batch_dimensions.begin())) {
return InvalidArgument(
"Batch dimension numbers must precede non-batch dimensions and be"
"monotonically increasing.");
}
return Status::OK();
}
@ -654,11 +635,9 @@ Status ValidateDotDimensionNumbers(
// Check that batch dimension numbers and sizes match.
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
if (dimension_numbers.lhs_batch_dimensions(i) !=
dimension_numbers.rhs_batch_dimensions(i) ||
lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
return fail("Batch dimension numbers and sizes must match for lhs/rhs.");
return fail("Batch dimension sizes must match for lhs/rhs.");
}
}
@ -668,19 +647,20 @@ Status ValidateDotDimensionNumbers(
// Generate the result dimensions in order, rhs dimensions followed by lhs
// dimensions except the contracted and batch dimensions.
std::vector<int64> dimensions;
std::unordered_set<int64> rhs_batch_dims(
dimension_numbers.rhs_batch_dimensions().begin(),
dimension_numbers.rhs_batch_dimensions().end());
for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
dimensions.push_back(lhs.dimensions(lhs_dim));
}
for (int64 i = 0; i < lhs.rank(); i++) {
if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
i)) {
i) &&
!absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
dimensions.push_back(lhs.dimensions(i));
}
}
for (int64 i = 0; i < rhs.rank(); i++) {
if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
i) &&
rhs_batch_dims.count(i) == 0) {
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
dimensions.push_back(rhs.dimensions(i));
}
}

View File

@ -1153,11 +1153,11 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) {
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("Batch dimension numbers and sizes must match"));
HasSubstr("Batch dimension sizes must match"));
}
// BatchMatMul with different batch dimension numbers fails.
TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
// BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) {
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
@ -1170,9 +1170,9 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
auto inferred_status =
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("Batch dimension numbers must precede non-batch"));
ASSERT_TRUE(inferred_status.ok());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
ShapeUtil::MakeShape(F32, {2, 11, 14})));
}
// BatchMatMul with out-of-range dimension numbers fails.