diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index ba1431b11c7..05cdc6e24b7 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -64,6 +65,19 @@ class ConvolutionVisitor { // Top-level function to begin space-to-batch conversion. Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution); + // Struct containing details about a convolution. + struct ConvDetails { + int64 spatial_dimension_to_split, inherent_low_padding, + inherent_high_padding, stride, spatial_size, base_dilation_factor, + halo_size, high_padding_for_conv, low_padding_for_conv, + kernel_spatial_dim_size, input_dim_size; + }; + + // Return a struct containing various necessary information pieces for + // performing space-to-batch on a convolution. + ConvDetails GetConvolutionDetails(HloInstruction* convolution, + ConvolutionDimensionNumbers& dim_numbers); + // Function that determines if space-to-batch can be propagated into the // consumer. Such propagation is only possible when all required operands are // space-to-batch'ed. @@ -225,11 +239,29 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch( return false; } - // TODO(b/168316428): Support base dilations. - if (convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .base_dilation() != 1) { - return false; + const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers); + + const int64 low_pad = convolution->window() + .dimensions(get_chosen_spatial_dim(convolution)) + .padding_low(); + + // TODO(b/168316428): Support base dilations more generically. + if (c.base_dilation_factor != 1) { + if (c.stride != 1) { + return false; + } + // For low pad of 0, only support a pointwise kernel. + if (low_pad == 0) { + if (c.kernel_spatial_dim_size != 1) { + return false; + } + } else if (c.kernel_spatial_dim_size != c.base_dilation_factor + 1 || + low_pad != c.base_dilation_factor - 1) { + // Only support dilations such that base dilation factor and low pad are + // compatible with kernel_spatial_dim_size to be compatible with + // HaloDuplicateWithSlice. + return false; + } } int64 activations_batch_dim = dim_numbers.input_batch_dimension(); @@ -240,42 +272,17 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch( if (old_batch_size > limit_on_batch_size_) { return false; } - - auto kernel = convolution->mutable_operand(1); - const auto& kernel_shape = kernel->shape(); - const int64 kernel_spatial_dim_size = - kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions( - get_chosen_spatial_dim(convolution))); - - auto activations = convolution->mutable_operand(0); - - const int64 input_dim_size = - activations->shape().dimensions(dim_numbers.input_spatial_dimensions( - get_chosen_spatial_dim(convolution))); - - const int64 inherent_low_padding = - convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .padding_low(); - const int64 inherent_high_padding = - convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .padding_high(); - - const int64 spatial_size = - input_dim_size + inherent_low_padding + inherent_high_padding; - VLOG(1) << "spatial size " << spatial_size; - - const int64 num_splits = kNewBatchSize / old_batch_size; - // We currently only cater to evenly divisible cases. if (kNewBatchSize % old_batch_size != 0) { return false; } - // Splitting will be incorrect in these cases. - if (spatial_size < num_splits || - input_dim_size / num_splits < kernel_spatial_dim_size) { + VLOG(1) << "spatial size " << c.spatial_size; + + const int64 num_splits = kNewBatchSize / old_batch_size; + // If the ratio is not within the 2X range, we can't Halo Pad from the next + // split. + if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) { return false; } VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString(); @@ -292,8 +299,8 @@ StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( activations->shape().dimensions(spatial_dimension_to_split); const int64 batch_size = activations->shape().dimensions(activations_batch_dim); - CHECK_LT(low_padding, spatial_split_size); + CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size); VLOG(1) << "In HaloDuplicateWithSlice with activations " << activations->ToString() << " batch_size " << batch_size << " spatial_split_size " << spatial_split_size << " low_padding " @@ -439,6 +446,7 @@ StatusOr ConvolutionVisitor::Run() { // Iterate through all instructions that we could not propagate through, and // turn their operands from batch-to-space as needed. for (auto instr : non_propagatable_instrs_) { + VLOG(1) << "Could not eventually propagate through " << instr->ToString(); absl::flat_hash_map operand_map; for (int64 i = 0; i < instr->operand_count(); ++i) { if (old_to_new_instrs_.count(instr->mutable_operand(i))) { @@ -480,8 +488,9 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, if (!old_to_new_instrs_.contains(old_producer) && !broadcast_or_constant) { - VLOG(1) << "Cannot propagate on elementwise op " - << consumer->ToString(); + VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString() + << " because operand " << old_producer->ToString() + << " isn't ready "; return false; } else { if (broadcast_or_constant) { @@ -496,10 +505,11 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, pivot_operand = old_producer; VLOG(2) << "Elementwise op: pivot " << old_producer->ToString(); } else { - VLOG(2) << "Elementwise op: checking for shape equivalence " - << consumer->ToString(); if (instr_to_dim_map_[pivot_operand] != instr_to_dim_map_[old_producer]) { + VLOG(2) << "Elementwise op: checking for shape equivalence " + << consumer->ToString() + << " failed due to changed batch space ordering "; return false; } auto pivot_new_instr = old_to_new_instrs_[pivot_operand]; @@ -509,13 +519,22 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, for (int j = 0; j < pivot_permute_dims.size(); ++j) { // Ensure the dimension mapping is the same. if (pivot_permute_dims[j] != permute_dims[j]) { + VLOG(2) << "Elementwise op: checking for shape equivalence " + << consumer->ToString() + << " failed due to permuted dimensions "; return false; } // Make sure all other dimensions are of the same size. if (pivot_new_instr->shape().dimensions(j) != new_instr->shape().dimensions(j)) { - return false; + if (!(consumer->IsElementwiseBinary() && + j == instr_to_dim_map_[pivot_operand].second)) { + VLOG(2) << "Elementwise op: checking for shape equivalence " + << consumer->ToString() + << " failed due to changed shape sizes "; + return false; + } } } } @@ -769,6 +788,28 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, if (IsTrivialElementwise(consumer)) { auto dim_map_val = instr_to_dim_map_[producer]; auto new_consumer = computation->AddInstruction(consumer->Clone()); + if (consumer->IsElementwiseBinary()) { + for (int64 i = 0; i < 2; ++i) { + if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { + break; + } + CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i))); + if (i == 1) { + // Choose the larger shape to be used as the producer. + if (old_to_new_instrs_[consumer->mutable_operand(0)] + ->shape() + .dimensions() > + old_to_new_instrs_[consumer->mutable_operand(1)] + ->shape() + .dimensions()) { + producer = consumer->mutable_operand(0); + } else { + producer = consumer->mutable_operand(1); + } + } + } + } + for (int64 i = 0; i < consumer->operand_count(); ++i) { if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { CHECK(old_to_new_instrs_.contains(producer)); @@ -786,8 +827,66 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast)); } else { CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i))); - TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape( - i, old_to_new_instrs_[consumer->mutable_operand(i)])); + HloInstruction* operand_to_use = nullptr; + + auto result = instr_to_dim_map_[producer]; + const int64 old_batch_dim = result.first; + const int64 old_space_dim = result.second; + const int64 old_batch_size = + producer->shape().dimensions(old_batch_dim); + HloInstruction* new_instr = + old_to_new_instrs_[consumer->mutable_operand(i)]; + HloInstruction* pivot_new_instr = old_to_new_instrs_[producer]; + + auto permute_dims = instr_to_dim_permute_map_[new_instr]; + const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim); + const int64 space_dim = DimLookUp(permute_dims, old_space_dim); + const int64 batch_size = new_instr->shape().dimensions(batch_dim); + + if (new_instr->shape().dimensions(space_dim) != + pivot_new_instr->shape().dimensions(space_dim)) { + // Because we do not propagate through transposes, the batch should + // always be followed by the split space dimension. + CHECK_EQ(batch_dim + 1, space_dim); + + // Reshape to 1D, pad to the producer's size, reshape back to 2D. + std::vector new_dimensions( + new_instr->shape().dimensions().begin(), + new_instr->shape().dimensions().end()); + new_dimensions[space_dim] *= (batch_size / old_batch_size); + new_dimensions[batch_dim] = old_batch_size; + + TF_ASSIGN_OR_RETURN(HloInstruction * reshape, + MakeReshapeHlo(new_dimensions, new_instr)); + + const int64 pivot_space_size = + pivot_new_instr->shape().dimensions(space_dim) * batch_size / + old_batch_size; + + CHECK_GT(pivot_space_size, new_dimensions[space_dim]); + + PaddingConfig padding_config = + MakeNoPaddingConfig(reshape->shape().dimensions_size()); + padding_config.mutable_dimensions(space_dim)->set_edge_padding_high( + pivot_space_size - new_dimensions[space_dim]); + padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0); + HloInstruction* padding = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(reshape->shape().element_type()))); + + TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand, + MakePadHlo(reshape, padding, padding_config)); + + TF_ASSIGN_OR_RETURN( + operand_to_use, + MakeReshapeHlo(pivot_new_instr->shape().dimensions(), + padded_operand)); + + } else { + operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)]; + } + TF_CHECK_OK( + new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use)); } } auto old_type = new_consumer->mutable_shape()->element_type(); @@ -1329,25 +1428,21 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { original_conv_dims.input_spatial_dimensions(i))); } - int64 spatial_dimension_to_split = - permuted_conv_dims_numbers.input_spatial_dimensions( - get_chosen_spatial_dim(convolution)); - const int64 old_batch_dim = original_conv_dims.input_batch_dimension(); const int64 old_batch_size = activations_old->shape().dimensions(old_batch_dim); - const int64 input_dim_size = activations_old->shape().dimensions( - permuted_conv_dims_numbers.input_spatial_dimensions( - get_chosen_spatial_dim(convolution))); + ConvDetails c = + GetConvolutionDetails(convolution, permuted_conv_dims_numbers); VLOG(1) << "Propagating on conv activations_batch_dim " << activations_batch_dim << " spatial_dimension_to_split " - << spatial_dimension_to_split << " old_batch_size " << old_batch_size; - TF_ASSIGN_OR_RETURN( - activations_new, - BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers, - spatial_dimension_to_split, activations_batch_dim)); + << c.spatial_dimension_to_split << " old_batch_size " + << old_batch_size; + TF_ASSIGN_OR_RETURN(activations_new, + BringSpaceNextToBatch( + activations_new, permuted_conv_dims_numbers, + c.spatial_dimension_to_split, activations_batch_dim)); auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(activations_new->shape().element_type()))); @@ -1355,32 +1450,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { TF_ASSIGN_OR_RETURN( activations_new, SelectValidPortion(activations_new, activations_old, select_val, - activations_batch_dim, spatial_dimension_to_split, + activations_batch_dim, c.spatial_dimension_to_split, old_batch_dim, old_space_dim)); // Create the new convolution dim numbers. auto new_dim_numbers = permuted_conv_dims_numbers; - auto kernel = convolution->mutable_operand(1); - const auto& kernel_shape = kernel->shape(); - const int64 kernel_spatial_dim_size = kernel_shape.dimensions( - permuted_conv_dims_numbers.kernel_spatial_dimensions( - get_chosen_spatial_dim(convolution))); - - const int64 inherent_low_padding = - convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .padding_low(); - const int64 inherent_high_padding = - convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .padding_high(); - const int64 stride = convolution->window() - .dimensions(get_chosen_spatial_dim(convolution)) - .stride(); - - const int64 spatial_size = - input_dim_size + inherent_low_padding + inherent_high_padding; - VLOG(1) << "spatial size " << spatial_size; + VLOG(1) << "spatial size " << c.spatial_size; const int64 num_splits = kNewBatchSize / old_batch_size; @@ -1390,18 +1465,18 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { const int64 output_offsets_per_split = CeilOfRatio(output_offsets, num_splits); - int64 spatial_split_size = output_offsets_per_split * stride; - const int64 halo_size = - std::max(kernel_spatial_dim_size - stride, static_cast(0)); + int64 spatial_split_size = + CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride; + // Keep increasing the split size so that overall size isn't smaller than the // original spatial dimension. Unlike for the first space-to-batch'ed // convolution, while propagating, we can use the last halo_size as available // spatial size. - while (spatial_split_size * num_splits + halo_size - spatial_size < 0) { - spatial_split_size += stride; + while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) { + spatial_split_size += c.stride; } - int64 slice_size = spatial_split_size + halo_size; + int64 slice_size = spatial_split_size + c.halo_size; VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size " << slice_size; @@ -1409,7 +1484,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { const int64 new_batch_size = activations_new->shape().dimensions(activations_batch_dim); const int64 new_space_size = - activations_new->shape().dimensions(spatial_dimension_to_split); + activations_new->shape().dimensions(c.spatial_dimension_to_split); // In the below case, we cannot use the activations directly for Halo // Duplication. We must reshape them. if (spatial_split_size > new_space_size) { @@ -1418,7 +1493,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { activations_new->shape().dimensions().end()); const int64 reshaped_space_size = new_space_size * new_batch_size / old_batch_size; - new_dimensions[spatial_dimension_to_split] = reshaped_space_size; + new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size; new_dimensions[activations_batch_dim] = old_batch_size; // Reshape the output of the new conv into the old convolutions shape. @@ -1427,10 +1502,10 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { PaddingConfig padding_config = MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size()); - padding_config.mutable_dimensions(spatial_dimension_to_split) + padding_config.mutable_dimensions(c.spatial_dimension_to_split) ->set_edge_padding_high(spatial_split_size * new_batch_size - reshaped_space_size); - padding_config.mutable_dimensions(spatial_dimension_to_split) + padding_config.mutable_dimensions(c.spatial_dimension_to_split) ->set_edge_padding_low(0); HloInstruction* padding = computation_->AddInstruction(HloInstruction::CreateConstant( @@ -1444,7 +1519,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { reshaped_activations->shape().dimensions().begin(), reshaped_activations->shape().dimensions().end()); - reshape_back_dims[spatial_dimension_to_split] = spatial_split_size; + reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size; reshape_back_dims[activations_batch_dim] = new_batch_size; TF_ASSIGN_OR_RETURN( @@ -1453,34 +1528,38 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { TF_ASSIGN_OR_RETURN( activations_new, - HaloDuplicateWithSlice(reshaped_activations, spatial_dimension_to_split, - activations_batch_dim, old_batch_size, - /*low_padding=*/inherent_low_padding, - /*high_padding=*/inherent_high_padding, - slice_size - spatial_split_size, - old_split_dim_size)); + HaloDuplicateWithSlice( + reshaped_activations, c.spatial_dimension_to_split, + activations_batch_dim, old_batch_size, + /*low_padding=*/c.base_dilation_factor != 1 && + c.inherent_low_padding != 0 + ? c.base_dilation_factor - 1 + : c.inherent_low_padding, + c.inherent_high_padding, slice_size - spatial_split_size, + old_split_dim_size)); } else { // If the ideal spatial_split_size was smaller than the incoming spatial // dimension size, we don't need reshaping. Instead, we determine the // additional space available, and adjust the required slice size (and - // thereby the halo size).'t need reshaping. Instead, we determine the - // additional space available, and adjust the required slice size (and // thereby the halo size). if (spatial_split_size < new_space_size) { - const int64 additional_space_present = spatial_split_size % stride; + const int64 additional_space_present = spatial_split_size % c.stride; spatial_split_size = new_space_size; slice_size = - spatial_split_size + - std::max(kernel_spatial_dim_size - stride - additional_space_present, - static_cast(0)); + spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride - + additional_space_present, + static_cast(0)); } TF_ASSIGN_OR_RETURN( activations_new, - HaloDuplicateWithSlice(activations_new, spatial_dimension_to_split, + HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split, activations_batch_dim, old_batch_size, - /*low_padding=*/inherent_low_padding, - /*high_padding=*/inherent_high_padding, + /*low_padding=*/c.base_dilation_factor != 1 && + c.inherent_low_padding != 0 + ? c.base_dilation_factor - 1 + : c.inherent_low_padding, + c.inherent_high_padding, slice_size - spatial_split_size, old_split_dim_size)); } @@ -1515,9 +1594,9 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { auto new_window = convolution->window(); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_padding_high(0); + ->set_padding_high(c.high_padding_for_conv); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_padding_low(0); + ->set_padding_low(c.low_padding_for_conv); TF_ASSIGN_OR_RETURN( HloInstruction * new_conv, MakeConvolveHlo( @@ -1855,19 +1934,9 @@ HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow( return nullptr; } -Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( - HloInstruction* convolution) { - VLOG(1) << "Handling conv " << convolution->ToString(); - - changed_ = false; - - ConvolutionDimensionNumbers dim_numbers = - convolution->convolution_dimension_numbers(); - - int64 activations_batch_dim = dim_numbers.input_batch_dimension(); - - const int64 old_batch_size = - convolution->operand(0)->shape().dimensions(activations_batch_dim); +ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails( + HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) { + auto activations = convolution->mutable_operand(0); auto kernel = convolution->mutable_operand(1); const auto& kernel_shape = kernel->shape(); @@ -1875,14 +1944,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions( get_chosen_spatial_dim(convolution))); - auto activations = convolution->mutable_operand(0); - - int64 spatial_dimension_to_split = + const int64 spatial_dimension_to_split = dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution)); const int64 input_dim_size = - activations->shape().dimensions(dim_numbers.input_spatial_dimensions( - get_chosen_spatial_dim(convolution))); + activations->shape().dimensions(spatial_dimension_to_split); const int64 inherent_low_padding = convolution->window() @@ -1892,26 +1958,75 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( convolution->window() .dimensions(get_chosen_spatial_dim(convolution)) .padding_high(); - const bool inherent_padding_needed = - inherent_low_padding != 0 || inherent_high_padding != 0; const int64 stride = convolution->window() .dimensions(get_chosen_spatial_dim(convolution)) .stride(); + const int64 base_dilation_factor = + convolution->window() + .dimensions(get_chosen_spatial_dim(convolution)) + .base_dilation(); + const int64 spatial_size = - input_dim_size + inherent_low_padding + inherent_high_padding; - VLOG(1) << "spatial size " << spatial_size; + input_dim_size + (base_dilation_factor > 1 ? 0 : inherent_low_padding) + + inherent_high_padding; + + const int64 halo_size = + std::max(kernel_spatial_dim_size - stride - (base_dilation_factor - 1), + static_cast(0)); + const int64 high_padding_for_conv = base_dilation_factor == 1 ? 0 + : inherent_low_padding == 0 + ? base_dilation_factor - 1 + : 0; + const int64 low_padding_for_conv = + base_dilation_factor == 1 ? 0 : inherent_low_padding; + + return ConvDetails{spatial_dimension_to_split, + inherent_low_padding, + inherent_high_padding, + stride, + spatial_size, + base_dilation_factor, + halo_size, + high_padding_for_conv, + low_padding_for_conv, + kernel_spatial_dim_size, + input_dim_size}; +} + +Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( + HloInstruction* convolution) { + VLOG(1) << "Handling conv " << convolution->ToString(); + + changed_ = false; + + ConvolutionDimensionNumbers dim_numbers = + convolution->convolution_dimension_numbers(); + + ConvDetails c = GetConvolutionDetails(convolution, dim_numbers); + + int64 activations_batch_dim = dim_numbers.input_batch_dimension(); + + const int64 old_batch_size = + convolution->operand(0)->shape().dimensions(activations_batch_dim); + + auto activations = convolution->mutable_operand(0); + + const bool inherent_padding_needed = + c.inherent_low_padding != 0 || c.inherent_high_padding != 0; + + VLOG(1) << "spatial size " << c.spatial_size; const int64 num_splits = kNewBatchSize / old_batch_size; auto original_conv = convolution; // We'd need transposition of activations here such that batch and space dim // that is being split are adjacent (in that order). - TF_ASSIGN_OR_RETURN( - activations, - BringSpaceNextToBatch(activations, dim_numbers, - spatial_dimension_to_split, activations_batch_dim)); + TF_ASSIGN_OR_RETURN(activations, + BringSpaceNextToBatch(activations, dim_numbers, + c.spatial_dimension_to_split, + activations_batch_dim)); // Create the new convolution dim numbers. auto new_dim_numbers = dim_numbers; @@ -1922,11 +2037,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( const int64 output_offsets_per_split = CeilOfRatio(output_offsets, num_splits); - int64 spatial_split_size = output_offsets_per_split * stride; + int64 spatial_split_size = + CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride; // Keep increasing the split size so that overall size isn't smaller than the // original spatial dimension. - while (spatial_split_size * num_splits - spatial_size < 0) { - spatial_split_size += stride; + while (spatial_split_size * num_splits - c.spatial_size < 0) { + spatial_split_size += c.stride; } auto reduce_window = DoesConvolutionFeedReduceWindow(convolution); @@ -1938,33 +2054,32 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( // windows. const int64 red_win_stride = reduce_window->window().dimensions(output_spatial_dim).stride(); - while ((spatial_split_size / stride) % red_win_stride != 0) { - spatial_split_size += stride; + while ((spatial_split_size / c.stride) % red_win_stride != 0) { + spatial_split_size += c.stride; } } - const int64 slice_size = - spatial_split_size + - std::max(kernel_spatial_dim_size - stride, static_cast(0)); + const int64 slice_size = spatial_split_size + c.halo_size; // Pad spatial dim. - const int64 pad_size = spatial_split_size * num_splits - spatial_size; + const int64 pad_size = spatial_split_size * num_splits - c.spatial_size; VLOG(1) << "spatial_split_size " << spatial_split_size << " stride " - << stride; - VLOG(1) << "spatial_dimension_to_split " << spatial_dimension_to_split + << c.stride << " slice_size " << slice_size; + VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split << " num_splits " << num_splits << " kernel_spatial_dim_size " - << kernel_spatial_dim_size; + << c.kernel_spatial_dim_size; // Because we are splitting the spatial dimension, if convolution needed // padding in the spatial dimension, we materialize it. if (pad_size != 0 || inherent_padding_needed) { PaddingConfig padding_config = MakeNoPaddingConfig(activations->shape().dimensions_size()); - padding_config.mutable_dimensions(spatial_dimension_to_split) - ->set_edge_padding_high(inherent_high_padding + pad_size); - padding_config.mutable_dimensions(spatial_dimension_to_split) - ->set_edge_padding_low(inherent_low_padding); + padding_config.mutable_dimensions(c.spatial_dimension_to_split) + ->set_edge_padding_high(c.inherent_high_padding + pad_size); + padding_config.mutable_dimensions(c.spatial_dimension_to_split) + ->set_edge_padding_low( + c.base_dilation_factor == 1 ? c.inherent_low_padding : 0); HloInstruction* padding = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(activations->shape().element_type()))); @@ -1991,7 +2106,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( activations->shape().dimensions().begin(), activations->shape().dimensions().end()); - reshape_dimensions[spatial_dimension_to_split] = spatial_split_size; + reshape_dimensions[c.spatial_dimension_to_split] = spatial_split_size; reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size; TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, @@ -2000,12 +2115,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( VLOG(1) << "First reshape done " << batch_increased_reshape->ToString(); - TF_ASSIGN_OR_RETURN(activations, - HaloDuplicateWithSlice( - batch_increased_reshape, spatial_dimension_to_split, - activations_batch_dim, old_batch_size, - /*low_padding=*/0, /*high_padding=*/0, - slice_size - spatial_split_size, input_dim_size)); + TF_ASSIGN_OR_RETURN( + activations, HaloDuplicateWithSlice(batch_increased_reshape, + c.spatial_dimension_to_split, + activations_batch_dim, old_batch_size, + /*low_padding=*/0, /*high_padding=*/0, + c.halo_size, c.input_dim_size)); VLOG(1) << "Batch merge done " << activations->ToString(); @@ -2040,9 +2155,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( << " batch dim " << new_dim_numbers.input_batch_dimension(); auto new_window = convolution->window(); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_padding_high(0); + ->set_padding_high(c.high_padding_for_conv); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_padding_low(0); + ->set_padding_low(c.low_padding_for_conv); TF_ASSIGN_OR_RETURN( HloInstruction * new_conv, MakeConvolveHlo( diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc index d53bb7d75f3..8921d98cad0 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc @@ -113,7 +113,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4); } -TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) { +TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithBaseDilation) { string hlo_string = R"( HloModule module @@ -129,8 +129,22 @@ ENTRY computation { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = module->entry_computation(); ConvolutionSpaceToBatchConverter converter; - ASSERT_FALSE(converter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Transpose()); + EXPECT_THAT(root->operand(0), op::Slice()); + auto reshape = root->operand(0)->operand(0); + EXPECT_THAT(reshape, op::Reshape()); + EXPECT_THAT(reshape->operand(0)->operand(1), op::Convolution()); + const int64 batch_dim = reshape->operand(0) + ->operand(1) + ->convolution_dimension_numbers() + .output_batch_dimension(); + + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4); } } // namespace