SPMD sharding for convolutions with a dilated kernel
PiperOrigin-RevId: 318887130 Change-Id: Ied7c5ea0dde042fef675c52431b0e047288f0e90
This commit is contained in:
parent
92d15f9719
commit
1b53b995da
@ -418,21 +418,17 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
|
||||
continue;
|
||||
}
|
||||
const auto& wd = window.dimensions(i);
|
||||
if (wd.window_dilation() != 1) {
|
||||
// TODO(yuanzx): Support window dilation.
|
||||
VLOG(2) << "Failed to reshard window operand due to window dilation";
|
||||
return absl::nullopt;
|
||||
}
|
||||
const auto dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
|
||||
int64 full_size =
|
||||
base_shape_.dimensions(i) +
|
||||
(wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) +
|
||||
wd.padding_high() + wd.padding_low();
|
||||
if (full_size < wd.size()) {
|
||||
if (full_size < dilated_size) {
|
||||
VLOG(2) << "Failed to reshard window operand because the window size is "
|
||||
"larger than padded base size";
|
||||
return absl::nullopt;
|
||||
}
|
||||
int64 window_count = (full_size - wd.size()) / wd.stride() + 1;
|
||||
int64 window_count = (full_size - dilated_size) / wd.stride() + 1;
|
||||
per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
|
||||
if (wd.stride() != 1 &&
|
||||
(wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
|
||||
@ -457,7 +453,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
|
||||
wd.stride() * per_shard_window_counts[i],
|
||||
wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
|
||||
int64 dilated_shard_size =
|
||||
wd.stride() * (per_shard_window_counts[i] - 1) + wd.size();
|
||||
wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
|
||||
limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
|
||||
wd.stride() * per_shard_window_counts[i],
|
||||
dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
|
||||
@ -493,7 +489,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
|
||||
for (int64 shard_ordinal = 0; shard_ordinal < shard_count;
|
||||
++shard_ordinal) {
|
||||
int64 wanted_limit_on_dilated_shard =
|
||||
wd.stride() * (per_shard_window_counts[i] - 1) + wd.size();
|
||||
wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
|
||||
int64 actual_limit_on_dilated_shard_without_pad_high =
|
||||
get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
|
||||
(max_shard_size - 1) * wd.base_dilation() + 1;
|
||||
|
Loading…
x
Reference in New Issue
Block a user