Extend space to batch to apply to larger batch sizes

PiperOrigin-RevId: 335657911
Change-Id: I22a9f7f978b7d64bf09654771036d044d3d3ef41
This commit is contained in:
A. Unique TensorFlower 2020-10-06 09:36:47 -07:00 committed by TensorFlower Gardener
parent 64a324460a
commit ede1385636
3 changed files with 90 additions and 41 deletions

View File

@ -52,7 +52,7 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
Status HandleConvolution(HloInstruction* convolution) override; Status HandleConvolution(HloInstruction* convolution) override;
// Runs the visitor on a computation. // Runs the visitor on a computation.
static bool Run(HloComputation* computation); static bool Run(int64 limit_on_batch_size, HloComputation* computation);
// Returns whether any convolution ops were rewritten. // Returns whether any convolution ops were rewritten.
const bool changed() const { return changed_; } const bool changed() const { return changed_; }
@ -60,18 +60,23 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
~ConvolutionVisitor() override = default; ~ConvolutionVisitor() override = default;
private: private:
explicit ConvolutionVisitor(HloComputation* computation) explicit ConvolutionVisitor(int64 limit_on_batch_size,
: computation_(computation) {} HloComputation* computation)
: computation_(computation), limit_on_batch_size_(limit_on_batch_size) {}
// Current HloComputation instance the ConvolutionVisitor is traversing. // Current HloComputation instance the ConvolutionVisitor is traversing.
HloComputation* computation_; HloComputation* computation_;
// Whether rewrite has occurred. // Whether rewrite has occurred.
bool changed_ = false; bool changed_ = false;
// Limit on batch size to apply this technique on.
int64 limit_on_batch_size_;
}; };
bool ConvolutionVisitor::Run(HloComputation* computation) { bool ConvolutionVisitor::Run(int64 limit_on_batch_size,
ConvolutionVisitor visitor(computation); HloComputation* computation) {
ConvolutionVisitor visitor(limit_on_batch_size, computation);
TF_CHECK_OK(computation->Accept(&visitor)); TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_; return visitor.changed_;
} }
@ -93,11 +98,18 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
constexpr int64 kLowLimitForSplitCount = 4; constexpr int64 kLowLimitForSplitCount = 4;
constexpr int64 kHighLimitForSplitCount = 24; constexpr int64 kHighLimitForSplitCount = 24;
// Batch in batch_group_count has different semantics (it isn't true batch).
// Consider supporting this case in future if needed.
if (convolution->batch_group_count() != 1) {
return Status::OK();
}
if (convolution->window().dimensions(kChosenSpatialDim).window_dilation() != if (convolution->window().dimensions(kChosenSpatialDim).window_dilation() !=
1) { 1) {
return Status::OK(); return Status::OK();
} }
// TODO(b/168316428): Support base dilations.
if (convolution->window().dimensions(kChosenSpatialDim).base_dilation() != if (convolution->window().dimensions(kChosenSpatialDim).base_dilation() !=
1) { 1) {
return Status::OK(); return Status::OK();
@ -108,8 +120,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
const int64 old_batch_size = const int64 old_batch_size =
convolution->operand(0)->shape().dimensions(activations_batch_dim); convolution->operand(0)->shape().dimensions(activations_batch_dim);
// TODO(b/168316428): Only doing this for batch 1 currently. Extend later. if (old_batch_size > limit_on_batch_size_) {
if (old_batch_size != 1) {
return Status::OK(); return Status::OK();
} }
@ -261,11 +272,20 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
// -2 low padding and +2 high padding) to create shape B. Then, we select // -2 low padding and +2 high padding) to create shape B. Then, we select
// between A and B such that halo regions are placed into A at the right // between A and B such that halo regions are placed into A at the right
// locations. // locations.
// The benefit of the above mentioned scheme is that it allows for batch
// growth. Here are some examples of the size increases it causes for a 3x3
// kernel.
// with batch=1, [1,16] -> [4,4] -> [4,6] -> [1,24] growth of 8.
// with batch=2, [2,16] -> [8,4] -> [8,6] -> [1,48] growth of 16.
// with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24.
std::vector<int64> reshape_dimensions( std::vector<int64> reshape_dimensions(
activations->shape().dimensions().begin(), activations->shape().dimensions().begin(),
activations->shape().dimensions().end()); activations->shape().dimensions().end());
reshape_dimensions[spatial_dimension_to_split] = spatial_split_size; reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
reshape_dimensions[activations_batch_dim] = num_splits; reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
MakeReshapeHlo(reshape_dimensions, activations)); MakeReshapeHlo(reshape_dimensions, activations));
@ -337,11 +357,19 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN(HloInstruction * select, TF_ASSIGN_OR_RETURN(HloInstruction * select,
MakeSelectHlo(shape_mask, straightened_activations, MakeSelectHlo(shape_mask, straightened_activations,
rotated_activations, convolution)); rotated_activations, convolution));
VLOG(1) << "Select generated"; VLOG(1) << "Select generated" << select->ToString();
// Increase batch size for one last time. // Increase batch size for one last time.
TF_ASSIGN_OR_RETURN( std::vector<int64> combined_batch_dimensions(
activations, MakeReshapeHlo(pad_applied->shape().dimensions(), select)); pad_applied->shape().dimensions().begin(),
pad_applied->shape().dimensions().end());
combined_batch_dimensions[activations_batch_dim] =
old_batch_size * num_splits;
TF_ASSIGN_OR_RETURN(activations,
MakeReshapeHlo(combined_batch_dimensions, select));
VLOG(1) << "Batch merge done " << activations->ToString();
// Now, we rewrite the convolution with a larger batch. // Now, we rewrite the convolution with a larger batch.
const auto& activations_shape = activations->shape(); const auto& activations_shape = activations->shape();
@ -385,27 +413,34 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
VLOG(1) << "new_conv " << new_conv->ToString(); VLOG(1) << "new_conv " << new_conv->ToString();
const int64 output_split_spatial_dim =
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim);
const int64 output_batch_dim = new_dim_numbers.output_batch_dimension();
Shape new_shape = new_conv->shape(); Shape new_shape = new_conv->shape();
const int64 new_batch_size = const int64 new_batch_size = new_shape.dimensions(output_batch_dim);
new_shape.dimensions(new_dim_numbers.output_batch_dimension()); const int64 new_spatial_dim_size =
const int64 new_spatial_dim_size = new_shape.dimensions( new_shape.dimensions(output_split_spatial_dim);
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim));
new_shape.set_dimensions( CHECK_EQ(new_batch_size % old_batch_size, 0);
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim),
new_batch_size * new_spatial_dim_size); const int64 output_split_batch_size = new_batch_size / old_batch_size;
new_shape.set_dimensions(new_dim_numbers.output_batch_dimension(),
old_batch_size); std::vector<int64> new_dimensions(new_conv->shape().dimensions().begin(),
new_conv->shape().dimensions().end());
new_dimensions[output_split_spatial_dim] =
output_split_batch_size * new_spatial_dim_size;
new_dimensions[new_dim_numbers.output_batch_dimension()] = old_batch_size;
// Reshape the output of the new conv into the old convolutions shape. // Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(HloInstruction * reshape, TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
MakeReshapeHlo(new_shape, new_conv)); MakeReshapeHlo(new_dimensions, new_conv));
convolution->SetupDerivedInstruction(reshape); convolution->SetupDerivedInstruction(reshape);
std::vector<int64> start_indices(rank, 0), std::vector<int64> start_indices(rank, 0),
end_indices(new_shape.dimensions().begin(), new_shape.dimensions().end()), end_indices(new_dimensions.begin(), new_dimensions.end()),
strides(rank, 1); strides(rank, 1);
end_indices[new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim)] = end_indices[output_split_spatial_dim] = convolution->shape().dimensions(
convolution->shape().dimensions(
dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); dim_numbers.output_spatial_dimensions(kChosenSpatialDim));
// This slicing is getting rid of the padding we added to evenly divide space. // This slicing is getting rid of the padding we added to evenly divide space.
@ -431,7 +466,7 @@ StatusOr<bool> ConvolutionSpaceToBatchConverter::Run(HloModule* module) {
module->ToString()); module->ToString());
bool changed = false; bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) { for (auto* comp : module->MakeNonfusionComputations()) {
if (ConvolutionVisitor::Run(comp)) { if (ConvolutionVisitor::Run(limit_on_batch_size_, comp)) {
changed = true; changed = true;
} }
} }

