Allow space-to-batch to occur on base dilated convolutions

PiperOrigin-RevId: 343906578
Change-Id: I33dc1940dcfb3c5db19943964ddfa85a75e0553c
This commit is contained in:
A. Unique TensorFlower 2020-11-23 12:34:21 -08:00 committed by TensorFlower Gardener
parent 1dcb38c020
commit 7cfa60337b
2 changed files with 288 additions and 159 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <queue>
#include <tuple>
#include <unordered_set>
#include <utility>
@ -64,6 +65,19 @@ class ConvolutionVisitor {
// Top-level function to begin space-to-batch conversion.
Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution);
// Struct containing details about a convolution.
struct ConvDetails {
int64 spatial_dimension_to_split, inherent_low_padding,
inherent_high_padding, stride, spatial_size, base_dilation_factor,
halo_size, high_padding_for_conv, low_padding_for_conv,
kernel_spatial_dim_size, input_dim_size;
};
// Return a struct containing various necessary information pieces for
// performing space-to-batch on a convolution.
ConvDetails GetConvolutionDetails(HloInstruction* convolution,
ConvolutionDimensionNumbers& dim_numbers);
// Function that determines if space-to-batch can be propagated into the
// consumer. Such propagation is only possible when all required operands are
// space-to-batch'ed.
@ -225,11 +239,29 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
return false;
}
// TODO(b/168316428): Support base dilations.
if (convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.base_dilation() != 1) {
return false;
const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
const int64 low_pad = convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_low();
// TODO(b/168316428): Support base dilations more generically.
if (c.base_dilation_factor != 1) {
if (c.stride != 1) {
return false;
}
// For low pad of 0, only support a pointwise kernel.
if (low_pad == 0) {
if (c.kernel_spatial_dim_size != 1) {
return false;
}
} else if (c.kernel_spatial_dim_size != c.base_dilation_factor + 1 ||
low_pad != c.base_dilation_factor - 1) {
// Only support dilations such that base dilation factor and low pad are
// compatible with kernel_spatial_dim_size to be compatible with
// HaloDuplicateWithSlice.
return false;
}
}
int64 activations_batch_dim = dim_numbers.input_batch_dimension();
@ -240,42 +272,17 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
if (old_batch_size > limit_on_batch_size_) {
return false;
}
auto kernel = convolution->mutable_operand(1);
const auto& kernel_shape = kernel->shape();
const int64 kernel_spatial_dim_size =
kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
auto activations = convolution->mutable_operand(0);
const int64 input_dim_size =
activations->shape().dimensions(dim_numbers.input_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
const int64 inherent_low_padding =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_low();
const int64 inherent_high_padding =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_high();
const int64 spatial_size =
input_dim_size + inherent_low_padding + inherent_high_padding;
VLOG(1) << "spatial size " << spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
// We currently only cater to evenly divisible cases.
if (kNewBatchSize % old_batch_size != 0) {
return false;
}
// Splitting will be incorrect in these cases.
if (spatial_size < num_splits ||
input_dim_size / num_splits < kernel_spatial_dim_size) {
VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
// If the ratio is not within the 2X range, we can't Halo Pad from the next
// split.
if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) {
return false;
}
VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
@ -292,8 +299,8 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
activations->shape().dimensions(spatial_dimension_to_split);
const int64 batch_size =
activations->shape().dimensions(activations_batch_dim);
CHECK_LT(low_padding, spatial_split_size);
CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size);
VLOG(1) << "In HaloDuplicateWithSlice with activations "
<< activations->ToString() << " batch_size " << batch_size
<< " spatial_split_size " << spatial_split_size << " low_padding "
@ -439,6 +446,7 @@ StatusOr<bool> ConvolutionVisitor::Run() {
// Iterate through all instructions that we could not propagate through, and
// turn their operands from batch-to-space as needed.
for (auto instr : non_propagatable_instrs_) {
VLOG(1) << "Could not eventually propagate through " << instr->ToString();
absl::flat_hash_map<int64, HloInstruction*> operand_map;
for (int64 i = 0; i < instr->operand_count(); ++i) {
if (old_to_new_instrs_.count(instr->mutable_operand(i))) {
@ -480,8 +488,9 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
if (!old_to_new_instrs_.contains(old_producer) &&
!broadcast_or_constant) {
VLOG(1) << "Cannot propagate on elementwise op "
<< consumer->ToString();
VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString()
<< " because operand " << old_producer->ToString()
<< " isn't ready ";
return false;
} else {
if (broadcast_or_constant) {
@ -496,10 +505,11 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
pivot_operand = old_producer;
VLOG(2) << "Elementwise op: pivot " << old_producer->ToString();
} else {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString();
if (instr_to_dim_map_[pivot_operand] !=
instr_to_dim_map_[old_producer]) {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString()
<< " failed due to changed batch space ordering ";
return false;
}
auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
@ -509,13 +519,22 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
for (int j = 0; j < pivot_permute_dims.size(); ++j) {
// Ensure the dimension mapping is the same.
if (pivot_permute_dims[j] != permute_dims[j]) {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString()
<< " failed due to permuted dimensions ";
return false;
}
// Make sure all other dimensions are of the same size.
if (pivot_new_instr->shape().dimensions(j) !=
new_instr->shape().dimensions(j)) {
return false;
if (!(consumer->IsElementwiseBinary() &&
j == instr_to_dim_map_[pivot_operand].second)) {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString()
<< " failed due to changed shape sizes ";
return false;
}
}
}
}
@ -769,6 +788,28 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
if (IsTrivialElementwise(consumer)) {
auto dim_map_val = instr_to_dim_map_[producer];
auto new_consumer = computation->AddInstruction(consumer->Clone());
if (consumer->IsElementwiseBinary()) {
for (int64 i = 0; i < 2; ++i) {
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
break;
}
CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i)));
if (i == 1) {
// Choose the larger shape to be used as the producer.
if (old_to_new_instrs_[consumer->mutable_operand(0)]
->shape()
.dimensions() >
old_to_new_instrs_[consumer->mutable_operand(1)]
->shape()
.dimensions()) {
producer = consumer->mutable_operand(0);
} else {
producer = consumer->mutable_operand(1);
}
}
}
}
for (int64 i = 0; i < consumer->operand_count(); ++i) {
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
CHECK(old_to_new_instrs_.contains(producer));
@ -786,8 +827,66 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
} else {
CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i)));
TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape(
i, old_to_new_instrs_[consumer->mutable_operand(i)]));
HloInstruction* operand_to_use = nullptr;
auto result = instr_to_dim_map_[producer];
const int64 old_batch_dim = result.first;
const int64 old_space_dim = result.second;
const int64 old_batch_size =
producer->shape().dimensions(old_batch_dim);
HloInstruction* new_instr =
old_to_new_instrs_[consumer->mutable_operand(i)];
HloInstruction* pivot_new_instr = old_to_new_instrs_[producer];
auto permute_dims = instr_to_dim_permute_map_[new_instr];
const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
const int64 space_dim = DimLookUp(permute_dims, old_space_dim);
const int64 batch_size = new_instr->shape().dimensions(batch_dim);
if (new_instr->shape().dimensions(space_dim) !=
pivot_new_instr->shape().dimensions(space_dim)) {
// Because we do not propagate through transposes, the batch should
// always be followed by the split space dimension.
CHECK_EQ(batch_dim + 1, space_dim);
// Reshape to 1D, pad to the producer's size, reshape back to 2D.
std::vector<int64> new_dimensions(
new_instr->shape().dimensions().begin(),
new_instr->shape().dimensions().end());
new_dimensions[space_dim] *= (batch_size / old_batch_size);
new_dimensions[batch_dim] = old_batch_size;
TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
MakeReshapeHlo(new_dimensions, new_instr));
const int64 pivot_space_size =
pivot_new_instr->shape().dimensions(space_dim) * batch_size /
old_batch_size;
CHECK_GT(pivot_space_size, new_dimensions[space_dim]);
PaddingConfig padding_config =
MakeNoPaddingConfig(reshape->shape().dimensions_size());
padding_config.mutable_dimensions(space_dim)->set_edge_padding_high(
pivot_space_size - new_dimensions[space_dim]);
padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0);
HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(reshape->shape().element_type())));
TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand,
MakePadHlo(reshape, padding, padding_config));
TF_ASSIGN_OR_RETURN(
operand_to_use,
MakeReshapeHlo(pivot_new_instr->shape().dimensions(),
padded_operand));
} else {
operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)];
}
TF_CHECK_OK(
new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
}
}
auto old_type = new_consumer->mutable_shape()->element_type();
@ -1329,25 +1428,21 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
original_conv_dims.input_spatial_dimensions(i)));
}
int64 spatial_dimension_to_split =
permuted_conv_dims_numbers.input_spatial_dimensions(
get_chosen_spatial_dim(convolution));
const int64 old_batch_dim = original_conv_dims.input_batch_dimension();
const int64 old_batch_size =
activations_old->shape().dimensions(old_batch_dim);
const int64 input_dim_size = activations_old->shape().dimensions(
permuted_conv_dims_numbers.input_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
ConvDetails c =
GetConvolutionDetails(convolution, permuted_conv_dims_numbers);
VLOG(1) << "Propagating on conv activations_batch_dim "
<< activations_batch_dim << " spatial_dimension_to_split "
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size;
TF_ASSIGN_OR_RETURN(
activations_new,
BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
spatial_dimension_to_split, activations_batch_dim));
<< c.spatial_dimension_to_split << " old_batch_size "
<< old_batch_size;
TF_ASSIGN_OR_RETURN(activations_new,
BringSpaceNextToBatch(
activations_new, permuted_conv_dims_numbers,
c.spatial_dimension_to_split, activations_batch_dim));
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations_new->shape().element_type())));
@ -1355,32 +1450,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN(
activations_new,
SelectValidPortion(activations_new, activations_old, select_val,
activations_batch_dim, spatial_dimension_to_split,
activations_batch_dim, c.spatial_dimension_to_split,
old_batch_dim, old_space_dim));
// Create the new convolution dim numbers.
auto new_dim_numbers = permuted_conv_dims_numbers;
auto kernel = convolution->mutable_operand(1);
const auto& kernel_shape = kernel->shape();
const int64 kernel_spatial_dim_size = kernel_shape.dimensions(
permuted_conv_dims_numbers.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
const int64 inherent_low_padding =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_low();
const int64 inherent_high_padding =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_high();
const int64 stride = convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.stride();
const int64 spatial_size =
input_dim_size + inherent_low_padding + inherent_high_padding;
VLOG(1) << "spatial size " << spatial_size;
VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
@ -1390,18 +1465,18 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
const int64 output_offsets_per_split =
CeilOfRatio(output_offsets, num_splits);
int64 spatial_split_size = output_offsets_per_split * stride;
const int64 halo_size =
std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0));
int64 spatial_split_size =
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
// Keep increasing the split size so that overall size isn't smaller than the
// original spatial dimension. Unlike for the first space-to-batch'ed
// convolution, while propagating, we can use the last halo_size as available
// spatial size.
while (spatial_split_size * num_splits + halo_size - spatial_size < 0) {
spatial_split_size += stride;
while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) {
spatial_split_size += c.stride;
}
int64 slice_size = spatial_split_size + halo_size;
int64 slice_size = spatial_split_size + c.halo_size;
VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
<< slice_size;
@ -1409,7 +1484,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
const int64 new_batch_size =
activations_new->shape().dimensions(activations_batch_dim);
const int64 new_space_size =
activations_new->shape().dimensions(spatial_dimension_to_split);
activations_new->shape().dimensions(c.spatial_dimension_to_split);
// In the below case, we cannot use the activations directly for Halo
// Duplication. We must reshape them.
if (spatial_split_size > new_space_size) {
@ -1418,7 +1493,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
activations_new->shape().dimensions().end());
const int64 reshaped_space_size =
new_space_size * new_batch_size / old_batch_size;
new_dimensions[spatial_dimension_to_split] = reshaped_space_size;
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
new_dimensions[activations_batch_dim] = old_batch_size;
// Reshape the output of the new conv into the old convolutions shape.
@ -1427,10 +1502,10 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
PaddingConfig padding_config =
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
padding_config.mutable_dimensions(spatial_dimension_to_split)
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_high(spatial_split_size * new_batch_size -
reshaped_space_size);
padding_config.mutable_dimensions(spatial_dimension_to_split)
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_low(0);
HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant(
@ -1444,7 +1519,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
reshaped_activations->shape().dimensions().begin(),
reshaped_activations->shape().dimensions().end());
reshape_back_dims[spatial_dimension_to_split] = spatial_split_size;
reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size;
reshape_back_dims[activations_batch_dim] = new_batch_size;
TF_ASSIGN_OR_RETURN(
@ -1453,34 +1528,38 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN(
activations_new,
HaloDuplicateWithSlice(reshaped_activations, spatial_dimension_to_split,
activations_batch_dim, old_batch_size,
/*low_padding=*/inherent_low_padding,
/*high_padding=*/inherent_high_padding,
slice_size - spatial_split_size,
old_split_dim_size));
HaloDuplicateWithSlice(
reshaped_activations, c.spatial_dimension_to_split,
activations_batch_dim, old_batch_size,
/*low_padding=*/c.base_dilation_factor != 1 &&
c.inherent_low_padding != 0
? c.base_dilation_factor - 1
: c.inherent_low_padding,
c.inherent_high_padding, slice_size - spatial_split_size,
old_split_dim_size));
} else {
// If the ideal spatial_split_size was smaller than the incoming spatial
// dimension size, we don't need reshaping. Instead, we determine the
// additional space available, and adjust the required slice size (and
// thereby the halo size).'t need reshaping. Instead, we determine the
// additional space available, and adjust the required slice size (and
// thereby the halo size).
if (spatial_split_size < new_space_size) {
const int64 additional_space_present = spatial_split_size % stride;
const int64 additional_space_present = spatial_split_size % c.stride;
spatial_split_size = new_space_size;
slice_size =
spatial_split_size +
std::max(kernel_spatial_dim_size - stride - additional_space_present,
static_cast<int64>(0));
spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride -
additional_space_present,
static_cast<int64>(0));
}
TF_ASSIGN_OR_RETURN(
activations_new,
HaloDuplicateWithSlice(activations_new, spatial_dimension_to_split,
HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
activations_batch_dim, old_batch_size,
/*low_padding=*/inherent_low_padding,
/*high_padding=*/inherent_high_padding,
/*low_padding=*/c.base_dilation_factor != 1 &&
c.inherent_low_padding != 0
? c.base_dilation_factor - 1
: c.inherent_low_padding,
c.inherent_high_padding,
slice_size - spatial_split_size,
old_split_dim_size));
}
@ -1515,9 +1594,9 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
auto new_window = convolution->window();
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_high(0);
->set_padding_high(c.high_padding_for_conv);
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_low(0);
->set_padding_low(c.low_padding_for_conv);
TF_ASSIGN_OR_RETURN(
HloInstruction * new_conv,
MakeConvolveHlo(
@ -1855,19 +1934,9 @@ HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow(
return nullptr;
}
Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
HloInstruction* convolution) {
VLOG(1) << "Handling conv " << convolution->ToString();
changed_ = false;
ConvolutionDimensionNumbers dim_numbers =
convolution->convolution_dimension_numbers();
int64 activations_batch_dim = dim_numbers.input_batch_dimension();
const int64 old_batch_size =
convolution->operand(0)->shape().dimensions(activations_batch_dim);
ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
auto activations = convolution->mutable_operand(0);
auto kernel = convolution->mutable_operand(1);
const auto& kernel_shape = kernel->shape();
@ -1875,14 +1944,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
auto activations = convolution->mutable_operand(0);
int64 spatial_dimension_to_split =
const int64 spatial_dimension_to_split =
dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution));
const int64 input_dim_size =
activations->shape().dimensions(dim_numbers.input_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
activations->shape().dimensions(spatial_dimension_to_split);
const int64 inherent_low_padding =
convolution->window()
@ -1892,26 +1958,75 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_high();
const bool inherent_padding_needed =
inherent_low_padding != 0 || inherent_high_padding != 0;
const int64 stride = convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.stride();
const int64 base_dilation_factor =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.base_dilation();
const int64 spatial_size =
input_dim_size + inherent_low_padding + inherent_high_padding;
VLOG(1) << "spatial size " << spatial_size;
input_dim_size + (base_dilation_factor > 1 ? 0 : inherent_low_padding) +
inherent_high_padding;
const int64 halo_size =
std::max(kernel_spatial_dim_size - stride - (base_dilation_factor - 1),
static_cast<int64>(0));
const int64 high_padding_for_conv = base_dilation_factor == 1 ? 0
: inherent_low_padding == 0
? base_dilation_factor - 1
: 0;
const int64 low_padding_for_conv =
base_dilation_factor == 1 ? 0 : inherent_low_padding;
return ConvDetails{spatial_dimension_to_split,
inherent_low_padding,
inherent_high_padding,
stride,
spatial_size,
base_dilation_factor,
halo_size,
high_padding_for_conv,
low_padding_for_conv,
kernel_spatial_dim_size,
input_dim_size};
}
Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
HloInstruction* convolution) {
VLOG(1) << "Handling conv " << convolution->ToString();
changed_ = false;
ConvolutionDimensionNumbers dim_numbers =
convolution->convolution_dimension_numbers();
ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
int64 activations_batch_dim = dim_numbers.input_batch_dimension();
const int64 old_batch_size =
convolution->operand(0)->shape().dimensions(activations_batch_dim);
auto activations = convolution->mutable_operand(0);
const bool inherent_padding_needed =
c.inherent_low_padding != 0 || c.inherent_high_padding != 0;
VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
auto original_conv = convolution;
// We'd need transposition of activations here such that batch and space dim
// that is being split are adjacent (in that order).
TF_ASSIGN_OR_RETURN(
activations,
BringSpaceNextToBatch(activations, dim_numbers,
spatial_dimension_to_split, activations_batch_dim));
TF_ASSIGN_OR_RETURN(activations,
BringSpaceNextToBatch(activations, dim_numbers,
c.spatial_dimension_to_split,
activations_batch_dim));
// Create the new convolution dim numbers.
auto new_dim_numbers = dim_numbers;
@ -1922,11 +2037,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
const int64 output_offsets_per_split =
CeilOfRatio(output_offsets, num_splits);
int64 spatial_split_size = output_offsets_per_split * stride;
int64 spatial_split_size =
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
// Keep increasing the split size so that overall size isn't smaller than the
// original spatial dimension.
while (spatial_split_size * num_splits - spatial_size < 0) {
spatial_split_size += stride;
while (spatial_split_size * num_splits - c.spatial_size < 0) {
spatial_split_size += c.stride;
}
auto reduce_window = DoesConvolutionFeedReduceWindow(convolution);
@ -1938,33 +2054,32 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
// windows.
const int64 red_win_stride =
reduce_window->window().dimensions(output_spatial_dim).stride();
while ((spatial_split_size / stride) % red_win_stride != 0) {
spatial_split_size += stride;
while ((spatial_split_size / c.stride) % red_win_stride != 0) {
spatial_split_size += c.stride;
}
}
const int64 slice_size =
spatial_split_size +
std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0));
const int64 slice_size = spatial_split_size + c.halo_size;
// Pad spatial dim.
const int64 pad_size = spatial_split_size * num_splits - spatial_size;
const int64 pad_size = spatial_split_size * num_splits - c.spatial_size;
VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
<< stride;
VLOG(1) << "spatial_dimension_to_split " << spatial_dimension_to_split
<< c.stride << " slice_size " << slice_size;
VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
<< " num_splits " << num_splits << " kernel_spatial_dim_size "
<< kernel_spatial_dim_size;
<< c.kernel_spatial_dim_size;
// Because we are splitting the spatial dimension, if convolution needed
// padding in the spatial dimension, we materialize it.
if (pad_size != 0 || inherent_padding_needed) {
PaddingConfig padding_config =
MakeNoPaddingConfig(activations->shape().dimensions_size());
padding_config.mutable_dimensions(spatial_dimension_to_split)
->set_edge_padding_high(inherent_high_padding + pad_size);
padding_config.mutable_dimensions(spatial_dimension_to_split)
->set_edge_padding_low(inherent_low_padding);
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_high(c.inherent_high_padding + pad_size);
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_low(
c.base_dilation_factor == 1 ? c.inherent_low_padding : 0);
HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations->shape().element_type())));
@ -1991,7 +2106,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
activations->shape().dimensions().begin(),
activations->shape().dimensions().end());
reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
reshape_dimensions[c.spatial_dimension_to_split] = spatial_split_size;
reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
@ -2000,12 +2115,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
VLOG(1) << "First reshape done " << batch_increased_reshape->ToString();
TF_ASSIGN_OR_RETURN(activations,
HaloDuplicateWithSlice(
batch_increased_reshape, spatial_dimension_to_split,
activations_batch_dim, old_batch_size,
/*low_padding=*/0, /*high_padding=*/0,
slice_size - spatial_split_size, input_dim_size));
TF_ASSIGN_OR_RETURN(
activations, HaloDuplicateWithSlice(batch_increased_reshape,
c.spatial_dimension_to_split,
activations_batch_dim, old_batch_size,
/*low_padding=*/0, /*high_padding=*/0,
c.halo_size, c.input_dim_size));
VLOG(1) << "Batch merge done " << activations->ToString();
@ -2040,9 +2155,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
<< " batch dim " << new_dim_numbers.input_batch_dimension();
auto new_window = convolution->window();
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_high(0);
->set_padding_high(c.high_padding_for_conv);
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_low(0);
->set_padding_low(c.low_padding_for_conv);
TF_ASSIGN_OR_RETURN(
HloInstruction * new_conv,
MakeConvolveHlo(

View File

@ -113,7 +113,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
}
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) {
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithBaseDilation) {
string hlo_string = R"(
HloModule module
@ -129,8 +129,22 @@ ENTRY computation {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
auto computation = module->entry_computation();
ConvolutionSpaceToBatchConverter converter;
ASSERT_FALSE(converter.Run(module.get()).ValueOrDie());
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Transpose());
EXPECT_THAT(root->operand(0), op::Slice());
auto reshape = root->operand(0)->operand(0);
EXPECT_THAT(reshape, op::Reshape());
EXPECT_THAT(reshape->operand(0)->operand(1), op::Convolution());
const int64 batch_dim = reshape->operand(0)
->operand(1)
->convolution_dimension_numbers()
.output_batch_dimension();
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
}
} // namespace