SPMD sharding for convolutions with a dilated kernel
PiperOrigin-RevId: 318887130 Change-Id: Ied7c5ea0dde042fef675c52431b0e047288f0e90
This commit is contained in:
parent
92d15f9719
commit
1b53b995da
@ -418,21 +418,17 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const auto& wd = window.dimensions(i);
|
const auto& wd = window.dimensions(i);
|
||||||
if (wd.window_dilation() != 1) {
|
const auto dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
|
||||||
// TODO(yuanzx): Support window dilation.
|
|
||||||
VLOG(2) << "Failed to reshard window operand due to window dilation";
|
|
||||||
return absl::nullopt;
|
|
||||||
}
|
|
||||||
int64 full_size =
|
int64 full_size =
|
||||||
base_shape_.dimensions(i) +
|
base_shape_.dimensions(i) +
|
||||||
(wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) +
|
(wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) +
|
||||||
wd.padding_high() + wd.padding_low();
|
wd.padding_high() + wd.padding_low();
|
||||||
if (full_size < wd.size()) {
|
if (full_size < dilated_size) {
|
||||||
VLOG(2) << "Failed to reshard window operand because the window size is "
|
VLOG(2) << "Failed to reshard window operand because the window size is "
|
||||||
"larger than padded base size";
|
"larger than padded base size";
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
int64 window_count = (full_size - wd.size()) / wd.stride() + 1;
|
int64 window_count = (full_size - dilated_size) / wd.stride() + 1;
|
||||||
per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
|
per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
|
||||||
if (wd.stride() != 1 &&
|
if (wd.stride() != 1 &&
|
||||||
(wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
|
(wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
|
||||||
@ -457,7 +453,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
|
|||||||
wd.stride() * per_shard_window_counts[i],
|
wd.stride() * per_shard_window_counts[i],
|
||||||
wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
|
wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
|
||||||
int64 dilated_shard_size =
|
int64 dilated_shard_size =
|
||||||
wd.stride() * (per_shard_window_counts[i] - 1) + wd.size();
|
wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
|
||||||
limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
|
limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
|
||||||
wd.stride() * per_shard_window_counts[i],
|
wd.stride() * per_shard_window_counts[i],
|
||||||
dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
|
dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
|
||||||
@ -493,7 +489,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window,
|
|||||||
for (int64 shard_ordinal = 0; shard_ordinal < shard_count;
|
for (int64 shard_ordinal = 0; shard_ordinal < shard_count;
|
||||||
++shard_ordinal) {
|
++shard_ordinal) {
|
||||||
int64 wanted_limit_on_dilated_shard =
|
int64 wanted_limit_on_dilated_shard =
|
||||||
wd.stride() * (per_shard_window_counts[i] - 1) + wd.size();
|
wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
|
||||||
int64 actual_limit_on_dilated_shard_without_pad_high =
|
int64 actual_limit_on_dilated_shard_without_pad_high =
|
||||||
get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
|
get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
|
||||||
(max_shard_size - 1) * wd.base_dilation() + 1;
|
(max_shard_size - 1) * wd.base_dilation() + 1;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user