SPMD sharding for convolutions with a dilated kernel

PiperOrigin-RevId: 318887130
Change-Id: Ied7c5ea0dde042fef675c52431b0e047288f0e90
This commit is contained in:
A. Unique TensorFlower 2020-06-29 13:54:09 -07:00 committed by TensorFlower Gardener
parent 92d15f9719
commit 1b53b995da

View File

@ -418,21 +418,17 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
continue; continue;
} }
const auto& wd = window.dimensions(i); const auto& wd = window.dimensions(i);
if (wd.window_dilation() != 1) { const auto dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
// TODO(yuanzx): Support window dilation.
VLOG(2) << "Failed to reshard window operand due to window dilation";
return absl::nullopt;
}
int64 full_size = int64 full_size =
base_shape_.dimensions(i) + base_shape_.dimensions(i) +
(wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) + (wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) +
wd.padding_high() + wd.padding_low(); wd.padding_high() + wd.padding_low();
if (full_size < wd.size()) { if (full_size < dilated_size) {
VLOG(2) << "Failed to reshard window operand because the window size is " VLOG(2) << "Failed to reshard window operand because the window size is "
"larger than padded base size"; "larger than padded base size";
return absl::nullopt; return absl::nullopt;
} }
int64 window_count = (full_size - wd.size()) / wd.stride() + 1; int64 window_count = (full_size - dilated_size) / wd.stride() + 1;
per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count); per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
if (wd.stride() != 1 && if (wd.stride() != 1 &&
(wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) { (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
@ -457,7 +453,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
wd.stride() * per_shard_window_counts[i], wd.stride() * per_shard_window_counts[i],
wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation()); wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
int64 dilated_shard_size = int64 dilated_shard_size =
wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
wd.stride() * per_shard_window_counts[i], wd.stride() * per_shard_window_counts[i],
dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(), dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
@ -493,7 +489,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
for (int64 shard_ordinal = 0; shard_ordinal < shard_count; for (int64 shard_ordinal = 0; shard_ordinal < shard_count;
++shard_ordinal) { ++shard_ordinal) {
int64 wanted_limit_on_dilated_shard = int64 wanted_limit_on_dilated_shard =
wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
int64 actual_limit_on_dilated_shard_without_pad_high = int64 actual_limit_on_dilated_shard_without_pad_high =
get_first_valid_element_offset_on_dilated_shard(shard_ordinal) + get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
(max_shard_size - 1) * wd.base_dilation() + 1; (max_shard_size - 1) * wd.base_dilation() + 1;