From 1b53b995da446faf4772380f4c40ded498c84954 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 29 Jun 2020 13:54:09 -0700 Subject: [PATCH] SPMD sharding for convolutions with a dilated kernel PiperOrigin-RevId: 318887130 Change-Id: Ied7c5ea0dde042fef675c52431b0e047288f0e90 --- .../compiler/xla/service/spmd/spmd_partitioner.cc | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index 635446a18a1..7e136be54e6 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -418,21 +418,17 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, continue; } const auto& wd = window.dimensions(i); - if (wd.window_dilation() != 1) { - // TODO(yuanzx): Support window dilation. - VLOG(2) << "Failed to reshard window operand due to window dilation"; - return absl::nullopt; - } + const auto dilated_size = 1 + (wd.size() - 1) * wd.window_dilation(); int64 full_size = base_shape_.dimensions(i) + (wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) + wd.padding_high() + wd.padding_low(); - if (full_size < wd.size()) { + if (full_size < dilated_size) { VLOG(2) << "Failed to reshard window operand because the window size is " "larger than padded base size"; return absl::nullopt; } - int64 window_count = (full_size - wd.size()) / wd.stride() + 1; + int64 window_count = (full_size - dilated_size) / wd.stride() + 1; per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count); if (wd.stride() != 1 && (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) { @@ -457,7 +453,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, wd.stride() * per_shard_window_counts[i], wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation()); int64 dilated_shard_size = - wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size; limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( wd.stride() * per_shard_window_counts[i], dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(), @@ -493,7 +489,7 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, for (int64 shard_ordinal = 0; shard_ordinal < shard_count; ++shard_ordinal) { int64 wanted_limit_on_dilated_shard = - wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size; int64 actual_limit_on_dilated_shard_without_pad_high = get_first_valid_element_offset_on_dilated_shard(shard_ordinal) + (max_shard_size - 1) * wd.base_dilation() + 1;