View File

@ -26,7 +26,8 @@ namespace xla {
// batch. // batch.
class ConvolutionSpaceToBatchConverter : public HloModulePass { class ConvolutionSpaceToBatchConverter : public HloModulePass {
public: public:
ConvolutionSpaceToBatchConverter() = default; explicit ConvolutionSpaceToBatchConverter(int64 limit_on_batch_size = 1)
: limit_on_batch_size_(limit_on_batch_size) {}
absl::string_view name() const override { absl::string_view name() const override {
return "convolution-space-to-batch-converter"; return "convolution-space-to-batch-converter";
@ -35,6 +36,8 @@ class ConvolutionSpaceToBatchConverter : public HloModulePass {
// Run convolution rewriting on the given computation. Returns whether the // Run convolution rewriting on the given computation. Returns whether the
// computation was changed. // computation was changed.
StatusOr<bool> Run(HloModule* module) override; StatusOr<bool> Run(HloModule* module) override;
int64 limit_on_batch_size_;
}; };
} // namespace xla } // namespace xla

View File

@ -65,31 +65,42 @@ ENTRY computation {
TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) { TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) {
string hlo_string = R"( string hlo_string = R"(
HloModule module HloModule module
ENTRY computation { ENTRY computation {
%p0 = bf16[2,258,258,32] parameter(0) %p0 = bf16[2,258,258,32] parameter(0)
%p1 = bf16[3,3,32,32] parameter(1) %p1 = bf16[3,3,32,32] parameter(1)
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
dim_labels=b01f_01io->b01f dim_labels=b01f_01io->b01f
} }
)"; )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string)); ParseAndReturnVerifiedModule(hlo_string));
ConvolutionSpaceToBatchConverter converter; ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/2);
ASSERT_FALSE(converter.Run(module.get()).ValueOrDie()); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
auto computation = module->entry_computation();
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Transpose());
EXPECT_THAT(root->operand(0), op::Slice());
auto reshape = root->operand(0)->operand(0);
EXPECT_THAT(reshape, op::Reshape());
EXPECT_THAT(reshape->operand(0), op::Convolution());
const int64 batch_dim = reshape->operand(0)
->convolution_dimension_numbers()
.output_batch_dimension();
// Verify that the transform has increased the batch size.
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1);
} }
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { TEST_F(ConvolutionSpaceToBatchConverterTest, Batch4WithStrideAndPad) {
string hlo_string = R"( string hlo_string = R"(
HloModule module HloModule module
ENTRY computation { ENTRY computation {
%p0 = bf16[1,224,224,3]{3,2,1,0} parameter(0) %p0 = bf16[4,224,224,3]{3,2,1,0} parameter(0)
%p1 = bf16[7,7,3,64]{3,2,1,0} parameter(1) %p1 = bf16[7,7,3,64]{3,2,1,0} parameter(1)
ROOT %convolution.3 = bf16[1,112,112,64]{3,2,1,0} convolution(%p0, %p1), ROOT %convolution.3 = bf16[4,112,112,64]{3,2,1,0} convolution(%p0, %p1),
window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
} }
)"; )";
@ -97,7 +108,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
ParseAndReturnVerifiedModule(hlo_string)); ParseAndReturnVerifiedModule(hlo_string));
auto computation = module->entry_computation(); auto computation = module->entry_computation();
ConvolutionSpaceToBatchConverter converter; ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/4);
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction(); HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Transpose()); EXPECT_THAT(root, op::Transpose());
@ -109,7 +120,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
->convolution_dimension_numbers() ->convolution_dimension_numbers()
.output_batch_dimension(); .output_batch_dimension();
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1); EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
} }
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) { TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) {