Allow space-to-batch to occur on base dilated convolutions
PiperOrigin-RevId: 343906578 Change-Id: I33dc1940dcfb3c5db19943964ddfa85a75e0553c
This commit is contained in:
parent
1dcb38c020
commit
7cfa60337b
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
@ -64,6 +65,19 @@ class ConvolutionVisitor {
|
||||
// Top-level function to begin space-to-batch conversion.
|
||||
Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution);
|
||||
|
||||
// Struct containing details about a convolution.
|
||||
struct ConvDetails {
|
||||
int64 spatial_dimension_to_split, inherent_low_padding,
|
||||
inherent_high_padding, stride, spatial_size, base_dilation_factor,
|
||||
halo_size, high_padding_for_conv, low_padding_for_conv,
|
||||
kernel_spatial_dim_size, input_dim_size;
|
||||
};
|
||||
|
||||
// Return a struct containing various necessary information pieces for
|
||||
// performing space-to-batch on a convolution.
|
||||
ConvDetails GetConvolutionDetails(HloInstruction* convolution,
|
||||
ConvolutionDimensionNumbers& dim_numbers);
|
||||
|
||||
// Function that determines if space-to-batch can be propagated into the
|
||||
// consumer. Such propagation is only possible when all required operands are
|
||||
// space-to-batch'ed.
|
||||
@ -225,11 +239,29 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO(b/168316428): Support base dilations.
|
||||
if (convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.base_dilation() != 1) {
|
||||
return false;
|
||||
const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
|
||||
|
||||
const int64 low_pad = convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.padding_low();
|
||||
|
||||
// TODO(b/168316428): Support base dilations more generically.
|
||||
if (c.base_dilation_factor != 1) {
|
||||
if (c.stride != 1) {
|
||||
return false;
|
||||
}
|
||||
// For low pad of 0, only support a pointwise kernel.
|
||||
if (low_pad == 0) {
|
||||
if (c.kernel_spatial_dim_size != 1) {
|
||||
return false;
|
||||
}
|
||||
} else if (c.kernel_spatial_dim_size != c.base_dilation_factor + 1 ||
|
||||
low_pad != c.base_dilation_factor - 1) {
|
||||
// Only support dilations such that base dilation factor and low pad are
|
||||
// compatible with kernel_spatial_dim_size to be compatible with
|
||||
// HaloDuplicateWithSlice.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int64 activations_batch_dim = dim_numbers.input_batch_dimension();
|
||||
@ -240,42 +272,17 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
|
||||
if (old_batch_size > limit_on_batch_size_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel = convolution->mutable_operand(1);
|
||||
const auto& kernel_shape = kernel->shape();
|
||||
const int64 kernel_spatial_dim_size =
|
||||
kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
|
||||
get_chosen_spatial_dim(convolution)));
|
||||
|
||||
auto activations = convolution->mutable_operand(0);
|
||||
|
||||
const int64 input_dim_size =
|
||||
activations->shape().dimensions(dim_numbers.input_spatial_dimensions(
|
||||
get_chosen_spatial_dim(convolution)));
|
||||
|
||||
const int64 inherent_low_padding =
|
||||
convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.padding_low();
|
||||
const int64 inherent_high_padding =
|
||||
convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.padding_high();
|
||||
|
||||
const int64 spatial_size =
|
||||
input_dim_size + inherent_low_padding + inherent_high_padding;
|
||||
VLOG(1) << "spatial size " << spatial_size;
|
||||
|
||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
||||
|
||||
// We currently only cater to evenly divisible cases.
|
||||
if (kNewBatchSize % old_batch_size != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Splitting will be incorrect in these cases.
|
||||
if (spatial_size < num_splits ||
|
||||
input_dim_size / num_splits < kernel_spatial_dim_size) {
|
||||
VLOG(1) << "spatial size " << c.spatial_size;
|
||||
|
||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
||||
// If the ratio is not within the 2X range, we can't Halo Pad from the next
|
||||
// split.
|
||||
if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) {
|
||||
return false;
|
||||
}
|
||||
VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
|
||||
@ -292,8 +299,8 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
|
||||
activations->shape().dimensions(spatial_dimension_to_split);
|
||||
const int64 batch_size =
|
||||
activations->shape().dimensions(activations_batch_dim);
|
||||
CHECK_LT(low_padding, spatial_split_size);
|
||||
|
||||
CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size);
|
||||
VLOG(1) << "In HaloDuplicateWithSlice with activations "
|
||||
<< activations->ToString() << " batch_size " << batch_size
|
||||
<< " spatial_split_size " << spatial_split_size << " low_padding "
|
||||
@ -439,6 +446,7 @@ StatusOr<bool> ConvolutionVisitor::Run() {
|
||||
// Iterate through all instructions that we could not propagate through, and
|
||||
// turn their operands from batch-to-space as needed.
|
||||
for (auto instr : non_propagatable_instrs_) {
|
||||
VLOG(1) << "Could not eventually propagate through " << instr->ToString();
|
||||
absl::flat_hash_map<int64, HloInstruction*> operand_map;
|
||||
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
||||
if (old_to_new_instrs_.count(instr->mutable_operand(i))) {
|
||||
@ -480,8 +488,9 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
|
||||
if (!old_to_new_instrs_.contains(old_producer) &&
|
||||
!broadcast_or_constant) {
|
||||
VLOG(1) << "Cannot propagate on elementwise op "
|
||||
<< consumer->ToString();
|
||||
VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString()
|
||||
<< " because operand " << old_producer->ToString()
|
||||
<< " isn't ready ";
|
||||
return false;
|
||||
} else {
|
||||
if (broadcast_or_constant) {
|
||||
@ -496,10 +505,11 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
pivot_operand = old_producer;
|
||||
VLOG(2) << "Elementwise op: pivot " << old_producer->ToString();
|
||||
} else {
|
||||
VLOG(2) << "Elementwise op: checking for shape equivalence "
|
||||
<< consumer->ToString();
|
||||
if (instr_to_dim_map_[pivot_operand] !=
|
||||
instr_to_dim_map_[old_producer]) {
|
||||
VLOG(2) << "Elementwise op: checking for shape equivalence "
|
||||
<< consumer->ToString()
|
||||
<< " failed due to changed batch space ordering ";
|
||||
return false;
|
||||
}
|
||||
auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
|
||||
@ -509,13 +519,22 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
for (int j = 0; j < pivot_permute_dims.size(); ++j) {
|
||||
// Ensure the dimension mapping is the same.
|
||||
if (pivot_permute_dims[j] != permute_dims[j]) {
|
||||
VLOG(2) << "Elementwise op: checking for shape equivalence "
|
||||
<< consumer->ToString()
|
||||
<< " failed due to permuted dimensions ";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Make sure all other dimensions are of the same size.
|
||||
if (pivot_new_instr->shape().dimensions(j) !=
|
||||
new_instr->shape().dimensions(j)) {
|
||||
return false;
|
||||
if (!(consumer->IsElementwiseBinary() &&
|
||||
j == instr_to_dim_map_[pivot_operand].second)) {
|
||||
VLOG(2) << "Elementwise op: checking for shape equivalence "
|
||||
<< consumer->ToString()
|
||||
<< " failed due to changed shape sizes ";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -769,6 +788,28 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
|
||||
if (IsTrivialElementwise(consumer)) {
|
||||
auto dim_map_val = instr_to_dim_map_[producer];
|
||||
auto new_consumer = computation->AddInstruction(consumer->Clone());
|
||||
if (consumer->IsElementwiseBinary()) {
|
||||
for (int64 i = 0; i < 2; ++i) {
|
||||
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
|
||||
break;
|
||||
}
|
||||
CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i)));
|
||||
if (i == 1) {
|
||||
// Choose the larger shape to be used as the producer.
|
||||
if (old_to_new_instrs_[consumer->mutable_operand(0)]
|
||||
->shape()
|
||||
.dimensions() >
|
||||
old_to_new_instrs_[consumer->mutable_operand(1)]
|
||||
->shape()
|
||||
.dimensions()) {
|
||||
producer = consumer->mutable_operand(0);
|
||||
} else {
|
||||
producer = consumer->mutable_operand(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < consumer->operand_count(); ++i) {
|
||||
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
|
||||
CHECK(old_to_new_instrs_.contains(producer));
|
||||
@ -786,8 +827,66 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
|
||||
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
|
||||
} else {
|
||||
CHECK(old_to_new_instrs_.contains(consumer->mutable_operand(i)));
|
||||
TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape(
|
||||
i, old_to_new_instrs_[consumer->mutable_operand(i)]));
|
||||
HloInstruction* operand_to_use = nullptr;
|
||||
|
||||
auto result = instr_to_dim_map_[producer];
|
||||
const int64 old_batch_dim = result.first;
|
||||
const int64 old_space_dim = result.second;
|
||||
const int64 old_batch_size =
|
||||
producer->shape().dimensions(old_batch_dim);
|
||||
HloInstruction* new_instr =
|
||||
old_to_new_instrs_[consumer->mutable_operand(i)];
|
||||
HloInstruction* pivot_new_instr = old_to_new_instrs_[producer];
|
||||
|
||||
auto permute_dims = instr_to_dim_permute_map_[new_instr];
|
||||
const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
|
||||
const int64 space_dim = DimLookUp(permute_dims, old_space_dim);
|
||||
const int64 batch_size = new_instr->shape().dimensions(batch_dim);
|
||||
|
||||
if (new_instr->shape().dimensions(space_dim) !=
|
||||
pivot_new_instr->shape().dimensions(space_dim)) {
|
||||
// Because we do not propagate through transposes, the batch should
|
||||
// always be followed by the split space dimension.
|
||||
CHECK_EQ(batch_dim + 1, space_dim);
|
||||
|
||||
// Reshape to 1D, pad to the producer's size, reshape back to 2D.
|
||||
std::vector<int64> new_dimensions(
|
||||
new_instr->shape().dimensions().begin(),
|
||||
new_instr->shape().dimensions().end());
|
||||
new_dimensions[space_dim] *= (batch_size / old_batch_size);
|
||||
new_dimensions[batch_dim] = old_batch_size;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
|
||||
MakeReshapeHlo(new_dimensions, new_instr));
|
||||
|
||||
const int64 pivot_space_size =
|
||||
pivot_new_instr->shape().dimensions(space_dim) * batch_size /
|
||||
old_batch_size;
|
||||
|
||||
CHECK_GT(pivot_space_size, new_dimensions[space_dim]);
|
||||
|
||||
PaddingConfig padding_config =
|
||||
MakeNoPaddingConfig(reshape->shape().dimensions_size());
|
||||
padding_config.mutable_dimensions(space_dim)->set_edge_padding_high(
|
||||
pivot_space_size - new_dimensions[space_dim]);
|
||||
padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0);
|
||||
HloInstruction* padding =
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(reshape->shape().element_type())));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand,
|
||||
MakePadHlo(reshape, padding, padding_config));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
operand_to_use,
|
||||
MakeReshapeHlo(pivot_new_instr->shape().dimensions(),
|
||||
padded_operand));
|
||||
|
||||
} else {
|
||||
operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)];
|
||||
}
|
||||
TF_CHECK_OK(
|
||||
new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
|
||||
}
|
||||
}
|
||||
auto old_type = new_consumer->mutable_shape()->element_type();
|
||||
@ -1329,25 +1428,21 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
original_conv_dims.input_spatial_dimensions(i)));
|
||||
}
|
||||
|
||||
int64 spatial_dimension_to_split =
|
||||
permuted_conv_dims_numbers.input_spatial_dimensions(
|
||||
get_chosen_spatial_dim(convolution));
|
||||
|
||||
const int64 old_batch_dim = original_conv_dims.input_batch_dimension();
|
||||
const int64 old_batch_size =
|
||||
activations_old->shape().dimensions(old_batch_dim);
|
||||
|
||||
const int64 input_dim_size = activations_old->shape().dimensions(
|
||||
permuted_conv_dims_numbers.input_spatial_dimensions(
|
||||
get_chosen_spatial_dim(convolution)));
|
||||
ConvDetails c =
|
||||
GetConvolutionDetails(convolution, permuted_conv_dims_numbers);
|
||||
|
||||
VLOG(1) << "Propagating on conv activations_batch_dim "
|
||||
<< activations_batch_dim << " spatial_dimension_to_split "
|
||||
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
|
||||
spatial_dimension_to_split, activations_batch_dim));
|
||||
<< c.spatial_dimension_to_split << " old_batch_size "
|
||||
<< old_batch_size;
|
||||
TF_ASSIGN_OR_RETURN(activations_new,
|
||||
BringSpaceNextToBatch(
|
||||
activations_new, permuted_conv_dims_numbers,
|
||||
c.spatial_dimension_to_split, activations_batch_dim));
|
||||
|
||||
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(activations_new->shape().element_type())));
|
||||
@ -1355,32 +1450,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
SelectValidPortion(activations_new, activations_old, select_val,
|
||||
activations_batch_dim, spatial_dimension_to_split,
|
||||
activations_batch_dim, c.spatial_dimension_to_split,
|
||||
old_batch_dim, old_space_dim));
|
||||
// Create the new convolution dim numbers.
|
||||
auto new_dim_numbers = permuted_conv_dims_numbers;
|
||||
|
||||
auto kernel = convolution->mutable_operand(1);
|
||||
const auto& kernel_shape = kernel->shape();
|
||||
const int64 kernel_spatial_dim_size = kernel_shape.dimensions(
|
||||
permuted_conv_dims_numbers.kernel_spatial_dimensions(
|
||||
get_chosen_spatial_dim(convolution)));
|
||||
|
||||
const int64 inherent_low_padding =
|
||||
convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.padding_low();
|
||||
const int64 inherent_high_padding =
|
||||
convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.padding_high();
|
||||
const int64 stride = convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.stride();
|
||||
|
||||
const int64 spatial_size =
|
||||
input_dim_size + inherent_low_padding + inherent_high_padding;
|
||||
VLOG(1) << "spatial size " << spatial_size;
|
||||
VLOG(1) << "spatial size " << c.spatial_size;
|
||||
|
||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
||||
|
||||
@ -1390,18 +1465,18 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
const int64 output_offsets_per_split =
|
||||
CeilOfRatio(output_offsets, num_splits);
|
||||
|
||||
int64 spatial_split_size = output_offsets_per_split * stride;
|
||||
const int64 halo_size =
|
||||
std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0));
|
||||
int64 spatial_split_size =
|
||||
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
|
||||
|
||||
// Keep increasing the split size so that overall size isn't smaller than the
|
||||
// original spatial dimension. Unlike for the first space-to-batch'ed
|
||||
// convolution, while propagating, we can use the last halo_size as available
|
||||
// spatial size.
|
||||
while (spatial_split_size * num_splits + halo_size - spatial_size < 0) {
|
||||
spatial_split_size += stride;
|
||||
while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) {
|
||||
spatial_split_size += c.stride;
|
||||
}
|
||||
|
||||
int64 slice_size = spatial_split_size + halo_size;
|
||||
int64 slice_size = spatial_split_size + c.halo_size;
|
||||
|
||||
VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
|
||||
<< slice_size;
|
||||
@ -1409,7 +1484,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
const int64 new_batch_size =
|
||||
activations_new->shape().dimensions(activations_batch_dim);
|
||||
const int64 new_space_size =
|
||||
activations_new->shape().dimensions(spatial_dimension_to_split);
|
||||
activations_new->shape().dimensions(c.spatial_dimension_to_split);
|
||||
// In the below case, we cannot use the activations directly for Halo
|
||||
// Duplication. We must reshape them.
|
||||
if (spatial_split_size > new_space_size) {
|
||||
@ -1418,7 +1493,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
activations_new->shape().dimensions().end());
|
||||
const int64 reshaped_space_size =
|
||||
new_space_size * new_batch_size / old_batch_size;
|
||||
new_dimensions[spatial_dimension_to_split] = reshaped_space_size;
|
||||
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
|
||||
new_dimensions[activations_batch_dim] = old_batch_size;
|
||||
|
||||
// Reshape the output of the new conv into the old convolutions shape.
|
||||
@ -1427,10 +1502,10 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
|
||||
PaddingConfig padding_config =
|
||||
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
|
||||
padding_config.mutable_dimensions(spatial_dimension_to_split)
|
||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
||||
->set_edge_padding_high(spatial_split_size * new_batch_size -
|
||||
reshaped_space_size);
|
||||
padding_config.mutable_dimensions(spatial_dimension_to_split)
|
||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
||||
->set_edge_padding_low(0);
|
||||
HloInstruction* padding =
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
@ -1444,7 +1519,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
reshaped_activations->shape().dimensions().begin(),
|
||||
reshaped_activations->shape().dimensions().end());
|
||||
|
||||
reshape_back_dims[spatial_dimension_to_split] = spatial_split_size;
|
||||
reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size;
|
||||
reshape_back_dims[activations_batch_dim] = new_batch_size;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1453,34 +1528,38 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
HaloDuplicateWithSlice(reshaped_activations, spatial_dimension_to_split,
|
||||
activations_batch_dim, old_batch_size,
|
||||
/*low_padding=*/inherent_low_padding,
|
||||
/*high_padding=*/inherent_high_padding,
|
||||
slice_size - spatial_split_size,
|
||||
old_split_dim_size));
|
||||
HaloDuplicateWithSlice(
|
||||
reshaped_activations, c.spatial_dimension_to_split,
|
||||
activations_batch_dim, old_batch_size,
|
||||
/*low_padding=*/c.base_dilation_factor != 1 &&
|
||||
c.inherent_low_padding != 0
|
||||
? c.base_dilation_factor - 1
|
||||
: c.inherent_low_padding,
|
||||
c.inherent_high_padding, slice_size - spatial_split_size,
|
||||
old_split_dim_size));
|
||||
} else {
|
||||
// If the ideal spatial_split_size was smaller than the incoming spatial
|
||||
// dimension size, we don't need reshaping. Instead, we determine the
|
||||
// additional space available, and adjust the required slice size (and
|
||||
// thereby the halo size).'t need reshaping. Instead, we determine the
|
||||
// additional space available, and adjust the required slice size (and
|
||||
// thereby the halo size).
|
||||
if (spatial_split_size < new_space_size) {
|
||||
const int64 additional_space_present = spatial_split_size % stride;
|
||||
const int64 additional_space_present = spatial_split_size % c.stride;
|
||||
spatial_split_size = new_space_size;
|
||||
slice_size =
|
||||
spatial_split_size +
|
||||
std::max(kernel_spatial_dim_size - stride - additional_space_present,
|
||||
static_cast<int64>(0));
|
||||
spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride -
|
||||
additional_space_present,
|
||||
static_cast<int64>(0));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
HaloDuplicateWithSlice(activations_new, spatial_dimension_to_split,
|
||||
HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
|
||||
activations_batch_dim, old_batch_size,
|
||||
/*low_padding=*/inherent_low_padding,
|
||||
/*high_padding=*/inherent_high_padding,
|
||||
/*low_padding=*/c.base_dilation_factor != 1 &&
|
||||
c.inherent_low_padding != 0
|
||||
? c.base_dilation_factor - 1
|
||||
: c.inherent_low_padding,
|
||||
c.inherent_high_padding,
|
||||
slice_size - spatial_split_size,
|
||||
old_split_dim_size));
|
||||
}
|
||||
@ -1515,9 +1594,9 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
|
||||
auto new_window = convolution->window();
|
||||
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
|
||||
->set_padding_high(0);
|
||||
->set_padding_high(c.high_padding_for_conv);
|
||||
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
|
||||
->set_padding_low(0);
|
||||
->set_padding_low(c.low_padding_for_conv);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_conv,
|
||||
MakeConvolveHlo(
|
||||
@ -1855,19 +1934,9 @@ HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
HloInstruction* convolution) {
|
||||
VLOG(1) << "Handling conv " << convolution->ToString();
|
||||
|
||||
changed_ = false;
|
||||
|
||||
ConvolutionDimensionNumbers dim_numbers =
|
||||
convolution->convolution_dimension_numbers();
|
||||
|
||||
int64 activations_batch_dim = dim_numbers.input_batch_dimension();
|
||||
|
||||
const int64 old_batch_size =
|
||||
convolution->operand(0)->shape().dimensions(activations_batch_dim);
|
||||
ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
|
||||
HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
|
||||
auto activations = convolution->mutable_operand(0);
|
||||
|
||||
auto kernel = convolution->mutable_operand(1);
|
||||
const auto& kernel_shape = kernel->shape();
|
||||
@ -1875,14 +1944,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
kernel_shape.dimensions(dim_numbers.kernel_spatial_dimensions(
|
||||
get_chosen_spatial_dim(convolution)));
|
||||
|
||||
auto activations = convolution->mutable_operand(0);
|
||||
|
||||
int64 spatial_dimension_to_split =
|
||||
const int64 spatial_dimension_to_split =
|
||||
dim_numbers.input_spatial_dimensions(get_chosen_spatial_dim(convolution));
|
||||
|
||||
const int64 input_dim_size =
|
||||
activations->shape().dimensions(dim_numbers.input_spatial_dimensions(
|
||||
get_chosen_spatial_dim(convolution)));
|
||||
activations->shape().dimensions(spatial_dimension_to_split);
|
||||
|
||||
const int64 inherent_low_padding =
|
||||
convolution->window()
|
||||
@ -1892,26 +1958,75 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.padding_high();
|
||||
const bool inherent_padding_needed =
|
||||
inherent_low_padding != 0 || inherent_high_padding != 0;
|
||||
|
||||
const int64 stride = convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.stride();
|
||||
|
||||
const int64 base_dilation_factor =
|
||||
convolution->window()
|
||||
.dimensions(get_chosen_spatial_dim(convolution))
|
||||
.base_dilation();
|
||||
|
||||
const int64 spatial_size =
|
||||
input_dim_size + inherent_low_padding + inherent_high_padding;
|
||||
VLOG(1) << "spatial size " << spatial_size;
|
||||
input_dim_size + (base_dilation_factor > 1 ? 0 : inherent_low_padding) +
|
||||
inherent_high_padding;
|
||||
|
||||
const int64 halo_size =
|
||||
std::max(kernel_spatial_dim_size - stride - (base_dilation_factor - 1),
|
||||
static_cast<int64>(0));
|
||||
const int64 high_padding_for_conv = base_dilation_factor == 1 ? 0
|
||||
: inherent_low_padding == 0
|
||||
? base_dilation_factor - 1
|
||||
: 0;
|
||||
const int64 low_padding_for_conv =
|
||||
base_dilation_factor == 1 ? 0 : inherent_low_padding;
|
||||
|
||||
return ConvDetails{spatial_dimension_to_split,
|
||||
inherent_low_padding,
|
||||
inherent_high_padding,
|
||||
stride,
|
||||
spatial_size,
|
||||
base_dilation_factor,
|
||||
halo_size,
|
||||
high_padding_for_conv,
|
||||
low_padding_for_conv,
|
||||
kernel_spatial_dim_size,
|
||||
input_dim_size};
|
||||
}
|
||||
|
||||
Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
HloInstruction* convolution) {
|
||||
VLOG(1) << "Handling conv " << convolution->ToString();
|
||||
|
||||
changed_ = false;
|
||||
|
||||
ConvolutionDimensionNumbers dim_numbers =
|
||||
convolution->convolution_dimension_numbers();
|
||||
|
||||
ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
|
||||
|
||||
int64 activations_batch_dim = dim_numbers.input_batch_dimension();
|
||||
|
||||
const int64 old_batch_size =
|
||||
convolution->operand(0)->shape().dimensions(activations_batch_dim);
|
||||
|
||||
auto activations = convolution->mutable_operand(0);
|
||||
|
||||
const bool inherent_padding_needed =
|
||||
c.inherent_low_padding != 0 || c.inherent_high_padding != 0;
|
||||
|
||||
VLOG(1) << "spatial size " << c.spatial_size;
|
||||
|
||||
const int64 num_splits = kNewBatchSize / old_batch_size;
|
||||
auto original_conv = convolution;
|
||||
|
||||
// We'd need transposition of activations here such that batch and space dim
|
||||
// that is being split are adjacent (in that order).
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations,
|
||||
BringSpaceNextToBatch(activations, dim_numbers,
|
||||
spatial_dimension_to_split, activations_batch_dim));
|
||||
TF_ASSIGN_OR_RETURN(activations,
|
||||
BringSpaceNextToBatch(activations, dim_numbers,
|
||||
c.spatial_dimension_to_split,
|
||||
activations_batch_dim));
|
||||
// Create the new convolution dim numbers.
|
||||
auto new_dim_numbers = dim_numbers;
|
||||
|
||||
@ -1922,11 +2037,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
const int64 output_offsets_per_split =
|
||||
CeilOfRatio(output_offsets, num_splits);
|
||||
|
||||
int64 spatial_split_size = output_offsets_per_split * stride;
|
||||
int64 spatial_split_size =
|
||||
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
|
||||
// Keep increasing the split size so that overall size isn't smaller than the
|
||||
// original spatial dimension.
|
||||
while (spatial_split_size * num_splits - spatial_size < 0) {
|
||||
spatial_split_size += stride;
|
||||
while (spatial_split_size * num_splits - c.spatial_size < 0) {
|
||||
spatial_split_size += c.stride;
|
||||
}
|
||||
|
||||
auto reduce_window = DoesConvolutionFeedReduceWindow(convolution);
|
||||
@ -1938,33 +2054,32 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
// windows.
|
||||
const int64 red_win_stride =
|
||||
reduce_window->window().dimensions(output_spatial_dim).stride();
|
||||
while ((spatial_split_size / stride) % red_win_stride != 0) {
|
||||
spatial_split_size += stride;
|
||||
while ((spatial_split_size / c.stride) % red_win_stride != 0) {
|
||||
spatial_split_size += c.stride;
|
||||
}
|
||||
}
|
||||
|
||||
const int64 slice_size =
|
||||
spatial_split_size +
|
||||
std::max(kernel_spatial_dim_size - stride, static_cast<int64>(0));
|
||||
const int64 slice_size = spatial_split_size + c.halo_size;
|
||||
|
||||
// Pad spatial dim.
|
||||
const int64 pad_size = spatial_split_size * num_splits - spatial_size;
|
||||
const int64 pad_size = spatial_split_size * num_splits - c.spatial_size;
|
||||
|
||||
VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
|
||||
<< stride;
|
||||
VLOG(1) << "spatial_dimension_to_split " << spatial_dimension_to_split
|
||||
<< c.stride << " slice_size " << slice_size;
|
||||
VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
|
||||
<< " num_splits " << num_splits << " kernel_spatial_dim_size "
|
||||
<< kernel_spatial_dim_size;
|
||||
<< c.kernel_spatial_dim_size;
|
||||
|
||||
// Because we are splitting the spatial dimension, if convolution needed
|
||||
// padding in the spatial dimension, we materialize it.
|
||||
if (pad_size != 0 || inherent_padding_needed) {
|
||||
PaddingConfig padding_config =
|
||||
MakeNoPaddingConfig(activations->shape().dimensions_size());
|
||||
padding_config.mutable_dimensions(spatial_dimension_to_split)
|
||||
->set_edge_padding_high(inherent_high_padding + pad_size);
|
||||
padding_config.mutable_dimensions(spatial_dimension_to_split)
|
||||
->set_edge_padding_low(inherent_low_padding);
|
||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
||||
->set_edge_padding_high(c.inherent_high_padding + pad_size);
|
||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
||||
->set_edge_padding_low(
|
||||
c.base_dilation_factor == 1 ? c.inherent_low_padding : 0);
|
||||
HloInstruction* padding =
|
||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(activations->shape().element_type())));
|
||||
@ -1991,7 +2106,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
activations->shape().dimensions().begin(),
|
||||
activations->shape().dimensions().end());
|
||||
|
||||
reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
|
||||
reshape_dimensions[c.spatial_dimension_to_split] = spatial_split_size;
|
||||
reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
|
||||
@ -2000,12 +2115,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
|
||||
VLOG(1) << "First reshape done " << batch_increased_reshape->ToString();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(activations,
|
||||
HaloDuplicateWithSlice(
|
||||
batch_increased_reshape, spatial_dimension_to_split,
|
||||
activations_batch_dim, old_batch_size,
|
||||
/*low_padding=*/0, /*high_padding=*/0,
|
||||
slice_size - spatial_split_size, input_dim_size));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations, HaloDuplicateWithSlice(batch_increased_reshape,
|
||||
c.spatial_dimension_to_split,
|
||||
activations_batch_dim, old_batch_size,
|
||||
/*low_padding=*/0, /*high_padding=*/0,
|
||||
c.halo_size, c.input_dim_size));
|
||||
|
||||
VLOG(1) << "Batch merge done " << activations->ToString();
|
||||
|
||||
@ -2040,9 +2155,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
<< " batch dim " << new_dim_numbers.input_batch_dimension();
|
||||
auto new_window = convolution->window();
|
||||
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
|
||||
->set_padding_high(0);
|
||||
->set_padding_high(c.high_padding_for_conv);
|
||||
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
|
||||
->set_padding_low(0);
|
||||
->set_padding_low(c.low_padding_for_conv);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_conv,
|
||||
MakeConvolveHlo(
|
||||
|
@ -113,7 +113,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) {
|
||||
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
|
||||
}
|
||||
|
||||
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) {
|
||||
TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithBaseDilation) {
|
||||
string hlo_string = R"(
|
||||
|
||||
HloModule module
|
||||
@ -129,8 +129,22 @@ ENTRY computation {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
auto computation = module->entry_computation();
|
||||
ConvolutionSpaceToBatchConverter converter;
|
||||
ASSERT_FALSE(converter.Run(module.get()).ValueOrDie());
|
||||
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
|
||||
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_THAT(root, op::Transpose());
|
||||
EXPECT_THAT(root->operand(0), op::Slice());
|
||||
auto reshape = root->operand(0)->operand(0);
|
||||
EXPECT_THAT(reshape, op::Reshape());
|
||||
EXPECT_THAT(reshape->operand(0)->operand(1), op::Convolution());
|
||||
const int64 batch_dim = reshape->operand(0)
|
||||
->operand(1)
|
||||
->convolution_dimension_numbers()
|
||||
.output_batch_dimension();
|
||||
|
||||
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user