Allow space-to-batch to occur on base dilated convolutions

PiperOrigin-RevId: 343906578
Change-Id: I33dc1940dcfb3c5db19943964ddfa85a75e0553c
This commit is contained in:
A. Unique TensorFlower 2020-11-23 12:34:21 -08:00 committed by TensorFlower Gardener
parent 1dcb38c020
commit 7cfa60337b
2 changed files with 288 additions and 159 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <map> #include <map>
#include <memory> #include <memory>
#include <queue> #include <queue>
#include <tuple>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
@ -64,6 +65,19 @@ class ConvolutionVisitor {
// Top-level function to begin space-to-batch conversion. // Top-level function to begin space-to-batch conversion.
Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution); Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution);
// Struct containing details about a convolution.
struct ConvDetails {
int64 spatial_dimension_to_split, inherent_low_padding,
inherent_high_padding, stride, spatial_size, base_dilation_factor,
halo_size, high_padding_for_conv, low_padding_for_conv,
kernel_spatial_dim_size, input_dim_size;
};
// Return a struct containing various necessary information pieces for
// performing space-to-batch on a convolution.
ConvDetails GetConvolutionDetails(HloInstruction* convolution,
ConvolutionDimensionNumbers& dim_numbers);
// Function that determines if space-to-batch can be propagated into the // Function that determines if space-to-batch can be propagated into the
// consumer. Such propagation is only possible when all required operands are // consumer. Such propagation is only possible when all required operands are
// space-to-batch'ed. // space-to-batch'ed.
@ -225,11 +239,29 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
return false; return false;
} }
// TODO(b/168316428): Support base dilations. const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
if (convolution->window()
.dimensions(get_chosen_spatial_dim(convolution)) const int64 low_pad = convolution->window()
.base_dilation() != 1) { .dimensions(get_chosen_spatial_dim(convolution))
return false; .padding_low();
// TODO(b/168316428): Support base dilations more generically.
if (c.base_dilation_factor != 1) {
if (c.stride != 1) {
return false;
}
// For low pad of 0, only support a pointwise kernel.
if (low_pad == 0) {
if (c.kernel_spatial_dim_size != 1) {
return false;
}
} else if (c.kernel_spatial_dim_size != c.base_dilation_factor + 1 ||
low_pad != c.base_dilation_factor - 1) {
// Only support dilations such that base dilation factor and low pad are
// compatible with kernel_spatial_dim_size to be compatible with
// HaloDuplicateWithSlice.
return false;
}
} }
int64 activations_batch_dim = dim_numbers.input_batch_dimension(); int64 activations_batch_dim = dim_numbers.input_batch_dimension();
@ -240,42 +272,17 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
if (old_batch_size > limit_on_batch_size_) { if (old_batch_size > limit_on_batch_size_) {
return false; return false;
} }
auto kernel = convolution->mutable_operand(1);
const auto& kernel_shape = kernel->shape();
const int64 kernel_spatial_dim_size =
kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
auto activations = convolution->mutable_operand(0);
const int64 input_dim_size =
activations->shape().dimensions(dim_numbers.input_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
const int64 inherent_low_padding =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_low();
const int64 inherent_high_padding =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_high();
const int64 spatial_size =
input_dim_size + inherent_low_padding + inherent_high_padding;
VLOG(1) << "spatial size " << spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
// We currently only cater to evenly divisible cases. // We currently only cater to evenly divisible cases.
if (kNewBatchSize % old_batch_size != 0) { if (kNewBatchSize % old_batch_size != 0) {
return false; return false;
} }
// Splitting will be incorrect in these cases. VLOG(1) << "spatial size " << c.spatial_size;
if (spatial_size < num_splits ||
input_dim_size / num_splits < kernel_spatial_dim_size) { const int64 num_splits = kNewBatchSize / old_batch_size;
// If the ratio is not within the 2X range, we can't Halo Pad from the next
// split.
if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) {
return false; return false;
} }
VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString(); VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
@ -292,8 +299,8 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
activations->shape().dimensions(spatial_dimension_to_split); activations->shape().dimensions(spatial_dimension_to_split);
const int64 batch_size = const int64 batch_size =
activations->shape().dimensions(activations_batch_dim); activations->shape().dimensions(activations_batch_dim);
CHECK_LT(low_padding, spatial_split_size);
CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size);
VLOG(1) << "In HaloDuplicateWithSlice with activations " VLOG(1) << "In HaloDuplicateWithSlice with activations "
<< activations->ToString() << " batch_size " << batch_size << activations->ToString() << " batch_size " << batch_size
<< " spatial_split_size " << spatial_split_size << " low_padding " << " spatial_split_size " << spatial_split_size << " low_padding "
@ -439,6 +446,7 @@ StatusOr<bool> ConvolutionVisitor::Run() {
// Iterate through all instructions that we could not propagate through, and // Iterate through all instructions that we could not propagate through, and
// turn their operands from batch-to-space as needed. // turn their operands from batch-to-space as needed.
for (auto instr : non_propagatable_instrs_) { for (auto instr : non_propagatable_instrs_) {
VLOG(1) << "Could not eventually propagate through " << instr->ToString();
absl::flat_hash_map<int64, HloInstruction*> operand_map; absl::flat_hash_map<int64, HloInstruction*> operand_map;
for (int64 i = 0; i < instr->operand_count(); ++i) { for (int64 i = 0; i < instr->operand_count(); ++i) {
if (old_to_new_instrs_.count(instr->mutable_operand(i))) { if (old_to_new_instrs_.count(instr->mutable_operand(i))) {
@ -480,8 +488,9 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
if (!old_to_new_instrs_.contains(old_producer) && if (!old_to_new_instrs_.contains(old_producer) &&
!broadcast_or_constant) { !broadcast_or_constant) {
VLOG(1) << "Cannot propagate on elementwise op " VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString()
<< consumer->ToString(); << " because operand " << old_producer->ToString()
<< " isn't ready ";
return false; return false;
} else { } else {
if (broadcast_or_constant) { if (broadcast_or_constant) {
@ -496,10 +505,11 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
pivot_operand = old_producer; pivot_operand = old_producer;
VLOG(2) << "Elementwise op: pivot " << old_producer->ToString(); VLOG(2) << "Elementwise op: pivot " << old_producer->ToString();
} else { } else {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString();
if (instr_to_dim_map_[pivot_operand] != if (instr_to_dim_map_[pivot_operand] !=
instr_to_dim_map_[old_producer]) { instr_to_dim_map_[old_producer]) {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString()
<< " failed due to changed batch space ordering ";
return false; return false;
} }
auto pivot_new_instr = old_to_new_instrs_[pivot_operand]; auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
@ -509,13 +519,22 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
for (int j = 0; j < pivot_permute_dims.size(); ++j) { for (int j = 0; j < pivot_permute_dims.size(); ++j) {
// Ensure the dimension mapping is the same. // Ensure the dimension mapping is the same.
if (pivot_permute_dims[j] != permute_dims[j]) { if (pivot_permute_dims[j] != permute_dims[j]) {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString()
<< " failed due to permuted dimensions ";
return false; return false;
} }
// Make sure all other dimensions are of the same size. // Make sure all other dimensions are of the same size.
if (pivot_new_instr->shape().dimensions(j) != if (pivot_new_instr->shape().dimensions(j) !=
new_instr->shape().dimensions(j)) { new_instr->shape().dimensions(j)) {
return false; if (!(consumer->IsElementwiseBinary() &&
j == instr_to_dim_map_[pivot_operand].second)) {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString()
<< " failed due to changed shape sizes ";
return false;
}
} }
} }
} }
@ -769,6 +788,28 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
if (IsTrivialElementwise(consumer)) { if (IsTrivialElementwise(consumer)) {
auto dim_map_val = instr_to_dim_map_[producer]; auto dim_map_val = instr_to_dim_map_[producer];
auto new_consumer = computation->AddInstruction(consumer->Clone()); auto new_consumer = computation->AddInstruction(consumer->Clone());
if (consumer->IsElementwiseBinary()) {
for (int64 i = 0; i < 2; ++i) {
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
break;
}
CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i)));
if (i == 1) {
// Choose the larger shape to be used as the producer.
if (old_to_new_instrs_[consumer->mutable_operand(0)]
->shape()
.dimensions() >
old_to_new_instrs_[consumer->mutable_operand(1)]
->shape()
.dimensions()) {
producer = consumer->mutable_operand(0);
} else {
producer = consumer->mutable_operand(1);
}
}
}
}
for (int64 i = 0; i < consumer->operand_count(); ++i) { for (int64 i = 0; i < consumer->operand_count(); ++i) {
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
CHECK(old_to_new_instrs_.contains(producer)); CHECK(old_to_new_instrs_.contains(producer));
@ -786,8 +827,66 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast)); new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
} else { } else {
CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i))); CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i)));
TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape( HloInstruction* operand_to_use = nullptr;
i, old_to_new_instrs_[consumer->mutable_operand(i)]));
auto result = instr_to_dim_map_[producer];
const int64 old_batch_dim = result.first;
const int64 old_space_dim = result.second;
const int64 old_batch_size =
producer->shape().dimensions(old_batch_dim);
HloInstruction* new_instr =
old_to_new_instrs_[consumer->mutable_operand(i)];
HloInstruction* pivot_new_instr = old_to_new_instrs_[producer];
auto permute_dims = instr_to_dim_permute_map_[new_instr];
const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
const int64 space_dim = DimLookUp(permute_dims, old_space_dim);
const int64 batch_size = new_instr->shape().dimensions(batch_dim);
if (new_instr->shape().dimensions(space_dim) !=
pivot_new_instr->shape().dimensions(space_dim)) {
// Because we do not propagate through transposes, the batch should
// always be followed by the split space dimension.
CHECK_EQ(batch_dim + 1, space_dim);
// Reshape to 1D, pad to the producer's size, reshape back to 2D.
std::vector<int64> new_dimensions(
new_instr->shape().dimensions().begin(),
new_instr->shape().dimensions().end());
new_dimensions[space_dim] *= (batch_size / old_batch_size);
new_dimensions[batch_dim] = old_batch_size;
TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
MakeReshapeHlo(new_dimensions, new_instr));
const int64 pivot_space_size =
pivot_new_instr->shape().dimensions(space_dim) * batch_size /
old_batch_size;
CHECK_GT(pivot_space_size, new_dimensions[space_dim]);
PaddingConfig padding_config =
MakeNoPaddingConfig(reshape->shape().dimensions_size());
padding_config.mutable_dimensions(space_dim)->set_edge_padding_high(
pivot_space_size - new_dimensions[space_dim]);
padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0);
HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(reshape->shape().element_type())));
TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand,
MakePadHlo(reshape, padding, padding_config));
TF_ASSIGN_OR_RETURN(
operand_to_use,
MakeReshapeHlo(pivot_new_instr->shape().dimensions(),
padded_operand));
} else {
operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)];
}
TF_CHECK_OK(
new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
} }
} }
auto old_type = new_consumer->mutable_shape()->element_type(); auto old_type = new_consumer->mutable_shape()->element_type();
@ -1329,25 +1428,21 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
original_conv_dims.input_spatial_dimensions(i))); original_conv_dims.input_spatial_dimensions(i)));
} }
int64 spatial_dimension_to_split =
permuted_conv_dims_numbers.input_spatial_dimensions(
get_chosen_spatial_dim(convolution));
const int64 old_batch_dim = original_conv_dims.input_batch_dimension(); const int64 old_batch_dim = original_conv_dims.input_batch_dimension();
const int64 old_batch_size = const int64 old_batch_size =
activations_old->shape().dimensions(old_batch_dim); activations_old->shape().dimensions(old_batch_dim);
const int64 input_dim_size = activations_old->shape().dimensions( ConvDetails c =
permuted_conv_dims_numbers.input_spatial_dimensions( GetConvolutionDetails(convolution, permuted_conv_dims_numbers);
get_chosen_spatial_dim(convolution)));
VLOG(1) << "Propagating on conv activations_batch_dim " VLOG(1) << "Propagating on conv activations_batch_dim "
<< activations_batch_dim << " spatial_dimension_to_split " << activations_batch_dim << " spatial_dimension_to_split "
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size; << c.spatial_dimension_to_split << " old_batch_size "
TF_ASSIGN_OR_RETURN( << old_batch_size;
activations_new, TF_ASSIGN_OR_RETURN(activations_new,
BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers, BringSpaceNextToBatch(
spatial_dimension_to_split, activations_batch_dim)); activations_new, permuted_conv_dims_numbers,
c.spatial_dimension_to_split, activations_batch_dim));
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations_new->shape().element_type()))); LiteralUtil::Zero(activations_new->shape().element_type())));
@ -1355,32 +1450,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
activations_new, activations_new,
SelectValidPortion(activations_new, activations_old, select_val, SelectValidPortion(activations_new, activations_old, select_val,
activations_batch_dim, spatial_dimension_to_split, activations_batch_dim, c.spatial_dimension_to_split,
old_batch_dim, old_space_dim)); old_batch_dim, old_space_dim));
// Create the new convolution dim numbers. // Create the new convolution dim numbers.
auto new_dim_numbers = permuted_conv_dims_numbers; auto new_dim_numbers = permuted_conv_dims_numbers;
auto kernel = convolution->mutable_operand(1); VLOG(1) << "spatial size " << c.spatial_size;
const auto& kernel_shape = kernel->shape();
const int64 kernel_spatial_dim_size = kernel_shape.dimensions(
permuted_conv_dims_numbers.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
const int64 inherent_low_padding =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_low();
const int64 inherent_high_padding =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.padding_high();
const int64 stride = convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.stride();
const int64 spatial_size =
input_dim_size + inherent_low_padding + inherent_high_padding;
VLOG(1) << "spatial size " << spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size; const int64 num_splits = kNewBatchSize / old_batch_size;
@ -1390,18 +1465,18 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
const int64 output_offsets_per_split = const int64 output_offsets_per_split =
CeilOfRatio(output_offsets, num_splits); CeilOfRatio(output_offsets, num_splits);
int64 spatial_split_size = output_offsets_per_split * stride; int64 spatial_split_size =
const int64 halo_size = CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0));
// Keep increasing the split size so that overall size isn't smaller than the // Keep increasing the split size so that overall size isn't smaller than the
// original spatial dimension. Unlike for the first space-to-batch'ed // original spatial dimension. Unlike for the first space-to-batch'ed
// convolution, while propagating, we can use the last halo_size as available // convolution, while propagating, we can use the last halo_size as available
// spatial size. // spatial size.
while (spatial_split_size * num_splits + halo_size - spatial_size < 0) { while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) {
spatial_split_size += stride; spatial_split_size += c.stride;
} }
int64 slice_size = spatial_split_size + halo_size; int64 slice_size = spatial_split_size + c.halo_size;
VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size " VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
<< slice_size; << slice_size;
@ -1409,7 +1484,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
const int64 new_batch_size = const int64 new_batch_size =
activations_new->shape().dimensions(activations_batch_dim); activations_new->shape().dimensions(activations_batch_dim);
const int64 new_space_size = const int64 new_space_size =
activations_new->shape().dimensions(spatial_dimension_to_split); activations_new->shape().dimensions(c.spatial_dimension_to_split);
// In the below case, we cannot use the activations directly for Halo // In the below case, we cannot use the activations directly for Halo
// Duplication. We must reshape them. // Duplication. We must reshape them.
if (spatial_split_size > new_space_size) { if (spatial_split_size > new_space_size) {
@ -1418,7 +1493,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
activations_new->shape().dimensions().end()); activations_new->shape().dimensions().end());
const int64 reshaped_space_size = const int64 reshaped_space_size =
new_space_size * new_batch_size / old_batch_size; new_space_size * new_batch_size / old_batch_size;
new_dimensions[spatial_dimension_to_split] = reshaped_space_size; new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
new_dimensions[activations_batch_dim] = old_batch_size; new_dimensions[activations_batch_dim] = old_batch_size;
// Reshape the output of the new conv into the old convolutions shape. // Reshape the output of the new conv into the old convolutions shape.
@ -1427,10 +1502,10 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
PaddingConfig padding_config = PaddingConfig padding_config =
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size()); MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
padding_config.mutable_dimensions(spatial_dimension_to_split) padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_high(spatial_split_size * new_batch_size - ->set_edge_padding_high(spatial_split_size * new_batch_size -
reshaped_space_size); reshaped_space_size);
padding_config.mutable_dimensions(spatial_dimension_to_split) padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_low(0); ->set_edge_padding_low(0);
HloInstruction* padding = HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant( computation_->AddInstruction(HloInstruction::CreateConstant(
@ -1444,7 +1519,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
reshaped_activations->shape().dimensions().begin(), reshaped_activations->shape().dimensions().begin(),
reshaped_activations->shape().dimensions().end()); reshaped_activations->shape().dimensions().end());
reshape_back_dims[spatial_dimension_to_split] = spatial_split_size; reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size;
reshape_back_dims[activations_batch_dim] = new_batch_size; reshape_back_dims[activations_batch_dim] = new_batch_size;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -1453,34 +1528,38 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
activations_new, activations_new,
HaloDuplicateWithSlice(reshaped_activations, spatial_dimension_to_split, HaloDuplicateWithSlice(
activations_batch_dim, old_batch_size, reshaped_activations, c.spatial_dimension_to_split,
/*low_padding=*/inherent_low_padding, activations_batch_dim, old_batch_size,
/*high_padding=*/inherent_high_padding, /*low_padding=*/c.base_dilation_factor != 1 &&
slice_size - spatial_split_size, c.inherent_low_padding != 0
old_split_dim_size)); ? c.base_dilation_factor - 1
: c.inherent_low_padding,
c.inherent_high_padding, slice_size - spatial_split_size,
old_split_dim_size));
} else { } else {
// If the ideal spatial_split_size was smaller than the incoming spatial // If the ideal spatial_split_size was smaller than the incoming spatial
// dimension size, we don't need reshaping. Instead, we determine the // dimension size, we don't need reshaping. Instead, we determine the
// additional space available, and adjust the required slice size (and // additional space available, and adjust the required slice size (and
// thereby the halo size).'t need reshaping. Instead, we determine the
// additional space available, and adjust the required slice size (and
// thereby the halo size). // thereby the halo size).
if (spatial_split_size < new_space_size) { if (spatial_split_size < new_space_size) {
const int64 additional_space_present = spatial_split_size % stride; const int64 additional_space_present = spatial_split_size % c.stride;
spatial_split_size = new_space_size; spatial_split_size = new_space_size;
slice_size = slice_size =
spatial_split_size + spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride -
std::max(kernel_spatial_dim_size - stride - additional_space_present, additional_space_present,
static_cast<int64>(0)); static_cast<int64>(0));
} }
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
activations_new, activations_new,
HaloDuplicateWithSlice(activations_new, spatial_dimension_to_split, HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
activations_batch_dim, old_batch_size, activations_batch_dim, old_batch_size,
/*low_padding=*/inherent_low_padding, /*low_padding=*/c.base_dilation_factor != 1 &&
/*high_padding=*/inherent_high_padding, c.inherent_low_padding != 0
? c.base_dilation_factor - 1
: c.inherent_low_padding,
c.inherent_high_padding,
slice_size - spatial_split_size, slice_size - spatial_split_size,
old_split_dim_size)); old_split_dim_size));
} }
@ -1515,9 +1594,9 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
auto new_window = convolution->window(); auto new_window = convolution->window();
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_high(0); ->set_padding_high(c.high_padding_for_conv);
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_low(0); ->set_padding_low(c.low_padding_for_conv);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
HloInstruction * new_conv, HloInstruction * new_conv,
MakeConvolveHlo( MakeConvolveHlo(
@ -1855,19 +1934,9 @@ HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow(
return nullptr; return nullptr;
} }
Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
HloInstruction* convolution) { HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
VLOG(1) << "Handling conv " << convolution->ToString(); auto activations = convolution->mutable_operand(0);
changed_ = false;
ConvolutionDimensionNumbers dim_numbers =
convolution->convolution_dimension_numbers();
int64 activations_batch_dim = dim_numbers.input_batch_dimension();
const int64 old_batch_size =
convolution->operand(0)->shape().dimensions(activations_batch_dim);
auto kernel = convolution->mutable_operand(1); auto kernel = convolution->mutable_operand(1);
const auto& kernel_shape = kernel->shape(); const auto& kernel_shape = kernel->shape();
@ -1875,14 +1944,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions( kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution))); get_chosen_spatial_dim(convolution)));
auto activations = convolution->mutable_operand(0); const int64 spatial_dimension_to_split =
int64 spatial_dimension_to_split =
dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution)); dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution));
const int64 input_dim_size = const int64 input_dim_size =
activations->shape().dimensions(dim_numbers.input_spatial_dimensions( activations->shape().dimensions(spatial_dimension_to_split);
get_chosen_spatial_dim(convolution)));
const int64 inherent_low_padding = const int64 inherent_low_padding =
convolution->window() convolution->window()
@ -1892,26 +1958,75 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
convolution->window() convolution->window()
.dimensions(get_chosen_spatial_dim(convolution)) .dimensions(get_chosen_spatial_dim(convolution))
.padding_high(); .padding_high();
const bool inherent_padding_needed =
inherent_low_padding != 0 || inherent_high_padding != 0;
const int64 stride = convolution->window() const int64 stride = convolution->window()
.dimensions(get_chosen_spatial_dim(convolution)) .dimensions(get_chosen_spatial_dim(convolution))
.stride(); .stride();
const int64 base_dilation_factor =
convolution->window()
.dimensions(get_chosen_spatial_dim(convolution))
.base_dilation();
const int64 spatial_size = const int64 spatial_size =
input_dim_size + inherent_low_padding + inherent_high_padding; input_dim_size + (base_dilation_factor > 1 ? 0 : inherent_low_padding) +
VLOG(1) << "spatial size " << spatial_size; inherent_high_padding;
const int64 halo_size =
std::max(kernel_spatial_dim_size - stride - (base_dilation_factor - 1),
static_cast<int64>(0));
const int64 high_padding_for_conv = base_dilation_factor == 1 ? 0
: inherent_low_padding == 0
? base_dilation_factor - 1
: 0;
const int64 low_padding_for_conv =
base_dilation_factor == 1 ? 0 : inherent_low_padding;
return ConvDetails{spatial_dimension_to_split,
inherent_low_padding,
inherent_high_padding,
stride,
spatial_size,
base_dilation_factor,
halo_size,
high_padding_for_conv,
low_padding_for_conv,
kernel_spatial_dim_size,
input_dim_size};
}
Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
HloInstruction* convolution) {
VLOG(1) << "Handling conv " << convolution->ToString();
changed_ = false;
ConvolutionDimensionNumbers dim_numbers =
convolution->convolution_dimension_numbers();
ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
int64 activations_batch_dim = dim_numbers.input_batch_dimension();
const int64 old_batch_size =
convolution->operand(0)->shape().dimensions(activations_batch_dim);
auto activations = convolution->mutable_operand(0);
const bool inherent_padding_needed =
c.inherent_low_padding != 0 || c.inherent_high_padding != 0;
VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size; const int64 num_splits = kNewBatchSize / old_batch_size;
auto original_conv = convolution; auto original_conv = convolution;
// We'd need transposition of activations here such that batch and space dim // We'd need transposition of activations here such that batch and space dim
// that is being split are adjacent (in that order). // that is being split are adjacent (in that order).
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(activations,
activations, BringSpaceNextToBatch(activations, dim_numbers,
BringSpaceNextToBatch(activations, dim_numbers, c.spatial_dimension_to_split,
spatial_dimension_to_split, activations_batch_dim)); activations_batch_dim));
// Create the new convolution dim numbers. // Create the new convolution dim numbers.
auto new_dim_numbers = dim_numbers; auto new_dim_numbers = dim_numbers;
@ -1922,11 +2037,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
const int64 output_offsets_per_split = const int64 output_offsets_per_split =
CeilOfRatio(output_offsets, num_splits); CeilOfRatio(output_offsets, num_splits);
int64 spatial_split_size = output_offsets_per_split * stride; int64 spatial_split_size =
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
// Keep increasing the split size so that overall size isn't smaller than the // Keep increasing the split size so that overall size isn't smaller than the
// original spatial dimension. // original spatial dimension.
while (spatial_split_size * num_splits - spatial_size < 0) { while (spatial_split_size * num_splits - c.spatial_size < 0) {
spatial_split_size += stride; spatial_split_size += c.stride;
} }
auto reduce_window = DoesConvolutionFeedReduceWindow(convolution); auto reduce_window = DoesConvolutionFeedReduceWindow(convolution);
@ -1938,33 +2054,32 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
// windows. // windows.
const int64 red_win_stride = const int64 red_win_stride =
reduce_window->window().dimensions(output_spatial_dim).stride(); reduce_window->window().dimensions(output_spatial_dim).stride();
while ((spatial_split_size / stride) % red_win_stride != 0) { while ((spatial_split_size / c.stride) % red_win_stride != 0) {
spatial_split_size += stride; spatial_split_size += c.stride;
} }
} }
const int64 slice_size = const int64 slice_size = spatial_split_size + c.halo_size;
spatial_split_size +
std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0));
// Pad spatial dim. // Pad spatial dim.
const int64 pad_size = spatial_split_size * num_splits - spatial_size; const int64 pad_size = spatial_split_size * num_splits - c.spatial_size;
VLOG(1) << "spatial_split_size " << spatial_split_size << " stride " VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
<< stride; << c.stride << " slice_size " << slice_size;
VLOG(1) << "spatial_dimension_to_split " << spatial_dimension_to_split VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
<< " num_splits " << num_splits << " kernel_spatial_dim_size " << " num_splits " << num_splits << " kernel_spatial_dim_size "
<< kernel_spatial_dim_size; << c.kernel_spatial_dim_size;
// Because we are splitting the spatial dimension, if convolution needed // Because we are splitting the spatial dimension, if convolution needed
// padding in the spatial dimension, we materialize it. // padding in the spatial dimension, we materialize it.
if (pad_size != 0 || inherent_padding_needed) { if (pad_size != 0 || inherent_padding_needed) {
PaddingConfig padding_config = PaddingConfig padding_config =
MakeNoPaddingConfig(activations->shape().dimensions_size()); MakeNoPaddingConfig(activations->shape().dimensions_size());
padding_config.mutable_dimensions(spatial_dimension_to_split) padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_high(inherent_high_padding + pad_size); ->set_edge_padding_high(c.inherent_high_padding + pad_size);
padding_config.mutable_dimensions(spatial_dimension_to_split) padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_low(inherent_low_padding); ->set_edge_padding_low(
c.base_dilation_factor == 1 ? c.inherent_low_padding : 0);
HloInstruction* padding = HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant( computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations->shape().element_type()))); LiteralUtil::Zero(activations->shape().element_type())));
@ -1991,7 +2106,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
activations->shape().dimensions().begin(), activations->shape().dimensions().begin(),
activations->shape().dimensions().end()); activations->shape().dimensions().end());
reshape_dimensions[spatial_dimension_to_split] = spatial_split_size; reshape_dimensions[c.spatial_dimension_to_split] = spatial_split_size;
reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size; reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
@ -2000,12 +2115,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
VLOG(1) << "First reshape done " << batch_increased_reshape->ToString(); VLOG(1) << "First reshape done " << batch_increased_reshape->ToString();
TF_ASSIGN_OR_RETURN(activations, TF_ASSIGN_OR_RETURN(
HaloDuplicateWithSlice( activations, HaloDuplicateWithSlice(batch_increased_reshape,
batch_increased_reshape, spatial_dimension_to_split, c.spatial_dimension_to_split,
activations_batch_dim, old_batch_size, activations_batch_dim, old_batch_size,
/*low_padding=*/0, /*high_padding=*/0, /*low_padding=*/0, /*high_padding=*/0,
slice_size - spatial_split_size, input_dim_size)); c.halo_size, c.input_dim_size));
VLOG(1) << "Batch merge done " << activations->ToString(); VLOG(1) << "Batch merge done " << activations->ToString();
@ -2040,9 +2155,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
<< " batch dim " << new_dim_numbers.input_batch_dimension(); << " batch dim " << new_dim_numbers.input_batch_dimension();
auto new_window = convolution->window(); auto new_window = convolution->window();
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_high(0); ->set_padding_high(c.high_padding_for_conv);
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_low(0); ->set_padding_low(c.low_padding_for_conv);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
HloInstruction * new_conv, HloInstruction * new_conv,
MakeConvolveHlo( MakeConvolveHlo(

View File

@ -113,7 +113,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4); EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
} }
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) { TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithBaseDilation) {
string hlo_string = R"( string hlo_string = R"(
HloModule module HloModule module
@ -129,8 +129,22 @@ ENTRY computation {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string)); ParseAndReturnVerifiedModule(hlo_string));
auto computation = module->entry_computation();
ConvolutionSpaceToBatchConverter converter; ConvolutionSpaceToBatchConverter converter;
ASSERT_FALSE(converter.Run(module.get()).ValueOrDie()); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Transpose());
EXPECT_THAT(root->operand(0), op::Slice());
auto reshape = root->operand(0)->operand(0);
EXPECT_THAT(reshape, op::Reshape());
EXPECT_THAT(reshape->operand(0)->operand(1), op::Convolution());
const int64 batch_dim = reshape->operand(0)
->operand(1)
->convolution_dimension_numbers()
.output_batch_dimension();
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
} }
} // namespace } // namespace