From ede138563636c4db03fa915efed5e4627f099da5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 6 Oct 2020 09:36:47 -0700 Subject: [PATCH] Extend space to batch to apply to larger batch sizes PiperOrigin-RevId: 335657911 Change-Id: I22a9f7f978b7d64bf09654771036d044d3d3ef41 --- .../xla/service/space_to_batch_converter.cc | 87 +++++++++++++------ .../xla/service/space_to_batch_converter.h | 5 +- .../service/space_to_batch_converter_test.cc | 39 ++++++--- 3 files changed, 90 insertions(+), 41 deletions(-) diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index 05cbe137a1e..47aee8ed5a8 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -52,7 +52,7 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; // Runs the visitor on a computation. - static bool Run(HloComputation* computation); + static bool Run(int64 limit_on_batch_size, HloComputation* computation); // Returns whether any convolution ops were rewritten. const bool changed() const { return changed_; } @@ -60,18 +60,23 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation) - : computation_(computation) {} + explicit ConvolutionVisitor(int64 limit_on_batch_size, + HloComputation* computation) + : computation_(computation), limit_on_batch_size_(limit_on_batch_size) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; // Whether rewrite has occurred. bool changed_ = false; + + // Limit on batch size to apply this technique on. + int64 limit_on_batch_size_; }; -bool ConvolutionVisitor::Run(HloComputation* computation) { - ConvolutionVisitor visitor(computation); +bool ConvolutionVisitor::Run(int64 limit_on_batch_size, + HloComputation* computation) { + ConvolutionVisitor visitor(limit_on_batch_size, computation); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -93,11 +98,18 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { constexpr int64 kLowLimitForSplitCount = 4; constexpr int64 kHighLimitForSplitCount = 24; + // Batch in batch_group_count has different semantics (it isn't true batch). + // Consider supporting this case in future if needed. + if (convolution->batch_group_count() != 1) { + return Status::OK(); + } + if (convolution->window().dimensions(kChosenSpatialDim).window_dilation() != 1) { return Status::OK(); } + // TODO(b/168316428): Support base dilations. if (convolution->window().dimensions(kChosenSpatialDim).base_dilation() != 1) { return Status::OK(); @@ -108,8 +120,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { const int64 old_batch_size = convolution->operand(0)->shape().dimensions(activations_batch_dim); - // TODO(b/168316428): Only doing this for batch 1 currently. Extend later. - if (old_batch_size != 1) { + if (old_batch_size > limit_on_batch_size_) { return Status::OK(); } @@ -261,11 +272,20 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // -2 low padding and +2 high padding) to create shape B. Then, we select // between A and B such that halo regions are placed into A at the right // locations. + + // The benefit of the above mentioned scheme is that it allows for batch + // growth. Here are some examples of the size increases it causes for a 3x3 + // kernel. + // with batch=1, [1,16] -> [4,4] -> [4,6] -> [1,24] growth of 8. + // with batch=2, [2,16] -> [8,4] -> [8,6] -> [1,48] growth of 16. + // with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24. + std::vector reshape_dimensions( activations->shape().dimensions().begin(), activations->shape().dimensions().end()); + reshape_dimensions[spatial_dimension_to_split] = spatial_split_size; - reshape_dimensions[activations_batch_dim] = num_splits; + reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size; TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, MakeReshapeHlo(reshape_dimensions, activations)); @@ -337,11 +357,19 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { TF_ASSIGN_OR_RETURN(HloInstruction * select, MakeSelectHlo(shape_mask, straightened_activations, rotated_activations, convolution)); - VLOG(1) << "Select generated"; + VLOG(1) << "Select generated" << select->ToString(); // Increase batch size for one last time. - TF_ASSIGN_OR_RETURN( - activations, MakeReshapeHlo(pad_applied->shape().dimensions(), select)); + std::vector combined_batch_dimensions( + pad_applied->shape().dimensions().begin(), + pad_applied->shape().dimensions().end()); + + combined_batch_dimensions[activations_batch_dim] = + old_batch_size * num_splits; + TF_ASSIGN_OR_RETURN(activations, + MakeReshapeHlo(combined_batch_dimensions, select)); + + VLOG(1) << "Batch merge done " << activations->ToString(); // Now, we rewrite the convolution with a larger batch. const auto& activations_shape = activations->shape(); @@ -385,28 +413,35 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { VLOG(1) << "new_conv " << new_conv->ToString(); + const int64 output_split_spatial_dim = + new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim); + const int64 output_batch_dim = new_dim_numbers.output_batch_dimension(); + Shape new_shape = new_conv->shape(); - const int64 new_batch_size = - new_shape.dimensions(new_dim_numbers.output_batch_dimension()); - const int64 new_spatial_dim_size = new_shape.dimensions( - new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); - new_shape.set_dimensions( - new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim), - new_batch_size * new_spatial_dim_size); - new_shape.set_dimensions(new_dim_numbers.output_batch_dimension(), - old_batch_size); + const int64 new_batch_size = new_shape.dimensions(output_batch_dim); + const int64 new_spatial_dim_size = + new_shape.dimensions(output_split_spatial_dim); + + CHECK_EQ(new_batch_size % old_batch_size, 0); + + const int64 output_split_batch_size = new_batch_size / old_batch_size; + + std::vector new_dimensions(new_conv->shape().dimensions().begin(), + new_conv->shape().dimensions().end()); + new_dimensions[output_split_spatial_dim] = + output_split_batch_size * new_spatial_dim_size; + new_dimensions[new_dim_numbers.output_batch_dimension()] = old_batch_size; // Reshape the output of the new conv into the old convolutions shape. TF_ASSIGN_OR_RETURN(HloInstruction * reshape, - MakeReshapeHlo(new_shape, new_conv)); + MakeReshapeHlo(new_dimensions, new_conv)); convolution->SetupDerivedInstruction(reshape); std::vector start_indices(rank, 0), - end_indices(new_shape.dimensions().begin(), new_shape.dimensions().end()), + end_indices(new_dimensions.begin(), new_dimensions.end()), strides(rank, 1); - end_indices[new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim)] = - convolution->shape().dimensions( - dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); + end_indices[output_split_spatial_dim] = convolution->shape().dimensions( + dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); // This slicing is getting rid of the padding we added to evenly divide space. TF_ASSIGN_OR_RETURN( @@ -431,7 +466,7 @@ StatusOr ConvolutionSpaceToBatchConverter::Run(HloModule* module) { module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp)) { + if (ConvolutionVisitor::Run(limit_on_batch_size_, comp)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.h b/tensorflow/compiler/xla/service/space_to_batch_converter.h index da4102ddf37..a92abda0337 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.h +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.h @@ -26,7 +26,8 @@ namespace xla { // batch. class ConvolutionSpaceToBatchConverter : public HloModulePass { public: - ConvolutionSpaceToBatchConverter() = default; + explicit ConvolutionSpaceToBatchConverter(int64 limit_on_batch_size = 1) + : limit_on_batch_size_(limit_on_batch_size) {} absl::string_view name() const override { return "convolution-space-to-batch-converter"; @@ -35,6 +36,8 @@ class ConvolutionSpaceToBatchConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + int64 limit_on_batch_size_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc index 0495d7f1031..bbc3882cde9 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc @@ -65,31 +65,42 @@ ENTRY computation { TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) { string hlo_string = R"( - HloModule module -ENTRY computation { - %p0 = bf16[2,258,258,32] parameter(0) - %p1 = bf16[3,3,32,32] parameter(1) - ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, - dim_labels=b01f_01io->b01f -} + ENTRY computation { + %p0 = bf16[2,258,258,32] parameter(0) + %p1 = bf16[3,3,32,32] parameter(1) + ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, + dim_labels=b01f_01io->b01f + } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); - ConvolutionSpaceToBatchConverter converter; - ASSERT_FALSE(converter.Run(module.get()).ValueOrDie()); + ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/2); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Transpose()); + EXPECT_THAT(root->operand(0), op::Slice()); + auto reshape = root->operand(0)->operand(0); + EXPECT_THAT(reshape, op::Reshape()); + EXPECT_THAT(reshape->operand(0), op::Convolution()); + const int64 batch_dim = reshape->operand(0) + ->convolution_dimension_numbers() + .output_batch_dimension(); + // Verify that the transform has increased the batch size. + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1); } -TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { +TEST_F(ConvolutionSpaceToBatchConverterTest, Batch4WithStrideAndPad) { string hlo_string = R"( HloModule module ENTRY computation { - %p0 = bf16[1,224,224,3]{3,2,1,0} parameter(0) + %p0 = bf16[4,224,224,3]{3,2,1,0} parameter(0) %p1 = bf16[7,7,3,64]{3,2,1,0} parameter(1) - ROOT %convolution.3 = bf16[1,112,112,64]{3,2,1,0} convolution(%p0, %p1), + ROOT %convolution.3 = bf16[4,112,112,64]{3,2,1,0} convolution(%p0, %p1), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f } )"; @@ -97,7 +108,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); - ConvolutionSpaceToBatchConverter converter; + ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/4); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Transpose()); @@ -109,7 +120,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { ->convolution_dimension_numbers() .output_batch_dimension(); - EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1); + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4); } TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) {