Extend space to batch to apply to larger batch sizes

PiperOrigin-RevId: 335657911
Change-Id: I22a9f7f978b7d64bf09654771036d044d3d3ef41
This commit is contained in:
A. Unique TensorFlower 2020-10-06 09:36:47 -07:00 committed by TensorFlower Gardener
parent 64a324460a
commit ede1385636
3 changed files with 90 additions and 41 deletions

View File

@ -52,7 +52,7 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
Status HandleConvolution(HloInstruction* convolution) override;
// Runs the visitor on a computation.
static bool Run(HloComputation* computation);
static bool Run(int64 limit_on_batch_size, HloComputation* computation);
// Returns whether any convolution ops were rewritten.
const bool changed() const { return changed_; }
@ -60,18 +60,23 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
~ConvolutionVisitor() override = default;
private:
explicit ConvolutionVisitor(HloComputation* computation)
: computation_(computation) {}
explicit ConvolutionVisitor(int64 limit_on_batch_size,
HloComputation* computation)
: computation_(computation), limit_on_batch_size_(limit_on_batch_size) {}
// Current HloComputation instance the ConvolutionVisitor is traversing.
HloComputation* computation_;
// Whether rewrite has occurred.
bool changed_ = false;
// Limit on batch size to apply this technique on.
int64 limit_on_batch_size_;
};
bool ConvolutionVisitor::Run(HloComputation* computation) {
ConvolutionVisitor visitor(computation);
bool ConvolutionVisitor::Run(int64 limit_on_batch_size,
HloComputation* computation) {
ConvolutionVisitor visitor(limit_on_batch_size, computation);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -93,11 +98,18 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
constexpr int64 kLowLimitForSplitCount = 4;
constexpr int64 kHighLimitForSplitCount = 24;
// Batch in batch_group_count has different semantics (it isn't true batch).
// Consider supporting this case in future if needed.
if (convolution->batch_group_count() != 1) {
return Status::OK();
}
if (convolution->window().dimensions(kChosenSpatialDim).window_dilation() !=
1) {
return Status::OK();
}
// TODO(b/168316428): Support base dilations.
if (convolution->window().dimensions(kChosenSpatialDim).base_dilation() !=
1) {
return Status::OK();
@ -108,8 +120,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
const int64 old_batch_size =
convolution->operand(0)->shape().dimensions(activations_batch_dim);
// TODO(b/168316428): Only doing this for batch 1 currently. Extend later.
if (old_batch_size != 1) {
if (old_batch_size > limit_on_batch_size_) {
return Status::OK();
}
@ -261,11 +272,20 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
// -2 low padding and +2 high padding) to create shape B. Then, we select
// between A and B such that halo regions are placed into A at the right
// locations.
// The benefit of the above mentioned scheme is that it allows for batch
// growth. Here are some examples of the size increases it causes for a 3x3
// kernel.
// with batch=1, [1,16] -> [4,4] -> [4,6] -> [1,24] growth of 8.
// with batch=2, [2,16] -> [8,4] -> [8,6] -> [1,48] growth of 16.
// with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24.
std::vector<int64> reshape_dimensions(
activations->shape().dimensions().begin(),
activations->shape().dimensions().end());
reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
reshape_dimensions[activations_batch_dim] = num_splits;
reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
MakeReshapeHlo(reshape_dimensions, activations));
@ -337,11 +357,19 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN(HloInstruction * select,
MakeSelectHlo(shape_mask, straightened_activations,
rotated_activations, convolution));
VLOG(1) << "Select generated";
VLOG(1) << "Select generated" << select->ToString();
// Increase batch size for one last time.
TF_ASSIGN_OR_RETURN(
activations, MakeReshapeHlo(pad_applied->shape().dimensions(), select));
std::vector<int64> combined_batch_dimensions(
pad_applied->shape().dimensions().begin(),
pad_applied->shape().dimensions().end());
combined_batch_dimensions[activations_batch_dim] =
old_batch_size * num_splits;
TF_ASSIGN_OR_RETURN(activations,
MakeReshapeHlo(combined_batch_dimensions, select));
VLOG(1) << "Batch merge done " << activations->ToString();
// Now, we rewrite the convolution with a larger batch.
const auto& activations_shape = activations->shape();
@ -385,28 +413,35 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
VLOG(1) << "new_conv " << new_conv->ToString();
const int64 output_split_spatial_dim =
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim);
const int64 output_batch_dim = new_dim_numbers.output_batch_dimension();
Shape new_shape = new_conv->shape();
const int64 new_batch_size =
new_shape.dimensions(new_dim_numbers.output_batch_dimension());
const int64 new_spatial_dim_size = new_shape.dimensions(
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim));
new_shape.set_dimensions(
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim),
new_batch_size * new_spatial_dim_size);
new_shape.set_dimensions(new_dim_numbers.output_batch_dimension(),
old_batch_size);
const int64 new_batch_size = new_shape.dimensions(output_batch_dim);
const int64 new_spatial_dim_size =
new_shape.dimensions(output_split_spatial_dim);
CHECK_EQ(new_batch_size % old_batch_size, 0);
const int64 output_split_batch_size = new_batch_size / old_batch_size;
std::vector<int64> new_dimensions(new_conv->shape().dimensions().begin(),
new_conv->shape().dimensions().end());
new_dimensions[output_split_spatial_dim] =
output_split_batch_size * new_spatial_dim_size;
new_dimensions[new_dim_numbers.output_batch_dimension()] = old_batch_size;
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
MakeReshapeHlo(new_shape, new_conv));
MakeReshapeHlo(new_dimensions, new_conv));
convolution->SetupDerivedInstruction(reshape);
std::vector<int64> start_indices(rank, 0),
end_indices(new_shape.dimensions().begin(), new_shape.dimensions().end()),
end_indices(new_dimensions.begin(), new_dimensions.end()),
strides(rank, 1);
end_indices[new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim)] =
convolution->shape().dimensions(
dim_numbers.output_spatial_dimensions(kChosenSpatialDim));
end_indices[output_split_spatial_dim] = convolution->shape().dimensions(
dim_numbers.output_spatial_dimensions(kChosenSpatialDim));
// This slicing is getting rid of the padding we added to evenly divide space.
TF_ASSIGN_OR_RETURN(
@ -431,7 +466,7 @@ StatusOr<bool> ConvolutionSpaceToBatchConverter::Run(HloModule* module) {
module->ToString());
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (ConvolutionVisitor::Run(comp)) {
if (ConvolutionVisitor::Run(limit_on_batch_size_, comp)) {
changed = true;
}
}

View File

@ -26,7 +26,8 @@ namespace xla {
// batch.
class ConvolutionSpaceToBatchConverter : public HloModulePass {
public:
ConvolutionSpaceToBatchConverter() = default;
explicit ConvolutionSpaceToBatchConverter(int64 limit_on_batch_size = 1)
: limit_on_batch_size_(limit_on_batch_size) {}
absl::string_view name() const override {
return "convolution-space-to-batch-converter";
@ -35,6 +36,8 @@ class ConvolutionSpaceToBatchConverter : public HloModulePass {
// Run convolution rewriting on the given computation. Returns whether the
// computation was changed.
StatusOr<bool> Run(HloModule* module) override;
int64 limit_on_batch_size_;
};
} // namespace xla

View File

@ -65,31 +65,42 @@ ENTRY computation {
TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) {
string hlo_string = R"(
HloModule module
ENTRY computation {
%p0 = bf16[2,258,258,32] parameter(0)
%p1 = bf16[3,3,32,32] parameter(1)
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
dim_labels=b01f_01io->b01f
}
ENTRY computation {
%p0 = bf16[2,258,258,32] parameter(0)
%p1 = bf16[3,3,32,32] parameter(1)
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
dim_labels=b01f_01io->b01f
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
ConvolutionSpaceToBatchConverter converter;
ASSERT_FALSE(converter.Run(module.get()).ValueOrDie());
ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/2);
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
auto computation = module->entry_computation();
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Transpose());
EXPECT_THAT(root->operand(0), op::Slice());
auto reshape = root->operand(0)->operand(0);
EXPECT_THAT(reshape, op::Reshape());
EXPECT_THAT(reshape->operand(0), op::Convolution());
const int64 batch_dim = reshape->operand(0)
->convolution_dimension_numbers()
.output_batch_dimension();
// Verify that the transform has increased the batch size.
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1);
}
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch4WithStrideAndPad) {
string hlo_string = R"(
HloModule module
ENTRY computation {
%p0 = bf16[1,224,224,3]{3,2,1,0} parameter(0)
%p0 = bf16[4,224,224,3]{3,2,1,0} parameter(0)
%p1 = bf16[7,7,3,64]{3,2,1,0} parameter(1)
ROOT %convolution.3 = bf16[1,112,112,64]{3,2,1,0} convolution(%p0, %p1),
ROOT %convolution.3 = bf16[4,112,112,64]{3,2,1,0} convolution(%p0, %p1),
window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
}
)";
@ -97,7 +108,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
ParseAndReturnVerifiedModule(hlo_string));
auto computation = module->entry_computation();
ConvolutionSpaceToBatchConverter converter;
ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/4);
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_THAT(root, op::Transpose());
@ -109,7 +120,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
->convolution_dimension_numbers()
.output_batch_dimension();
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1);
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
}
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) {