[XLA] Enable RHS space-to-batch on backprop filter convolutions.

PiperOrigin-RevId: 350873443
Change-Id: I1388f4fe5f151c10062326d1b08d36c8eaec860c
This commit is contained in:
A. Unique TensorFlower 2021-01-08 18:44:59 -08:00 committed by TensorFlower Gardener
parent 5975aae33f
commit 5b9181194e

View File

@ -82,7 +82,8 @@ class ConvolutionVisitor {
// Function that determines if space-to-batch can be propagated into the
// consumer. Such propagation is only possible when all required operands are
// space-to-batch'ed.
bool CanPropagate(HloInstruction* consumer, HloInstruction* producer);
bool CanPropagate(HloInstruction* consumer, HloInstruction* producer,
bool last_try = false);
// Returns true if the op has all its direct and indirect operands being
// created via broadcasts. Consumer uses op, and is space-to-batched.
@ -116,7 +117,7 @@ class ConvolutionVisitor {
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
int64& spatial_dimension_to_split, int64& activations_batch_dim,
int64 high_padding, int64 low_padding, int64 spatial_split_size,
int64 num_splits, bool is_backprop = false);
int64 num_splits, bool is_backprop = false, bool is_rhs = false);
// Perform space-to-batch propagation on the convolution. Assumes the
// activations were already space-to-batched.
@ -149,7 +150,13 @@ class ConvolutionVisitor {
StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch(
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
int64& spatial_dimension_to_split, int64& activations_batch_dim,
bool is_backprop = false);
bool is_backprop = false, bool is_rhs = false);
// Increases the spatial dimension size in an already space-to-batched shape
// so that the new size is new_spatial_dim_size.
StatusOr<HloInstruction*> IncreaseSpatialSizeOnSpaceToBatchedShape(
HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
int64 spatial_dimension, int64 new_spatial_dim_size);
// Function that converts spaced-to-batch shape back to the original.
StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr);
@ -213,6 +220,10 @@ class ConvolutionVisitor {
absl::flat_hash_map<HloInstruction*, std::vector<int64>>
instr_to_dim_permute_map_;
// Map maintaining previously space-to-batched broadcasts.
absl::flat_hash_map<HloInstruction*, absl::flat_hash_set<HloInstruction*>>
broadcast_map_;
// Whether rewrite has occurred.
bool changed_ = false;
@ -456,7 +467,7 @@ StatusOr<ConvolutionVisitor::SpaceNextToBatchDetails>
ConvolutionVisitor::BringSpaceNextToBatch(
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
int64& spatial_dimension_to_split, int64& activations_batch_dim,
bool is_backprop) {
bool is_backprop, bool is_rhs) {
std::vector<int64> transpose_dims(activations->shape().rank());
if (spatial_dimension_to_split == activations_batch_dim + 1) {
absl::c_iota(transpose_dims, 0);
@ -465,49 +476,137 @@ ConvolutionVisitor::BringSpaceNextToBatch(
int64 pushed_counter = 0;
int64 new_batch_dim, new_spatial_dim;
int64 dim_counter = 0;
for (int i = 0; i < activations->shape().rank(); ++i) {
if (i == activations_batch_dim) {
continue;
}
if (i == spatial_dimension_to_split) {
transpose_dims[dim_counter++] = activations_batch_dim;
new_batch_dim = pushed_counter;
pushed_counter++;
new_spatial_dim = pushed_counter;
}
if (is_rhs) {
CHECK(is_backprop);
for (int i = 0; i < activations->shape().rank(); ++i) {
if (i == activations_batch_dim) {
continue;
}
if (i == spatial_dimension_to_split) {
transpose_dims[dim_counter++] = activations_batch_dim;
new_batch_dim = pushed_counter;
pushed_counter++;
new_spatial_dim = pushed_counter;
}
if (is_backprop && i == dim_numbers.input_batch_dimension()) {
new_dim_numbers.set_input_batch_dimension(pushed_counter);
} else if (i == dim_numbers.input_feature_dimension()) {
new_dim_numbers.set_input_feature_dimension(pushed_counter);
} else {
for (int j = 0; j < dim_numbers.input_spatial_dimensions_size(); ++j) {
if (i == dim_numbers.input_spatial_dimensions(j)) {
new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
break;
if (i == dim_numbers.kernel_output_feature_dimension()) {
new_dim_numbers.set_kernel_output_feature_dimension(pushed_counter);
} else {
auto it = absl::c_find(dim_numbers.kernel_spatial_dimensions(), i);
if (it != dim_numbers.kernel_spatial_dimensions().end()) {
int64 j = it - dim_numbers.kernel_spatial_dimensions().begin();
new_dim_numbers.set_kernel_spatial_dimensions(j, pushed_counter);
}
}
transpose_dims[dim_counter++] = i;
pushed_counter++;
}
transpose_dims[dim_counter++] = i;
pushed_counter++;
}
activations_batch_dim = new_batch_dim;
spatial_dimension_to_split = new_spatial_dim;
TF_ASSIGN_OR_RETURN(activations,
MakeTransposeHlo(activations, transpose_dims));
activations_batch_dim = new_batch_dim;
spatial_dimension_to_split = new_spatial_dim;
TF_ASSIGN_OR_RETURN(activations,
MakeTransposeHlo(activations, transpose_dims));
new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim);
if (is_backprop) {
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
} else {
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
for (int i = 0; i < activations->shape().rank(); ++i) {
if (i == activations_batch_dim) {
continue;
}
if (i == spatial_dimension_to_split) {
transpose_dims[dim_counter++] = activations_batch_dim;
new_batch_dim = pushed_counter;
pushed_counter++;
new_spatial_dim = pushed_counter;
}
if (is_backprop && i == dim_numbers.input_batch_dimension()) {
new_dim_numbers.set_input_batch_dimension(pushed_counter);
} else if (i == dim_numbers.input_feature_dimension()) {
new_dim_numbers.set_input_feature_dimension(pushed_counter);
} else {
auto it = absl::c_find(dim_numbers.input_spatial_dimensions(), i);
if (it != dim_numbers.input_spatial_dimensions().end()) {
int64 j = it - dim_numbers.input_spatial_dimensions().begin();
new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
}
}
transpose_dims[dim_counter++] = i;
pushed_counter++;
}
activations_batch_dim = new_batch_dim;
spatial_dimension_to_split = new_spatial_dim;
TF_ASSIGN_OR_RETURN(activations,
MakeTransposeHlo(activations, transpose_dims));
if (is_backprop) {
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
} else {
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
}
}
dim_numbers = new_dim_numbers;
}
return SpaceNextToBatchDetails{activations, transpose_dims};
}
StatusOr<HloInstruction*>
ConvolutionVisitor::IncreaseSpatialSizeOnSpaceToBatchedShape(
HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
int64 spatial_dimension, int64 new_spatial_dim_size) {
CHECK_EQ(batch_dimension + 1, spatial_dimension);
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
activations->shape().dimensions().end());
const int64 new_batch_size = activations->shape().dimensions(batch_dimension);
int64 spatial_dim_size = activations->shape().dimensions(spatial_dimension);
const int64 reshaped_space_size =
spatial_dim_size * new_batch_size / old_batch_size;
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
<< new_batch_size << " old_batch_size " << old_batch_size;
new_dimensions[spatial_dimension] = reshaped_space_size;
new_dimensions[batch_dimension] = old_batch_size;
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
MakeReshapeHlo(new_dimensions, activations));
VLOG(3) << "First reshape done";
PaddingConfig padding_config =
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
padding_config.mutable_dimensions(spatial_dimension)
->set_edge_padding_high(new_spatial_dim_size * new_batch_size /
old_batch_size -
reshaped_space_size);
padding_config.mutable_dimensions(spatial_dimension)->set_edge_padding_low(0);
HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(reshaped_activations->shape().element_type())));
TF_ASSIGN_OR_RETURN(
reshaped_activations,
MakePadHlo(reshaped_activations, padding, padding_config));
std::vector<int64> reshape_back_dims(
reshaped_activations->shape().dimensions().begin(),
reshaped_activations->shape().dimensions().end());
reshape_back_dims[spatial_dimension] = new_spatial_dim_size;
reshape_back_dims[batch_dimension] = new_batch_size;
TF_ASSIGN_OR_RETURN(HloInstruction * activations_new,
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
VLOG(3) << "Size increased activations " << activations_new->ToString();
return activations_new;
}
StatusOr<bool> ConvolutionVisitor::Run() {
for (auto conv : conv_visitor_list_) {
if (convs_to_visit_.count(conv) > 0) {
@ -519,6 +618,29 @@ StatusOr<bool> ConvolutionVisitor::Run() {
// Iterate through all instructions that we could not propagate through, and
// turn their operands from batch-to-space as needed.
for (auto instr : non_propagatable_instrs_) {
if (instr->opcode() == HloOpcode::kConvolution) {
VLOG(1) << "Instr " << instr->ToString();
}
// Try to propagate on backprop filters
if (instr->opcode() == HloOpcode::kConvolution &&
!IsConvSuitableForSpaceToBatch(instr)) {
HloInstruction* producer = nullptr;
if (old_to_new_instrs_.contains(instr->mutable_operand(0))) {
producer = instr->mutable_operand(0);
} else if (old_to_new_instrs_.contains(instr->mutable_operand(1))) {
producer = instr->mutable_operand(1);
}
if (producer) {
if (CanPropagate(instr, producer, /*last_try=*/true)) {
bool needs_further_propagation;
TF_ASSIGN_OR_RETURN(needs_further_propagation,
Propagate(instr, producer));
TF_CHECK_OK(computation_->ReplaceInstruction(
instr, old_to_new_instrs_[instr]));
continue;
}
}
}
VLOG(1) << "Could not eventually propagate through " << instr->ToString();
absl::flat_hash_map<int64, HloInstruction*> operand_map;
for (int64 i = 0; i < instr->operand_count(); ++i) {
@ -539,14 +661,14 @@ bool IsTrivialElementwise(HloInstruction* hlo) {
if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng ||
hlo->opcode() == HloOpcode::kCopy ||
hlo->opcode() == HloOpcode::kConstant ||
hlo->opcode() == HloOpcode::kIota) {
hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) {
return false;
}
return hlo->IsElementwise();
}
bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
HloInstruction* producer) {
HloInstruction* producer, bool last_try) {
if (IsTrivialElementwise(consumer)) {
VLOG(2) << "Doing propagation check on elementwise op: "
<< consumer->ToString();
@ -604,7 +726,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
// Make sure all other dimensions are of the same size.
if (pivot_new_instr->shape().dimensions(j) !=
new_instr->shape().dimensions(j)) {
if (!(consumer->IsElementwiseBinary() &&
if (!((consumer->IsElementwiseBinary() ||
consumer->opcode() == HloOpcode::kSelect) &&
j == instr_to_dim_map_[pivot_operand].second)) {
VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString()
@ -653,20 +776,70 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
return false;
}
// Same reason why we give up on batch group counts applies to features in
// backprop.
if (consumer->feature_group_count() != 1) {
return false;
}
VLOG(2) << "Checking for backprop filter conv propagatability";
CHECK_EQ(consumer->operand_count(), 2);
VLOG(2) << "Checking for backprop filter conv operands "
<< consumer->operand_count();
auto activations = consumer->mutable_operand(0);
auto kernel = consumer->mutable_operand(1);
if (!old_to_new_instrs_.contains(kernel)) {
VLOG(2) << "Backprop filter conv not ready for propagation because of "
"kernel is not space-to-batched";
if (!last_try) {
if (!old_to_new_instrs_.contains(kernel) ||
!old_to_new_instrs_.contains(activations)) {
return false;
}
}
if (!old_to_new_instrs_.contains(kernel) &&
!old_to_new_instrs_.contains(activations)) {
return false;
}
const int64 rhs_dilation = consumer->window()
.dimensions(get_chosen_spatial_dim(consumer))
.window_dilation();
if (!old_to_new_instrs_.contains(kernel)) {
const int64 rhs_batch =
kernel->shape().dimensions(consumer->convolution_dimension_numbers()
.kernel_input_feature_dimension());
auto dim_map_val_op_0 = instr_to_dim_map_[activations];
const int64 old_batch_dim = dim_map_val_op_0.first;
const int64 old_space_dim = dim_map_val_op_0.second;
auto first_operand = old_to_new_instrs_[activations];
auto permute_dims_first_operand =
instr_to_dim_permute_map_[first_operand];
const int64 new_batch_dim =
DimLookUp(permute_dims_first_operand, old_batch_dim);
const int64 new_space_dim =
DimLookUp(permute_dims_first_operand, old_space_dim);
const int64 lhs_batch = first_operand->shape().dimensions(new_batch_dim);
if (first_operand->shape().dimensions(new_space_dim) % rhs_dilation !=
0) {
return false;
}
// Because we want to convert activations into a space-to-batched version
// only for backprop filter convolutions, we want to make sure that the
// batch dimensions (feature dimensions, technically) are same sized.
// Since LHS is already space-to-batched, we need to account for it too.
if (rhs_batch * kNumSplits != lhs_batch) {
return false;
}
// If kernel have not been propagated through, we can do
// space-to-batch on them provided kernel has been propagated.
VLOG(2)
<< "Backprop filter conv ready for propagation: activations ready, "
" kernel will be space-to-batched";
return true;
}
if (!old_to_new_instrs_.contains(activations)) {
const int64 lhs_batch = activations->shape().dimensions(
consumer->convolution_dimension_numbers().input_feature_dimension());
@ -720,10 +893,7 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
return false;
}
const int64 rhs_dilation = consumer->window()
.dimensions(get_chosen_spatial_dim(consumer))
.window_dilation();
if (first_operand->shape().dimensions(new_space_dim_operand_0) !=
if (first_operand->shape().dimensions(new_space_dim_operand_0) >
rhs_dilation *
second_operand->shape().dimensions(new_space_dim_operand_1)) {
VLOG(2) << "Backprop filter conv not ready for propagation because of "
@ -816,20 +986,57 @@ void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer,
auto new_producer = old_to_new_instrs_[producer];
auto permute_dims = instr_to_dim_permute_map_[new_producer];
auto dim_map_val = instr_to_dim_map_[producer];
const int64 old_batch_dim = dim_map_val.first;
const int64 old_space_dim = dim_map_val.second;
auto orig_broadcast_dims = consumer->dimensions();
bool batch_is_broadcasted =
absl::c_linear_search(orig_broadcast_dims, old_batch_dim);
const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
bool map_found = broadcast_map_.contains(consumer);
if (map_found) {
// Check if we previously had created the same broadcast.
for (auto previous_broadcast : broadcast_map_[consumer]) {
if (ShapeUtil::CompatibleIgnoringElementType(previous_broadcast->shape(),
new_producer->shape())) {
return;
}
}
}
std::vector<int64> final_shape_dims(
new_producer->shape().dimensions().begin(),
new_producer->shape().dimensions().end());
if (batch_is_broadcasted) {
final_shape_dims[new_batch_dim] =
producer->shape().dimensions(old_batch_dim);
final_shape_dims[new_space_dim] *= kNumSplits;
}
std::vector<int64> broadcast_dims;
for (auto j : consumer->dimensions()) {
broadcast_dims.push_back(DimLookUp(permute_dims, j));
}
auto new_broadcast =
MakeBroadcastHlo(consumer->mutable_operand(0), broadcast_dims,
new_producer->shape().dimensions());
auto new_broadcast = MakeBroadcastHlo(consumer->mutable_operand(0),
broadcast_dims, final_shape_dims);
VLOG(1) << "Created broadcast " << new_broadcast->ToString();
// Pass on the permutation information from the producer.
old_to_new_instrs_[consumer] = new_broadcast;
instr_to_dim_map_[consumer] = dim_map_val;
instr_to_dim_permute_map_[new_broadcast] = std::vector<int64>(
instr_to_dim_permute_map_[old_to_new_instrs_[producer]]);
if (batch_is_broadcasted) {
new_broadcast =
MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast)
.ValueOrDie();
VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString();
}
if (!map_found) {
absl::flat_hash_set<HloInstruction*> set_of_broadcasts;
broadcast_map_[consumer] = set_of_broadcasts;
}
broadcast_map_[consumer].insert(new_broadcast);
}
void ConvolutionVisitor::RewriteBroadcastTree(
@ -882,11 +1089,9 @@ bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast,
CHECK(instr_to_dim_map_.contains(old_other_op));
auto result = instr_to_dim_map_[old_other_op];
const int64 batch_dim = result.first;
const int64 space_dim = result.second;
auto broadcast_dims = broadcast->dimensions();
return !absl::c_linear_search(broadcast_dims, batch_dim) &&
!absl::c_linear_search(broadcast_dims, space_dim);
return !absl::c_linear_search(broadcast_dims, space_dim);
}
bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
@ -990,27 +1195,50 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
// For elementwise binary ops, both of whose operands have been space-to-
// batched, if their new spatial sizes don't match, choose the bigger one
// as the producer.
if (consumer->IsElementwiseBinary() &&
old_to_new_instrs_.contains(consumer->mutable_operand(0)) &&
old_to_new_instrs_.contains(consumer->mutable_operand(1))) {
is_pivot_producer_modified = true;
if (old_to_new_instrs_[consumer->mutable_operand(0)]
->shape()
.dimensions() > old_to_new_instrs_[consumer->mutable_operand(1)]
->shape()
.dimensions()) {
producer = consumer->mutable_operand(0);
} else {
producer = consumer->mutable_operand(1);
if (consumer->IsElementwiseBinary() ||
consumer->opcode() == HloOpcode::kSelect) {
int64 pivot_operand_number = -1;
HloInstruction* pivot_operand = nullptr;
for (int i = 0; i < consumer->operand_count(); ++i) {
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
continue;
}
auto operand = consumer->mutable_operand(i);
if (old_to_new_instrs_.contains(operand)) {
if (pivot_operand_number == -1 ||
old_to_new_instrs_[pivot_operand]->shape().dimensions() <
old_to_new_instrs_[operand]->shape().dimensions()) {
is_pivot_producer_modified = true;
pivot_operand_number = i;
pivot_operand = consumer->mutable_operand(pivot_operand_number);
}
}
}
if (pivot_operand_number != -1) {
producer = pivot_operand;
}
}
for (int64 i = 0; i < consumer->operand_count(); ++i) {
std::vector<HloInstruction*> instructions_to_transform;
if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
auto broadcast = consumer->mutable_operand(i);
PropagateOnBroadcast(broadcast, producer);
HloInstruction* new_broadcast = nullptr;
auto new_producer = old_to_new_instrs_[producer];
for (auto previous_broadcast : broadcast_map_[broadcast]) {
if (ShapeUtil::CompatibleIgnoringElementType(
previous_broadcast->shape(), new_producer->shape())) {
new_broadcast = previous_broadcast;
break;
}
}
CHECK_NE(new_broadcast, nullptr);
TF_CHECK_OK(
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
} else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
HloInstruction* operand_to_use = nullptr;
auto result = instr_to_dim_map_[producer];
const int64 old_batch_dim = result.first;
const int64 old_space_dim = result.second;
@ -1070,13 +1298,6 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
}
TF_CHECK_OK(
new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
} else if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
if (!old_to_new_instrs_.contains(consumer->operand(i))) {
PropagateOnBroadcast(consumer->mutable_operand(i), producer);
}
auto new_broadcast = old_to_new_instrs_[consumer->mutable_operand(i)];
TF_CHECK_OK(
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
} else if (consumer->IsElementwiseBinary() &&
IsBroadcastTree(consumer->mutable_operand(i), producer,
instructions_to_transform)) {
@ -1463,8 +1684,10 @@ StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
HloInstruction* old_instr) {
if (batch_to_space_map_.count(old_instr)) {
CHECK_NE(batch_to_space_map_[old_instr], nullptr);
return batch_to_space_map_[old_instr];
}
auto result = instr_to_dim_map_[old_instr];
const int64 old_batch_dim = result.first;
const int64 old_space_dim = result.second;
@ -1683,68 +1906,28 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
<< slice_size;
const int64 new_batch_size =
activations_new->shape().dimensions(activations_batch_dim);
const int64 new_space_size =
activations_new->shape().dimensions(c.spatial_dimension_to_split);
// In the below case, we cannot use the activations directly for Halo
// Duplication. We must reshape them.
if (spatial_split_size > new_space_size) {
std::vector<int64> new_dimensions(
activations_new->shape().dimensions().begin(),
activations_new->shape().dimensions().end());
const int64 reshaped_space_size =
new_space_size * new_batch_size / old_batch_size;
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
<< new_batch_size << " old_batch_size " << old_batch_size;
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
new_dimensions[activations_batch_dim] = old_batch_size;
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
MakeReshapeHlo(new_dimensions, activations_new));
VLOG(3) << "First reshape done";
PaddingConfig padding_config =
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_high(spatial_split_size * new_batch_size /
old_batch_size -
reshaped_space_size);
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_low(0);
HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(reshaped_activations->shape().element_type())));
TF_ASSIGN_OR_RETURN(
reshaped_activations,
MakePadHlo(reshaped_activations, padding, padding_config));
std::vector<int64> reshape_back_dims(
reshaped_activations->shape().dimensions().begin(),
reshaped_activations->shape().dimensions().end());
reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size;
reshape_back_dims[activations_batch_dim] = new_batch_size;
TF_ASSIGN_OR_RETURN(
reshaped_activations,
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
VLOG(3) << "Second reshape done";
activations_new,
IncreaseSpatialSizeOnSpaceToBatchedShape(
activations_new, activations_batch_dim, old_batch_size,
c.spatial_dimension_to_split, spatial_split_size));
TF_ASSIGN_OR_RETURN(
activations_new,
HaloDuplicateWithSlice(
reshaped_activations, c.spatial_dimension_to_split,
activations_batch_dim, old_batch_size,
/*low_padding=*/c.base_dilation_factor != 1 &&
c.inherent_low_padding != 0
? c.base_dilation_factor - 1
: c.inherent_low_padding,
c.inherent_high_padding, slice_size - spatial_split_size,
old_split_dim_size));
HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
activations_batch_dim, old_batch_size,
/*low_padding=*/c.base_dilation_factor != 1 &&
c.inherent_low_padding != 0
? c.base_dilation_factor - 1
: c.inherent_low_padding,
c.inherent_high_padding,
slice_size - spatial_split_size,
old_split_dim_size));
} else {
// If the ideal spatial_split_size was smaller than the incoming spatial
// dimension size, we don't need reshaping. Instead, we determine the
@ -1835,14 +2018,15 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
int64& spatial_dimension_to_split,
int64& activations_batch_dim, int64 high_padding,
int64 low_padding, int64 spatial_split_size,
int64 num_splits, bool is_backprop) {
int64 num_splits, bool is_backprop,
bool is_rhs) {
const int64 old_batch_size =
activations->shape().dimensions(activations_batch_dim);
TF_ASSIGN_OR_RETURN(
auto retval, BringSpaceNextToBatch(activations, dim_numbers,
spatial_dimension_to_split,
activations_batch_dim, is_backprop));
TF_ASSIGN_OR_RETURN(auto retval,
BringSpaceNextToBatch(
activations, dim_numbers, spatial_dimension_to_split,
activations_batch_dim, is_backprop, is_rhs));
activations = retval.instr;
std::vector<int64> transpose_dims = retval.transpose_dims;
@ -1903,7 +2087,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
.window_dilation();
auto original_conv_dims = convolution->convolution_dimension_numbers();
const int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions(
int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution));
auto kernel_old = convolution->mutable_operand(1);
const int64 old_kernel_split_dim_size =
@ -1914,18 +2098,24 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
int64 old_split_dim_size = activations_old->shape().dimensions(old_space_dim);
int64 old_batch_dim = original_conv_dims.input_feature_dimension();
int64 kernel_old_batch_dim =
original_conv_dims.kernel_input_feature_dimension();
const int64 old_batch_size =
activations_old->shape().dimensions(old_batch_dim);
CHECK(old_to_new_instrs_.contains(kernel_old));
auto kernel_new = old_to_new_instrs_[kernel_old];
auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
CHECK(old_to_new_instrs_.contains(kernel_old) ||
old_to_new_instrs_.contains(activations_old));
HloInstruction* activations_new = nullptr;
HloInstruction* kernel_new = nullptr;
bool activations_locally_space_to_batched = false;
bool kernel_locally_space_to_batched = false;
std::vector<int64> permute_dims_kernel, permute_dims;
// If activations were no space-to-batched, we space-to-batch them below.
if (!old_to_new_instrs_.contains(activations_old)) {
kernel_new = old_to_new_instrs_[kernel_old];
permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
VLOG(1) << "Space-to-batching activations to enable space-to-depth";
const int64 prev_feature_dim = original_conv_dims.input_feature_dimension();
const int64 prev_batch_dim = original_conv_dims.input_batch_dimension();
@ -1960,14 +2150,63 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
VLOG(3) << "New Activations " << retval.first->ToString();
activations_locally_space_to_batched = true;
} else if (!old_to_new_instrs_.contains(kernel_old)) {
activations_new = old_to_new_instrs_[activations_old];
permute_dims = instr_to_dim_permute_map_[activations_new];
VLOG(1) << "Space-to-batching kernel to enable space-to-depth";
const int64 prev_feature_dim =
original_conv_dims.kernel_input_feature_dimension();
const int64 prev_output_feature_dim =
original_conv_dims.kernel_output_feature_dimension();
// TODO(b/168316428): The instr_to_dim_map_ is set incorrectly here, but it
// doesn't matter since it is never used. Investigate further to see if just
// not setting it works.
instr_to_dim_map_[kernel_old] =
std::make_pair(prev_feature_dim, prev_output_feature_dim);
const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
const int64 new_split_dim_size =
activations_new->shape().dimensions(new_space_dim);
const int64 needed_spatial_size =
CeilOfRatio(new_split_dim_size, rhs_dilation);
int64 old_kernel_split_dim_size =
kernel_old->shape().dimensions(kernel_space_dim);
const int64 pad_size =
needed_spatial_size * kNumSplits - old_kernel_split_dim_size;
ConvolutionDimensionNumbers tmp_dim_numbers;
tmp_dim_numbers = original_conv_dims;
TF_ASSIGN_OR_RETURN(
auto retval, SplitSpace(kernel_old, tmp_dim_numbers, kernel_space_dim,
kernel_old_batch_dim,
/*high_padding=*/pad_size, /*low_padding=*/0,
needed_spatial_size, kNumSplits,
/*is_backprop=*/true, /*is_rhs=*/true));
old_to_new_instrs_[kernel_old] = retval.first;
std::vector<int64> reversed_transpose_dims(retval.second.size());
for (int64 i = 0; i < retval.second.size(); ++i) {
reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
}
instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
VLOG(3) << "New kernel " << retval.first->ToString();
kernel_locally_space_to_batched = true;
}
CHECK(old_to_new_instrs_.contains(activations_old));
CHECK(old_to_new_instrs_.contains(kernel_old));
activations_new = old_to_new_instrs_[activations_old];
kernel_new = old_to_new_instrs_[kernel_old];
const int64 new_spatial_dimension =
activations_new->shape().dimensions_size();
auto permute_dims = instr_to_dim_permute_map_[activations_new];
permute_dims = instr_to_dim_permute_map_[activations_new];
permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
auto permuted_conv_dims_numbers = original_conv_dims;
// Note the inversion here : batch and feature are inverted in backprop
@ -2024,9 +2263,12 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
permuted_conv_dims_numbers.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution));
const int64 new_split_dim_size =
int64 new_split_dim_size =
activations_new->shape().dimensions(spatial_dimension_to_split);
const int64 kernel_new_split_dim_size =
kernel_new->shape().dimensions(kernel_spatial_dimension_to_split);
permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim);
permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim);
@ -2040,6 +2282,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
spatial_dimension_to_split, activations_batch_dim,
/*is_backprop=*/true));
std::vector<int64> transpose_dims = retval.transpose_dims;
CHECK(!transpose_dims.empty());
activations_new = retval.instr;
@ -2048,6 +2291,17 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
<< activations_new->ToString();
VLOG(1) << "activations_batch_dim " << activations_batch_dim
<< " activations_feature_dim " << activations_feature_dim;
const int64 expected_split_dim_size =
rhs_dilation * kernel_new_split_dim_size;
if (new_split_dim_size != expected_split_dim_size) {
CHECK_LT(new_split_dim_size, expected_split_dim_size);
new_split_dim_size = expected_split_dim_size;
TF_ASSIGN_OR_RETURN(
activations_new,
IncreaseSpatialSizeOnSpaceToBatchedShape(
activations_new, activations_batch_dim, old_batch_size,
spatial_dimension_to_split, new_split_dim_size));
}
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations_new->shape().element_type())));
@ -2060,16 +2314,18 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
activations_batch_dim, spatial_dimension_to_split,
old_batch_dim, old_space_dim));
}
VLOG(3) << "Selecting the valid kernel area";
// Select kernel correctly by masking additional space.
TF_ASSIGN_OR_RETURN(
kernel_new,
SelectValidPortion(
kernel_new, kernel_old, select_val,
/*new_batch_dim=*/kernel_input_feature_dim,
kernel_spatial_dimension_to_split,
/*old_batch_dim=*/original_conv_dims.kernel_input_feature_dimension(),
kernel_space_dim));
if (!kernel_locally_space_to_batched) {
VLOG(3) << "Selecting the valid kernel area";
// Select kernel correctly by masking additional space.
TF_ASSIGN_OR_RETURN(
kernel_new,
SelectValidPortion(kernel_new, kernel_old, select_val,
/*new_batch_dim=*/kernel_input_feature_dim,
kernel_spatial_dimension_to_split,
/*old_batch_dim=*/
original_conv_dims.kernel_input_feature_dimension(),
kernel_space_dim));
}
// Create the new convolution dim numbers.
auto new_dim_numbers = permuted_conv_dims_numbers;
@ -2115,7 +2371,9 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
old_split_dim_size - expanded_kernel + 1 +
(inherent_low_padding < 0 ? inherent_low_padding : 0) +
(inherent_high_padding < 0 ? inherent_high_padding : 0);
VLOG(1) << "overlap_count " << overlap_count;
VLOG(1) << "overlap_count " << overlap_count << " inherent_low_padding "
<< inherent_low_padding << " inherent_high_padding "
<< inherent_high_padding;
// Insert original activations.
for (int64 i = 0; i < overlap_count; ++i) {
@ -2190,7 +2448,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_low(0);
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_size(new_split_dim_size / rhs_dilation);
->set_size(CeilOfRatio(new_split_dim_size, rhs_dilation));
// Set the window for the additional spatial dim. This is a vanilla window.
auto window_dim = new_window.add_dimensions();
@ -2211,6 +2469,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
/*preferred_element_type=*/convolution->shape().element_type()));
convolution->SetupDerivedInstruction(new_conv);
VLOG(2) << "New backprop filter convolution " << new_conv->ToString();
std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
new_conv->shape().dimensions().end());
@ -2364,8 +2624,8 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
if (reduce_window_or_select_and_scatter != nullptr) {
VLOG(2) << "DoesConvolutionFeedReduceWindowOrSelectAndScatter "
<< reduce_window_or_select_and_scatter;
VLOG(2)
<< "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
// Take into account the stride of the reduce window while choosing the
// spatial_split_size. This will guarantee propagation through reduce
// windows.
@ -2457,6 +2717,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
/*preferred_element_type=*/convolution->shape().element_type()));
convolution->SetupDerivedInstruction(new_conv);
// If the activations were to be batch-to-spaced again, simply use the
// original value.
batch_to_space_map_[convolution->mutable_operand(0)] =
convolution->mutable_operand(0);
VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
const int64 output_split_spatial_dim =
@ -2483,7 +2748,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
get_chosen_spatial_dim(original_conv)));
instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims);
if (non_propagatable_instrs_.count(convolution) > 0) {
non_propagatable_instrs_.erase(convolution);
}
TF_CHECK_OK(PropagateOnUsers(original_conv));
changed_ = true;