Support 0-sized shape with dynamic reshapes.

- Return finer granule control in CommonFactors when input and output are the same and one of dimensions is zero.
- Skip instruction with dynamic size in ZeroSizedHloElimination.

PiperOrigin-RevId: 352189050
Change-Id: Ieda2d8c89d20e66670b3cd85484f4e1e97bc4c93
This commit is contained in:
Yunxing Dai 2021-01-16 10:15:40 -08:00 committed by TensorFlower Gardener
parent a7c1a23ebb
commit 3b94e9cfdb
4 changed files with 15 additions and 4 deletions

View File

@ -36,7 +36,8 @@ StatusOr<bool> ZeroSizedHloElimination::Run(HloModule* module) {
continue;
}
if (comp->IsSafelyRemovable(instruction) &&
ShapeUtil::IsZeroElementArray(instruction->shape())) {
ShapeUtil::IsZeroElementArray(instruction->shape()) &&
instruction->shape().is_static()) {
// If the instruction doesn't have a layout, use a default layout for
// the literal.
Shape shape = instruction->shape();

View File

@ -265,11 +265,18 @@ int64 Product(absl::Span<const int64> xs) {
absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
absl::Span<const int64> a, absl::Span<const int64> b) {
CHECK_EQ(Product(a), Product(b));
absl::InlinedVector<std::pair<int64, int64>, 8> bounds;
if (absl::c_equal(a, b)) {
bounds.reserve(a.size() + 1);
for (int64 i = 0; i <= a.size(); ++i) {
bounds.emplace_back(i, i);
}
return bounds;
}
if (0 == Product(a)) {
return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
}
absl::InlinedVector<std::pair<int64, int64>, 8> bounds;
for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
partial_size_b = 1;
;) {

View File

@ -502,8 +502,10 @@ int64 Product(absl::Span<const int64> xs);
// b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1]
// where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)`
//
// If the given shapes have non-zero size, returns the bounds of the shortest
// possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`.
// If input and output are the same, return {(0, 0), {1, 1}, ... {a.size,
// b.size}}, otherwise if the given shapes have non-zero size, returns the
// bounds of the shortest possible such subsequences; else, returns `{(0, 0),
// (a.size, b.size)}`.
absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
absl::Span<const int64> a, absl::Span<const int64> b);

View File

@ -74,6 +74,7 @@ TEST(UtilTest, CommonFactors) {
absl::InlinedVector<std::pair<int64, int64>, 8> expected;
} test_cases[] = {
{/*.a =*/{0}, /*.b =*/{0}, /*.expected =*/{{0, 0}, {1, 1}}},
{/*.a =*/{0, 1}, /*.b =*/{0, 1}, /*.expected =*/{{0, 0}, {1, 1}, {2, 2}}},
{/*.a =*/{}, /*.b =*/{}, /*.expected =*/{{0, 0}}},
{/*.a =*/{2, 5, 1, 3},
/*.b =*/{1, 10, 3, 1},