diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index 8f7cc1af74a..7050a5289e9 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -82,7 +82,8 @@ class ConvolutionVisitor { // Function that determines if space-to-batch can be propagated into the // consumer. Such propagation is only possible when all required operands are // space-to-batch'ed. - bool CanPropagate(HloInstruction* consumer, HloInstruction* producer); + bool CanPropagate(HloInstruction* consumer, HloInstruction* producer, + bool last_try = false); // Returns true if the op has all its direct and indirect operands being // created via broadcasts. Consumer uses op, and is space-to-batched. @@ -116,7 +117,7 @@ class ConvolutionVisitor { HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64& spatial_dimension_to_split, int64& activations_batch_dim, int64 high_padding, int64 low_padding, int64 spatial_split_size, - int64 num_splits, bool is_backprop = false); + int64 num_splits, bool is_backprop = false, bool is_rhs = false); // Perform space-to-batch propagation on the convolution. Assumes the // activations were already space-to-batched. @@ -149,7 +150,13 @@ class ConvolutionVisitor { StatusOr BringSpaceNextToBatch( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64& spatial_dimension_to_split, int64& activations_batch_dim, - bool is_backprop = false); + bool is_backprop = false, bool is_rhs = false); + + // Increases the spatial dimension size in an already space-to-batched shape + // so that the new size is new_spatial_dim_size. + StatusOr IncreaseSpatialSizeOnSpaceToBatchedShape( + HloInstruction* activations, int64 batch_dimension, int64 old_batch_size, + int64 spatial_dimension, int64 new_spatial_dim_size); // Function that converts spaced-to-batch shape back to the original. StatusOr BatchToSpace(HloInstruction* old_instr); @@ -213,6 +220,10 @@ class ConvolutionVisitor { absl::flat_hash_map> instr_to_dim_permute_map_; + // Map maintaining previously space-to-batched broadcasts. + absl::flat_hash_map> + broadcast_map_; + // Whether rewrite has occurred. bool changed_ = false; @@ -456,7 +467,7 @@ StatusOr ConvolutionVisitor::BringSpaceNextToBatch( HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, int64& spatial_dimension_to_split, int64& activations_batch_dim, - bool is_backprop) { + bool is_backprop, bool is_rhs) { std::vector transpose_dims(activations->shape().rank()); if (spatial_dimension_to_split == activations_batch_dim + 1) { absl::c_iota(transpose_dims, 0); @@ -465,49 +476,137 @@ ConvolutionVisitor::BringSpaceNextToBatch( int64 pushed_counter = 0; int64 new_batch_dim, new_spatial_dim; int64 dim_counter = 0; - for (int i = 0; i < activations->shape().rank(); ++i) { - if (i == activations_batch_dim) { - continue; - } - if (i == spatial_dimension_to_split) { - transpose_dims[dim_counter++] = activations_batch_dim; - new_batch_dim = pushed_counter; - pushed_counter++; - new_spatial_dim = pushed_counter; - } + if (is_rhs) { + CHECK(is_backprop); + for (int i = 0; i < activations->shape().rank(); ++i) { + if (i == activations_batch_dim) { + continue; + } + if (i == spatial_dimension_to_split) { + transpose_dims[dim_counter++] = activations_batch_dim; + new_batch_dim = pushed_counter; + pushed_counter++; + new_spatial_dim = pushed_counter; + } - if (is_backprop && i == dim_numbers.input_batch_dimension()) { - new_dim_numbers.set_input_batch_dimension(pushed_counter); - } else if (i == dim_numbers.input_feature_dimension()) { - new_dim_numbers.set_input_feature_dimension(pushed_counter); - } else { - for (int j = 0; j < dim_numbers.input_spatial_dimensions_size(); ++j) { - if (i == dim_numbers.input_spatial_dimensions(j)) { - new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter); - break; + if (i == dim_numbers.kernel_output_feature_dimension()) { + new_dim_numbers.set_kernel_output_feature_dimension(pushed_counter); + } else { + auto it = absl::c_find(dim_numbers.kernel_spatial_dimensions(), i); + if (it != dim_numbers.kernel_spatial_dimensions().end()) { + int64 j = it - dim_numbers.kernel_spatial_dimensions().begin(); + new_dim_numbers.set_kernel_spatial_dimensions(j, pushed_counter); } } + transpose_dims[dim_counter++] = i; + pushed_counter++; } - transpose_dims[dim_counter++] = i; - pushed_counter++; - } - activations_batch_dim = new_batch_dim; - spatial_dimension_to_split = new_spatial_dim; - TF_ASSIGN_OR_RETURN(activations, - MakeTransposeHlo(activations, transpose_dims)); + activations_batch_dim = new_batch_dim; + spatial_dimension_to_split = new_spatial_dim; + TF_ASSIGN_OR_RETURN(activations, + MakeTransposeHlo(activations, transpose_dims)); + + new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim); - if (is_backprop) { - new_dim_numbers.set_input_feature_dimension(activations_batch_dim); } else { - new_dim_numbers.set_input_batch_dimension(activations_batch_dim); + for (int i = 0; i < activations->shape().rank(); ++i) { + if (i == activations_batch_dim) { + continue; + } + if (i == spatial_dimension_to_split) { + transpose_dims[dim_counter++] = activations_batch_dim; + new_batch_dim = pushed_counter; + pushed_counter++; + new_spatial_dim = pushed_counter; + } + + if (is_backprop && i == dim_numbers.input_batch_dimension()) { + new_dim_numbers.set_input_batch_dimension(pushed_counter); + } else if (i == dim_numbers.input_feature_dimension()) { + new_dim_numbers.set_input_feature_dimension(pushed_counter); + } else { + auto it = absl::c_find(dim_numbers.input_spatial_dimensions(), i); + if (it != dim_numbers.input_spatial_dimensions().end()) { + int64 j = it - dim_numbers.input_spatial_dimensions().begin(); + new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter); + } + } + transpose_dims[dim_counter++] = i; + pushed_counter++; + } + + activations_batch_dim = new_batch_dim; + spatial_dimension_to_split = new_spatial_dim; + TF_ASSIGN_OR_RETURN(activations, + MakeTransposeHlo(activations, transpose_dims)); + + if (is_backprop) { + new_dim_numbers.set_input_feature_dimension(activations_batch_dim); + } else { + new_dim_numbers.set_input_batch_dimension(activations_batch_dim); + } } + dim_numbers = new_dim_numbers; } return SpaceNextToBatchDetails{activations, transpose_dims}; } +StatusOr +ConvolutionVisitor::IncreaseSpatialSizeOnSpaceToBatchedShape( + HloInstruction* activations, int64 batch_dimension, int64 old_batch_size, + int64 spatial_dimension, int64 new_spatial_dim_size) { + CHECK_EQ(batch_dimension + 1, spatial_dimension); + std::vector new_dimensions(activations->shape().dimensions().begin(), + activations->shape().dimensions().end()); + + const int64 new_batch_size = activations->shape().dimensions(batch_dimension); + int64 spatial_dim_size = activations->shape().dimensions(spatial_dimension); + const int64 reshaped_space_size = + spatial_dim_size * new_batch_size / old_batch_size; + + VLOG(3) << "Increasing the spatial size while propagating new_batch_size " + << new_batch_size << " old_batch_size " << old_batch_size; + new_dimensions[spatial_dimension] = reshaped_space_size; + new_dimensions[batch_dimension] = old_batch_size; + + // Reshape the output of the new conv into the old convolutions shape. + TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations, + MakeReshapeHlo(new_dimensions, activations)); + + VLOG(3) << "First reshape done"; + PaddingConfig padding_config = + MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size()); + padding_config.mutable_dimensions(spatial_dimension) + ->set_edge_padding_high(new_spatial_dim_size * new_batch_size / + old_batch_size - + reshaped_space_size); + padding_config.mutable_dimensions(spatial_dimension)->set_edge_padding_low(0); + HloInstruction* padding = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(reshaped_activations->shape().element_type()))); + + TF_ASSIGN_OR_RETURN( + reshaped_activations, + MakePadHlo(reshaped_activations, padding, padding_config)); + + std::vector reshape_back_dims( + reshaped_activations->shape().dimensions().begin(), + reshaped_activations->shape().dimensions().end()); + + reshape_back_dims[spatial_dimension] = new_spatial_dim_size; + reshape_back_dims[batch_dimension] = new_batch_size; + + TF_ASSIGN_OR_RETURN(HloInstruction * activations_new, + MakeReshapeHlo(reshape_back_dims, reshaped_activations)); + + VLOG(3) << "Size increased activations " << activations_new->ToString(); + + return activations_new; +} + StatusOr ConvolutionVisitor::Run() { for (auto conv : conv_visitor_list_) { if (convs_to_visit_.count(conv) > 0) { @@ -519,6 +618,29 @@ StatusOr ConvolutionVisitor::Run() { // Iterate through all instructions that we could not propagate through, and // turn their operands from batch-to-space as needed. for (auto instr : non_propagatable_instrs_) { + if (instr->opcode() == HloOpcode::kConvolution) { + VLOG(1) << "Instr " << instr->ToString(); + } + // Try to propagate on backprop filters + if (instr->opcode() == HloOpcode::kConvolution && + !IsConvSuitableForSpaceToBatch(instr)) { + HloInstruction* producer = nullptr; + if (old_to_new_instrs_.contains(instr->mutable_operand(0))) { + producer = instr->mutable_operand(0); + } else if (old_to_new_instrs_.contains(instr->mutable_operand(1))) { + producer = instr->mutable_operand(1); + } + if (producer) { + if (CanPropagate(instr, producer, /*last_try=*/true)) { + bool needs_further_propagation; + TF_ASSIGN_OR_RETURN(needs_further_propagation, + Propagate(instr, producer)); + TF_CHECK_OK(computation_->ReplaceInstruction( + instr, old_to_new_instrs_[instr])); + continue; + } + } + } VLOG(1) << "Could not eventually propagate through " << instr->ToString(); absl::flat_hash_map operand_map; for (int64 i = 0; i < instr->operand_count(); ++i) { @@ -539,14 +661,14 @@ bool IsTrivialElementwise(HloInstruction* hlo) { if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng || hlo->opcode() == HloOpcode::kCopy || hlo->opcode() == HloOpcode::kConstant || - hlo->opcode() == HloOpcode::kIota) { + hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) { return false; } return hlo->IsElementwise(); } bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, - HloInstruction* producer) { + HloInstruction* producer, bool last_try) { if (IsTrivialElementwise(consumer)) { VLOG(2) << "Doing propagation check on elementwise op: " << consumer->ToString(); @@ -604,7 +726,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, // Make sure all other dimensions are of the same size. if (pivot_new_instr->shape().dimensions(j) != new_instr->shape().dimensions(j)) { - if (!(consumer->IsElementwiseBinary() && + if (!((consumer->IsElementwiseBinary() || + consumer->opcode() == HloOpcode::kSelect) && j == instr_to_dim_map_[pivot_operand].second)) { VLOG(2) << "Elementwise op: checking for shape equivalence " << consumer->ToString() @@ -653,20 +776,70 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, return false; } + // Same reason why we give up on batch group counts applies to features in + // backprop. + if (consumer->feature_group_count() != 1) { + return false; + } + VLOG(2) << "Checking for backprop filter conv propagatability"; CHECK_EQ(consumer->operand_count(), 2); - VLOG(2) << "Checking for backprop filter conv operands " - << consumer->operand_count(); auto activations = consumer->mutable_operand(0); auto kernel = consumer->mutable_operand(1); - if (!old_to_new_instrs_.contains(kernel)) { - VLOG(2) << "Backprop filter conv not ready for propagation because of " - "kernel is not space-to-batched"; + if (!last_try) { + if (!old_to_new_instrs_.contains(kernel) || + !old_to_new_instrs_.contains(activations)) { + return false; + } + } + + if (!old_to_new_instrs_.contains(kernel) && + !old_to_new_instrs_.contains(activations)) { return false; } + const int64 rhs_dilation = consumer->window() + .dimensions(get_chosen_spatial_dim(consumer)) + .window_dilation(); + + if (!old_to_new_instrs_.contains(kernel)) { + const int64 rhs_batch = + kernel->shape().dimensions(consumer->convolution_dimension_numbers() + .kernel_input_feature_dimension()); + auto dim_map_val_op_0 = instr_to_dim_map_[activations]; + const int64 old_batch_dim = dim_map_val_op_0.first; + const int64 old_space_dim = dim_map_val_op_0.second; + auto first_operand = old_to_new_instrs_[activations]; + auto permute_dims_first_operand = + instr_to_dim_permute_map_[first_operand]; + const int64 new_batch_dim = + DimLookUp(permute_dims_first_operand, old_batch_dim); + const int64 new_space_dim = + DimLookUp(permute_dims_first_operand, old_space_dim); + const int64 lhs_batch = first_operand->shape().dimensions(new_batch_dim); + + if (first_operand->shape().dimensions(new_space_dim) % rhs_dilation != + 0) { + return false; + } + // Because we want to convert activations into a space-to-batched version + // only for backprop filter convolutions, we want to make sure that the + // batch dimensions (feature dimensions, technically) are same sized. + // Since LHS is already space-to-batched, we need to account for it too. + if (rhs_batch * kNumSplits != lhs_batch) { + return false; + } + + // If kernel have not been propagated through, we can do + // space-to-batch on them provided kernel has been propagated. + VLOG(2) + << "Backprop filter conv ready for propagation: activations ready, " + " kernel will be space-to-batched"; + return true; + } + if (!old_to_new_instrs_.contains(activations)) { const int64 lhs_batch = activations->shape().dimensions( consumer->convolution_dimension_numbers().input_feature_dimension()); @@ -720,10 +893,7 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, return false; } - const int64 rhs_dilation = consumer->window() - .dimensions(get_chosen_spatial_dim(consumer)) - .window_dilation(); - if (first_operand->shape().dimensions(new_space_dim_operand_0) != + if (first_operand->shape().dimensions(new_space_dim_operand_0) > rhs_dilation * second_operand->shape().dimensions(new_space_dim_operand_1)) { VLOG(2) << "Backprop filter conv not ready for propagation because of " @@ -816,20 +986,57 @@ void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer, auto new_producer = old_to_new_instrs_[producer]; auto permute_dims = instr_to_dim_permute_map_[new_producer]; auto dim_map_val = instr_to_dim_map_[producer]; + + const int64 old_batch_dim = dim_map_val.first; + const int64 old_space_dim = dim_map_val.second; + + auto orig_broadcast_dims = consumer->dimensions(); + + bool batch_is_broadcasted = + absl::c_linear_search(orig_broadcast_dims, old_batch_dim); + const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim); + const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim); + + bool map_found = broadcast_map_.contains(consumer); + if (map_found) { + // Check if we previously had created the same broadcast. + for (auto previous_broadcast : broadcast_map_[consumer]) { + if (ShapeUtil::CompatibleIgnoringElementType(previous_broadcast->shape(), + new_producer->shape())) { + return; + } + } + } + + std::vector final_shape_dims( + new_producer->shape().dimensions().begin(), + new_producer->shape().dimensions().end()); + if (batch_is_broadcasted) { + final_shape_dims[new_batch_dim] = + producer->shape().dimensions(old_batch_dim); + final_shape_dims[new_space_dim] *= kNumSplits; + } + std::vector broadcast_dims; for (auto j : consumer->dimensions()) { broadcast_dims.push_back(DimLookUp(permute_dims, j)); } - auto new_broadcast = - MakeBroadcastHlo(consumer->mutable_operand(0), broadcast_dims, - new_producer->shape().dimensions()); + auto new_broadcast = MakeBroadcastHlo(consumer->mutable_operand(0), + broadcast_dims, final_shape_dims); VLOG(1) << "Created broadcast " << new_broadcast->ToString(); - // Pass on the permutation information from the producer. - old_to_new_instrs_[consumer] = new_broadcast; - instr_to_dim_map_[consumer] = dim_map_val; - instr_to_dim_permute_map_[new_broadcast] = std::vector( - instr_to_dim_permute_map_[old_to_new_instrs_[producer]]); + if (batch_is_broadcasted) { + new_broadcast = + MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast) + .ValueOrDie(); + VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString(); + } + + if (!map_found) { + absl::flat_hash_set set_of_broadcasts; + broadcast_map_[consumer] = set_of_broadcasts; + } + broadcast_map_[consumer].insert(new_broadcast); } void ConvolutionVisitor::RewriteBroadcastTree( @@ -882,11 +1089,9 @@ bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast, CHECK(instr_to_dim_map_.contains(old_other_op)); auto result = instr_to_dim_map_[old_other_op]; - const int64 batch_dim = result.first; const int64 space_dim = result.second; auto broadcast_dims = broadcast->dimensions(); - return !absl::c_linear_search(broadcast_dims, batch_dim) && - !absl::c_linear_search(broadcast_dims, space_dim); + return !absl::c_linear_search(broadcast_dims, space_dim); } bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, @@ -990,27 +1195,50 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, // For elementwise binary ops, both of whose operands have been space-to- // batched, if their new spatial sizes don't match, choose the bigger one // as the producer. - if (consumer->IsElementwiseBinary() && - old_to_new_instrs_.contains(consumer->mutable_operand(0)) && - old_to_new_instrs_.contains(consumer->mutable_operand(1))) { - is_pivot_producer_modified = true; - if (old_to_new_instrs_[consumer->mutable_operand(0)] - ->shape() - .dimensions() > old_to_new_instrs_[consumer->mutable_operand(1)] - ->shape() - .dimensions()) { - producer = consumer->mutable_operand(0); - } else { - producer = consumer->mutable_operand(1); + if (consumer->IsElementwiseBinary() || + consumer->opcode() == HloOpcode::kSelect) { + int64 pivot_operand_number = -1; + HloInstruction* pivot_operand = nullptr; + for (int i = 0; i < consumer->operand_count(); ++i) { + if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { + continue; + } + auto operand = consumer->mutable_operand(i); + if (old_to_new_instrs_.contains(operand)) { + if (pivot_operand_number == -1 || + old_to_new_instrs_[pivot_operand]->shape().dimensions() < + old_to_new_instrs_[operand]->shape().dimensions()) { + is_pivot_producer_modified = true; + pivot_operand_number = i; + pivot_operand = consumer->mutable_operand(pivot_operand_number); + } + } + } + if (pivot_operand_number != -1) { + producer = pivot_operand; } } for (int64 i = 0; i < consumer->operand_count(); ++i) { std::vector instructions_to_transform; - if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) { + if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { + auto broadcast = consumer->mutable_operand(i); + PropagateOnBroadcast(broadcast, producer); + HloInstruction* new_broadcast = nullptr; + auto new_producer = old_to_new_instrs_[producer]; + for (auto previous_broadcast : broadcast_map_[broadcast]) { + if (ShapeUtil::CompatibleIgnoringElementType( + previous_broadcast->shape(), new_producer->shape())) { + new_broadcast = previous_broadcast; + break; + } + } + CHECK_NE(new_broadcast, nullptr); + TF_CHECK_OK( + new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast)); + } else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) { HloInstruction* operand_to_use = nullptr; - auto result = instr_to_dim_map_[producer]; const int64 old_batch_dim = result.first; const int64 old_space_dim = result.second; @@ -1070,13 +1298,6 @@ StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, } TF_CHECK_OK( new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use)); - } else if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) { - if (!old_to_new_instrs_.contains(consumer->operand(i))) { - PropagateOnBroadcast(consumer->mutable_operand(i), producer); - } - auto new_broadcast = old_to_new_instrs_[consumer->mutable_operand(i)]; - TF_CHECK_OK( - new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast)); } else if (consumer->IsElementwiseBinary() && IsBroadcastTree(consumer->mutable_operand(i), producer, instructions_to_transform)) { @@ -1463,8 +1684,10 @@ StatusOr ConvolutionVisitor::SelectValidPortion( StatusOr ConvolutionVisitor::BatchToSpace( HloInstruction* old_instr) { if (batch_to_space_map_.count(old_instr)) { + CHECK_NE(batch_to_space_map_[old_instr], nullptr); return batch_to_space_map_[old_instr]; } + auto result = instr_to_dim_map_[old_instr]; const int64 old_batch_dim = result.first; const int64 old_space_dim = result.second; @@ -1683,68 +1906,28 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size " << slice_size; - const int64 new_batch_size = - activations_new->shape().dimensions(activations_batch_dim); const int64 new_space_size = activations_new->shape().dimensions(c.spatial_dimension_to_split); // In the below case, we cannot use the activations directly for Halo // Duplication. We must reshape them. if (spatial_split_size > new_space_size) { - std::vector new_dimensions( - activations_new->shape().dimensions().begin(), - activations_new->shape().dimensions().end()); - const int64 reshaped_space_size = - new_space_size * new_batch_size / old_batch_size; - VLOG(3) << "Increasing the spatial size while propagating new_batch_size " - << new_batch_size << " old_batch_size " << old_batch_size; - new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size; - new_dimensions[activations_batch_dim] = old_batch_size; - - // Reshape the output of the new conv into the old convolutions shape. - TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations, - MakeReshapeHlo(new_dimensions, activations_new)); - - VLOG(3) << "First reshape done"; - PaddingConfig padding_config = - MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size()); - padding_config.mutable_dimensions(c.spatial_dimension_to_split) - ->set_edge_padding_high(spatial_split_size * new_batch_size / - old_batch_size - - reshaped_space_size); - padding_config.mutable_dimensions(c.spatial_dimension_to_split) - ->set_edge_padding_low(0); - HloInstruction* padding = - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(reshaped_activations->shape().element_type()))); - TF_ASSIGN_OR_RETURN( - reshaped_activations, - MakePadHlo(reshaped_activations, padding, padding_config)); - - std::vector reshape_back_dims( - reshaped_activations->shape().dimensions().begin(), - reshaped_activations->shape().dimensions().end()); - - reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size; - reshape_back_dims[activations_batch_dim] = new_batch_size; - - TF_ASSIGN_OR_RETURN( - reshaped_activations, - MakeReshapeHlo(reshape_back_dims, reshaped_activations)); - - VLOG(3) << "Second reshape done"; + activations_new, + IncreaseSpatialSizeOnSpaceToBatchedShape( + activations_new, activations_batch_dim, old_batch_size, + c.spatial_dimension_to_split, spatial_split_size)); TF_ASSIGN_OR_RETURN( activations_new, - HaloDuplicateWithSlice( - reshaped_activations, c.spatial_dimension_to_split, - activations_batch_dim, old_batch_size, - /*low_padding=*/c.base_dilation_factor != 1 && - c.inherent_low_padding != 0 - ? c.base_dilation_factor - 1 - : c.inherent_low_padding, - c.inherent_high_padding, slice_size - spatial_split_size, - old_split_dim_size)); + HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split, + activations_batch_dim, old_batch_size, + /*low_padding=*/c.base_dilation_factor != 1 && + c.inherent_low_padding != 0 + ? c.base_dilation_factor - 1 + : c.inherent_low_padding, + c.inherent_high_padding, + slice_size - spatial_split_size, + old_split_dim_size)); } else { // If the ideal spatial_split_size was smaller than the incoming spatial // dimension size, we don't need reshaping. Instead, we determine the @@ -1835,14 +2018,15 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations, int64& spatial_dimension_to_split, int64& activations_batch_dim, int64 high_padding, int64 low_padding, int64 spatial_split_size, - int64 num_splits, bool is_backprop) { + int64 num_splits, bool is_backprop, + bool is_rhs) { const int64 old_batch_size = activations->shape().dimensions(activations_batch_dim); - TF_ASSIGN_OR_RETURN( - auto retval, BringSpaceNextToBatch(activations, dim_numbers, - spatial_dimension_to_split, - activations_batch_dim, is_backprop)); + TF_ASSIGN_OR_RETURN(auto retval, + BringSpaceNextToBatch( + activations, dim_numbers, spatial_dimension_to_split, + activations_batch_dim, is_backprop, is_rhs)); activations = retval.instr; std::vector transpose_dims = retval.transpose_dims; @@ -1903,7 +2087,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( .window_dilation(); auto original_conv_dims = convolution->convolution_dimension_numbers(); - const int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions( + int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions( get_chosen_spatial_dim(convolution)); auto kernel_old = convolution->mutable_operand(1); const int64 old_kernel_split_dim_size = @@ -1914,18 +2098,24 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( int64 old_split_dim_size = activations_old->shape().dimensions(old_space_dim); int64 old_batch_dim = original_conv_dims.input_feature_dimension(); + int64 kernel_old_batch_dim = + original_conv_dims.kernel_input_feature_dimension(); const int64 old_batch_size = activations_old->shape().dimensions(old_batch_dim); - CHECK(old_to_new_instrs_.contains(kernel_old)); - auto kernel_new = old_to_new_instrs_[kernel_old]; - - auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new]; + CHECK(old_to_new_instrs_.contains(kernel_old) || + old_to_new_instrs_.contains(activations_old)); HloInstruction* activations_new = nullptr; + HloInstruction* kernel_new = nullptr; bool activations_locally_space_to_batched = false; + bool kernel_locally_space_to_batched = false; + std::vector permute_dims_kernel, permute_dims; // If activations were no space-to-batched, we space-to-batch them below. if (!old_to_new_instrs_.contains(activations_old)) { + kernel_new = old_to_new_instrs_[kernel_old]; + permute_dims_kernel = instr_to_dim_permute_map_[kernel_new]; + VLOG(1) << "Space-to-batching activations to enable space-to-depth"; const int64 prev_feature_dim = original_conv_dims.input_feature_dimension(); const int64 prev_batch_dim = original_conv_dims.input_batch_dimension(); @@ -1960,14 +2150,63 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( VLOG(3) << "New Activations " << retval.first->ToString(); activations_locally_space_to_batched = true; + } else if (!old_to_new_instrs_.contains(kernel_old)) { + activations_new = old_to_new_instrs_[activations_old]; + permute_dims = instr_to_dim_permute_map_[activations_new]; + + VLOG(1) << "Space-to-batching kernel to enable space-to-depth"; + const int64 prev_feature_dim = + original_conv_dims.kernel_input_feature_dimension(); + const int64 prev_output_feature_dim = + original_conv_dims.kernel_output_feature_dimension(); + // TODO(b/168316428): The instr_to_dim_map_ is set incorrectly here, but it + // doesn't matter since it is never used. Investigate further to see if just + // not setting it works. + instr_to_dim_map_[kernel_old] = + std::make_pair(prev_feature_dim, prev_output_feature_dim); + + const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim); + const int64 new_split_dim_size = + activations_new->shape().dimensions(new_space_dim); + const int64 needed_spatial_size = + CeilOfRatio(new_split_dim_size, rhs_dilation); + int64 old_kernel_split_dim_size = + kernel_old->shape().dimensions(kernel_space_dim); + const int64 pad_size = + needed_spatial_size * kNumSplits - old_kernel_split_dim_size; + + ConvolutionDimensionNumbers tmp_dim_numbers; + tmp_dim_numbers = original_conv_dims; + TF_ASSIGN_OR_RETURN( + auto retval, SplitSpace(kernel_old, tmp_dim_numbers, kernel_space_dim, + kernel_old_batch_dim, + /*high_padding=*/pad_size, /*low_padding=*/0, + needed_spatial_size, kNumSplits, + /*is_backprop=*/true, /*is_rhs=*/true)); + + old_to_new_instrs_[kernel_old] = retval.first; + + std::vector reversed_transpose_dims(retval.second.size()); + for (int64 i = 0; i < retval.second.size(); ++i) { + reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i); + } + instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims; + + VLOG(3) << "New kernel " << retval.first->ToString(); + + kernel_locally_space_to_batched = true; } CHECK(old_to_new_instrs_.contains(activations_old)); + CHECK(old_to_new_instrs_.contains(kernel_old)); activations_new = old_to_new_instrs_[activations_old]; + kernel_new = old_to_new_instrs_[kernel_old]; const int64 new_spatial_dimension = activations_new->shape().dimensions_size(); - auto permute_dims = instr_to_dim_permute_map_[activations_new]; + permute_dims = instr_to_dim_permute_map_[activations_new]; + permute_dims_kernel = instr_to_dim_permute_map_[kernel_new]; + auto permuted_conv_dims_numbers = original_conv_dims; // Note the inversion here : batch and feature are inverted in backprop @@ -2024,9 +2263,12 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( permuted_conv_dims_numbers.kernel_spatial_dimensions( get_chosen_spatial_dim(convolution)); - const int64 new_split_dim_size = + int64 new_split_dim_size = activations_new->shape().dimensions(spatial_dimension_to_split); + const int64 kernel_new_split_dim_size = + kernel_new->shape().dimensions(kernel_spatial_dimension_to_split); + permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim); permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim); @@ -2040,6 +2282,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers, spatial_dimension_to_split, activations_batch_dim, /*is_backprop=*/true)); + std::vector transpose_dims = retval.transpose_dims; CHECK(!transpose_dims.empty()); activations_new = retval.instr; @@ -2048,6 +2291,17 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( << activations_new->ToString(); VLOG(1) << "activations_batch_dim " << activations_batch_dim << " activations_feature_dim " << activations_feature_dim; + const int64 expected_split_dim_size = + rhs_dilation * kernel_new_split_dim_size; + if (new_split_dim_size != expected_split_dim_size) { + CHECK_LT(new_split_dim_size, expected_split_dim_size); + new_split_dim_size = expected_split_dim_size; + TF_ASSIGN_OR_RETURN( + activations_new, + IncreaseSpatialSizeOnSpaceToBatchedShape( + activations_new, activations_batch_dim, old_batch_size, + spatial_dimension_to_split, new_split_dim_size)); + } auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(activations_new->shape().element_type()))); @@ -2060,16 +2314,18 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( activations_batch_dim, spatial_dimension_to_split, old_batch_dim, old_space_dim)); } - VLOG(3) << "Selecting the valid kernel area"; - // Select kernel correctly by masking additional space. - TF_ASSIGN_OR_RETURN( - kernel_new, - SelectValidPortion( - kernel_new, kernel_old, select_val, - /*new_batch_dim=*/kernel_input_feature_dim, - kernel_spatial_dimension_to_split, - /*old_batch_dim=*/original_conv_dims.kernel_input_feature_dimension(), - kernel_space_dim)); + if (!kernel_locally_space_to_batched) { + VLOG(3) << "Selecting the valid kernel area"; + // Select kernel correctly by masking additional space. + TF_ASSIGN_OR_RETURN( + kernel_new, + SelectValidPortion(kernel_new, kernel_old, select_val, + /*new_batch_dim=*/kernel_input_feature_dim, + kernel_spatial_dimension_to_split, + /*old_batch_dim=*/ + original_conv_dims.kernel_input_feature_dimension(), + kernel_space_dim)); + } // Create the new convolution dim numbers. auto new_dim_numbers = permuted_conv_dims_numbers; @@ -2115,7 +2371,9 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( old_split_dim_size - expanded_kernel + 1 + (inherent_low_padding < 0 ? inherent_low_padding : 0) + (inherent_high_padding < 0 ? inherent_high_padding : 0); - VLOG(1) << "overlap_count " << overlap_count; + VLOG(1) << "overlap_count " << overlap_count << " inherent_low_padding " + << inherent_low_padding << " inherent_high_padding " + << inherent_high_padding; // Insert original activations. for (int64 i = 0; i < overlap_count; ++i) { @@ -2190,7 +2448,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) ->set_padding_low(0); new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) - ->set_size(new_split_dim_size / rhs_dilation); + ->set_size(CeilOfRatio(new_split_dim_size, rhs_dilation)); // Set the window for the additional spatial dim. This is a vanilla window. auto window_dim = new_window.add_dimensions(); @@ -2211,6 +2469,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv( /*preferred_element_type=*/convolution->shape().element_type())); convolution->SetupDerivedInstruction(new_conv); + VLOG(2) << "New backprop filter convolution " << new_conv->ToString(); + std::vector output_sizes(new_conv->shape().dimensions().begin(), new_conv->shape().dimensions().end()); @@ -2364,8 +2624,8 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution); if (reduce_window_or_select_and_scatter != nullptr) { - VLOG(2) << "DoesConvolutionFeedReduceWindowOrSelectAndScatter " - << reduce_window_or_select_and_scatter; + VLOG(2) + << "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true"; // Take into account the stride of the reduce window while choosing the // spatial_split_size. This will guarantee propagation through reduce // windows. @@ -2457,6 +2717,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( /*preferred_element_type=*/convolution->shape().element_type())); convolution->SetupDerivedInstruction(new_conv); + // If the activations were to be batch-to-spaced again, simply use the + // original value. + batch_to_space_map_[convolution->mutable_operand(0)] = + convolution->mutable_operand(0); + VLOG(1) << "Space-to-batched convolution " << new_conv->ToString(); const int64 output_split_spatial_dim = @@ -2483,7 +2748,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( get_chosen_spatial_dim(original_conv))); instr_to_dim_permute_map_[new_conv] = std::vector(transpose_dims); - + if (non_propagatable_instrs_.count(convolution) > 0) { + non_propagatable_instrs_.erase(convolution); + } TF_CHECK_OK(PropagateOnUsers(original_conv)); changed_ = true;