Extend space to batch to apply to larger batch sizes
PiperOrigin-RevId: 335657911 Change-Id: I22a9f7f978b7d64bf09654771036d044d3d3ef41
This commit is contained in:
parent
64a324460a
commit
ede1385636
@ -52,7 +52,7 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
|
||||
// Runs the visitor on a computation.
|
||||
static bool Run(HloComputation* computation);
|
||||
static bool Run(int64 limit_on_batch_size, HloComputation* computation);
|
||||
|
||||
// Returns whether any convolution ops were rewritten.
|
||||
const bool changed() const { return changed_; }
|
||||
@ -60,18 +60,23 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
|
||||
~ConvolutionVisitor() override = default;
|
||||
|
||||
private:
|
||||
explicit ConvolutionVisitor(HloComputation* computation)
|
||||
: computation_(computation) {}
|
||||
explicit ConvolutionVisitor(int64 limit_on_batch_size,
|
||||
HloComputation* computation)
|
||||
: computation_(computation), limit_on_batch_size_(limit_on_batch_size) {}
|
||||
|
||||
// Current HloComputation instance the ConvolutionVisitor is traversing.
|
||||
HloComputation* computation_;
|
||||
|
||||
// Whether rewrite has occurred.
|
||||
bool changed_ = false;
|
||||
|
||||
// Limit on batch size to apply this technique on.
|
||||
int64 limit_on_batch_size_;
|
||||
};
|
||||
|
||||
bool ConvolutionVisitor::Run(HloComputation* computation) {
|
||||
ConvolutionVisitor visitor(computation);
|
||||
bool ConvolutionVisitor::Run(int64 limit_on_batch_size,
|
||||
HloComputation* computation) {
|
||||
ConvolutionVisitor visitor(limit_on_batch_size, computation);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
@ -93,11 +98,18 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
constexpr int64 kLowLimitForSplitCount = 4;
|
||||
constexpr int64 kHighLimitForSplitCount = 24;
|
||||
|
||||
// Batch in batch_group_count has different semantics (it isn't true batch).
|
||||
// Consider supporting this case in future if needed.
|
||||
if (convolution->batch_group_count() != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (convolution->window().dimensions(kChosenSpatialDim).window_dilation() !=
|
||||
1) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(b/168316428): Support base dilations.
|
||||
if (convolution->window().dimensions(kChosenSpatialDim).base_dilation() !=
|
||||
1) {
|
||||
return Status::OK();
|
||||
@ -108,8 +120,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
const int64 old_batch_size =
|
||||
convolution->operand(0)->shape().dimensions(activations_batch_dim);
|
||||
|
||||
// TODO(b/168316428): Only doing this for batch 1 currently. Extend later.
|
||||
if (old_batch_size != 1) {
|
||||
if (old_batch_size > limit_on_batch_size_) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -261,11 +272,20 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
// -2 low padding and +2 high padding) to create shape B. Then, we select
|
||||
// between A and B such that halo regions are placed into A at the right
|
||||
// locations.
|
||||
|
||||
// The benefit of the above mentioned scheme is that it allows for batch
|
||||
// growth. Here are some examples of the size increases it causes for a 3x3
|
||||
// kernel.
|
||||
// with batch=1, [1,16] -> [4,4] -> [4,6] -> [1,24] growth of 8.
|
||||
// with batch=2, [2,16] -> [8,4] -> [8,6] -> [1,48] growth of 16.
|
||||
// with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24.
|
||||
|
||||
std::vector<int64> reshape_dimensions(
|
||||
activations->shape().dimensions().begin(),
|
||||
activations->shape().dimensions().end());
|
||||
|
||||
reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
|
||||
reshape_dimensions[activations_batch_dim] = num_splits;
|
||||
reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
|
||||
MakeReshapeHlo(reshape_dimensions, activations));
|
||||
@ -337,11 +357,19 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * select,
|
||||
MakeSelectHlo(shape_mask, straightened_activations,
|
||||
rotated_activations, convolution));
|
||||
VLOG(1) << "Select generated";
|
||||
VLOG(1) << "Select generated" << select->ToString();
|
||||
|
||||
// Increase batch size for one last time.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations, MakeReshapeHlo(pad_applied->shape().dimensions(), select));
|
||||
std::vector<int64> combined_batch_dimensions(
|
||||
pad_applied->shape().dimensions().begin(),
|
||||
pad_applied->shape().dimensions().end());
|
||||
|
||||
combined_batch_dimensions[activations_batch_dim] =
|
||||
old_batch_size * num_splits;
|
||||
TF_ASSIGN_OR_RETURN(activations,
|
||||
MakeReshapeHlo(combined_batch_dimensions, select));
|
||||
|
||||
VLOG(1) << "Batch merge done " << activations->ToString();
|
||||
|
||||
// Now, we rewrite the convolution with a larger batch.
|
||||
const auto& activations_shape = activations->shape();
|
||||
@ -385,28 +413,35 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
|
||||
VLOG(1) << "new_conv " << new_conv->ToString();
|
||||
|
||||
const int64 output_split_spatial_dim =
|
||||
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim);
|
||||
const int64 output_batch_dim = new_dim_numbers.output_batch_dimension();
|
||||
|
||||
Shape new_shape = new_conv->shape();
|
||||
const int64 new_batch_size =
|
||||
new_shape.dimensions(new_dim_numbers.output_batch_dimension());
|
||||
const int64 new_spatial_dim_size = new_shape.dimensions(
|
||||
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim));
|
||||
new_shape.set_dimensions(
|
||||
new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim),
|
||||
new_batch_size * new_spatial_dim_size);
|
||||
new_shape.set_dimensions(new_dim_numbers.output_batch_dimension(),
|
||||
old_batch_size);
|
||||
const int64 new_batch_size = new_shape.dimensions(output_batch_dim);
|
||||
const int64 new_spatial_dim_size =
|
||||
new_shape.dimensions(output_split_spatial_dim);
|
||||
|
||||
CHECK_EQ(new_batch_size % old_batch_size, 0);
|
||||
|
||||
const int64 output_split_batch_size = new_batch_size / old_batch_size;
|
||||
|
||||
std::vector<int64> new_dimensions(new_conv->shape().dimensions().begin(),
|
||||
new_conv->shape().dimensions().end());
|
||||
new_dimensions[output_split_spatial_dim] =
|
||||
output_split_batch_size * new_spatial_dim_size;
|
||||
new_dimensions[new_dim_numbers.output_batch_dimension()] = old_batch_size;
|
||||
|
||||
// Reshape the output of the new conv into the old convolutions shape.
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
|
||||
MakeReshapeHlo(new_shape, new_conv));
|
||||
MakeReshapeHlo(new_dimensions, new_conv));
|
||||
convolution->SetupDerivedInstruction(reshape);
|
||||
|
||||
std::vector<int64> start_indices(rank, 0),
|
||||
end_indices(new_shape.dimensions().begin(), new_shape.dimensions().end()),
|
||||
end_indices(new_dimensions.begin(), new_dimensions.end()),
|
||||
strides(rank, 1);
|
||||
end_indices[new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim)] =
|
||||
convolution->shape().dimensions(
|
||||
dim_numbers.output_spatial_dimensions(kChosenSpatialDim));
|
||||
end_indices[output_split_spatial_dim] = convolution->shape().dimensions(
|
||||
dim_numbers.output_spatial_dimensions(kChosenSpatialDim));
|
||||
|
||||
// This slicing is getting rid of the padding we added to evenly divide space.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -431,7 +466,7 @@ StatusOr<bool> ConvolutionSpaceToBatchConverter::Run(HloModule* module) {
|
||||
module->ToString());
|
||||
bool changed = false;
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
if (ConvolutionVisitor::Run(comp)) {
|
||||
if (ConvolutionVisitor::Run(limit_on_batch_size_, comp)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -26,7 +26,8 @@ namespace xla {
|
||||
// batch.
|
||||
class ConvolutionSpaceToBatchConverter : public HloModulePass {
|
||||
public:
|
||||
ConvolutionSpaceToBatchConverter() = default;
|
||||
explicit ConvolutionSpaceToBatchConverter(int64 limit_on_batch_size = 1)
|
||||
: limit_on_batch_size_(limit_on_batch_size) {}
|
||||
|
||||
absl::string_view name() const override {
|
||||
return "convolution-space-to-batch-converter";
|
||||
@ -35,6 +36,8 @@ class ConvolutionSpaceToBatchConverter : public HloModulePass {
|
||||
// Run convolution rewriting on the given computation. Returns whether the
|
||||
// computation was changed.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
int64 limit_on_batch_size_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -65,31 +65,42 @@ ENTRY computation {
|
||||
|
||||
TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) {
|
||||
string hlo_string = R"(
|
||||
|
||||
HloModule module
|
||||
ENTRY computation {
|
||||
%p0 = bf16[2,258,258,32] parameter(0)
|
||||
%p1 = bf16[3,3,32,32] parameter(1)
|
||||
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
|
||||
dim_labels=b01f_01io->b01f
|
||||
}
|
||||
ENTRY computation {
|
||||
%p0 = bf16[2,258,258,32] parameter(0)
|
||||
%p1 = bf16[3,3,32,32] parameter(1)
|
||||
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
|
||||
dim_labels=b01f_01io->b01f
|
||||
}
|
||||
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
ConvolutionSpaceToBatchConverter converter;
|
||||
ASSERT_FALSE(converter.Run(module.get()).ValueOrDie());
|
||||
ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/2);
|
||||
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
|
||||
auto computation = module->entry_computation();
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Transpose());
|
||||
EXPECT_THAT(root->operand(0), op::Slice());
|
||||
auto reshape = root->operand(0)->operand(0);
|
||||
EXPECT_THAT(reshape, op::Reshape());
|
||||
EXPECT_THAT(reshape->operand(0), op::Convolution());
|
||||
const int64 batch_dim = reshape->operand(0)
|
||||
->convolution_dimension_numbers()
|
||||
.output_batch_dimension();
|
||||
// Verify that the transform has increased the batch size.
|
||||
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1);
|
||||
}
|
||||
|
||||
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
|
||||
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch4WithStrideAndPad) {
|
||||
string hlo_string = R"(
|
||||
HloModule module
|
||||
ENTRY computation {
|
||||
%p0 = bf16[1,224,224,3]{3,2,1,0} parameter(0)
|
||||
%p0 = bf16[4,224,224,3]{3,2,1,0} parameter(0)
|
||||
%p1 = bf16[7,7,3,64]{3,2,1,0} parameter(1)
|
||||
|
||||
ROOT %convolution.3 = bf16[1,112,112,64]{3,2,1,0} convolution(%p0, %p1),
|
||||
ROOT %convolution.3 = bf16[4,112,112,64]{3,2,1,0} convolution(%p0, %p1),
|
||||
window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
|
||||
}
|
||||
)";
|
||||
@ -97,7 +108,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto computation = module->entry_computation();
|
||||
ConvolutionSpaceToBatchConverter converter;
|
||||
ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/4);
|
||||
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Transpose());
|
||||
@ -109,7 +120,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
|
||||
->convolution_dimension_numbers()
|
||||
.output_batch_dimension();
|
||||
|
||||
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1);
|
||||
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
|
||||
}
|
||||
|
||||
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user