[XLA] Enable RHS space-to-batch on backprop filter convolutions.

PiperOrigin-RevId: 350873443
Change-Id: I1388f4fe5f151c10062326d1b08d36c8eaec860c
This commit is contained in:
A. Unique TensorFlower 2021-01-08 18:44:59 -08:00 committed by TensorFlower Gardener
parent 5975aae33f
commit 5b9181194e

View File

@ -82,7 +82,8 @@ class ConvolutionVisitor {
// Function that determines if space-to-batch can be propagated into the // Function that determines if space-to-batch can be propagated into the
// consumer. Such propagation is only possible when all required operands are // consumer. Such propagation is only possible when all required operands are
// space-to-batch'ed. // space-to-batch'ed.
bool CanPropagate(HloInstruction* consumer, HloInstruction* producer); bool CanPropagate(HloInstruction* consumer, HloInstruction* producer,
bool last_try = false);
// Returns true if the op has all its direct and indirect operands being // Returns true if the op has all its direct and indirect operands being
// created via broadcasts. Consumer uses op, and is space-to-batched. // created via broadcasts. Consumer uses op, and is space-to-batched.
@ -116,7 +117,7 @@ class ConvolutionVisitor {
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
int64& spatial_dimension_to_split, int64& activations_batch_dim, int64& spatial_dimension_to_split, int64& activations_batch_dim,
int64 high_padding, int64 low_padding, int64 spatial_split_size, int64 high_padding, int64 low_padding, int64 spatial_split_size,
int64 num_splits, bool is_backprop = false); int64 num_splits, bool is_backprop = false, bool is_rhs = false);
// Perform space-to-batch propagation on the convolution. Assumes the // Perform space-to-batch propagation on the convolution. Assumes the
// activations were already space-to-batched. // activations were already space-to-batched.
@ -149,7 +150,13 @@ class ConvolutionVisitor {
StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch( StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch(
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
int64& spatial_dimension_to_split, int64& activations_batch_dim, int64& spatial_dimension_to_split, int64& activations_batch_dim,
bool is_backprop = false); bool is_backprop = false, bool is_rhs = false);
// Increases the spatial dimension size in an already space-to-batched shape
// so that the new size is new_spatial_dim_size.
StatusOr<HloInstruction*> IncreaseSpatialSizeOnSpaceToBatchedShape(
HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
int64 spatial_dimension, int64 new_spatial_dim_size);
// Function that converts spaced-to-batch shape back to the original. // Function that converts spaced-to-batch shape back to the original.
StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr); StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr);
@ -213,6 +220,10 @@ class ConvolutionVisitor {
absl::flat_hash_map<HloInstruction*, std::vector<int64>> absl::flat_hash_map<HloInstruction*, std::vector<int64>>
instr_to_dim_permute_map_; instr_to_dim_permute_map_;
// Map maintaining previously space-to-batched broadcasts.
absl::flat_hash_map<HloInstruction*, absl::flat_hash_set<HloInstruction*>>
broadcast_map_;
// Whether rewrite has occurred. // Whether rewrite has occurred.
bool changed_ = false; bool changed_ = false;
@ -456,7 +467,7 @@ StatusOr<ConvolutionVisitor::SpaceNextToBatchDetails>
ConvolutionVisitor::BringSpaceNextToBatch( ConvolutionVisitor::BringSpaceNextToBatch(
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers, HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
int64& spatial_dimension_to_split, int64& activations_batch_dim, int64& spatial_dimension_to_split, int64& activations_batch_dim,
bool is_backprop) { bool is_backprop, bool is_rhs) {
std::vector<int64> transpose_dims(activations->shape().rank()); std::vector<int64> transpose_dims(activations->shape().rank());
if (spatial_dimension_to_split == activations_batch_dim + 1) { if (spatial_dimension_to_split == activations_batch_dim + 1) {
absl::c_iota(transpose_dims, 0); absl::c_iota(transpose_dims, 0);
@ -465,49 +476,137 @@ ConvolutionVisitor::BringSpaceNextToBatch(
int64 pushed_counter = 0; int64 pushed_counter = 0;
int64 new_batch_dim, new_spatial_dim; int64 new_batch_dim, new_spatial_dim;
int64 dim_counter = 0; int64 dim_counter = 0;
for (int i = 0; i < activations->shape().rank(); ++i) { if (is_rhs) {
if (i == activations_batch_dim) { CHECK(is_backprop);
continue; for (int i = 0; i < activations->shape().rank(); ++i) {
} if (i == activations_batch_dim) {
if (i == spatial_dimension_to_split) { continue;
transpose_dims[dim_counter++] = activations_batch_dim; }
new_batch_dim = pushed_counter; if (i == spatial_dimension_to_split) {
pushed_counter++; transpose_dims[dim_counter++] = activations_batch_dim;
new_spatial_dim = pushed_counter; new_batch_dim = pushed_counter;
} pushed_counter++;
new_spatial_dim = pushed_counter;
}
if (is_backprop && i == dim_numbers.input_batch_dimension()) { if (i == dim_numbers.kernel_output_feature_dimension()) {
new_dim_numbers.set_input_batch_dimension(pushed_counter); new_dim_numbers.set_kernel_output_feature_dimension(pushed_counter);
} else if (i == dim_numbers.input_feature_dimension()) { } else {
new_dim_numbers.set_input_feature_dimension(pushed_counter); auto it = absl::c_find(dim_numbers.kernel_spatial_dimensions(), i);
} else { if (it != dim_numbers.kernel_spatial_dimensions().end()) {
for (int j = 0; j < dim_numbers.input_spatial_dimensions_size(); ++j) { int64 j = it - dim_numbers.kernel_spatial_dimensions().begin();
if (i == dim_numbers.input_spatial_dimensions(j)) { new_dim_numbers.set_kernel_spatial_dimensions(j, pushed_counter);
new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
break;
} }
} }
transpose_dims[dim_counter++] = i;
pushed_counter++;
} }
transpose_dims[dim_counter++] = i;
pushed_counter++;
}
activations_batch_dim = new_batch_dim; activations_batch_dim = new_batch_dim;
spatial_dimension_to_split = new_spatial_dim; spatial_dimension_to_split = new_spatial_dim;
TF_ASSIGN_OR_RETURN(activations, TF_ASSIGN_OR_RETURN(activations,
MakeTransposeHlo(activations, transpose_dims)); MakeTransposeHlo(activations, transpose_dims));
new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim);
if (is_backprop) {
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
} else { } else {
new_dim_numbers.set_input_batch_dimension(activations_batch_dim); for (int i = 0; i < activations->shape().rank(); ++i) {
if (i == activations_batch_dim) {
continue;
}
if (i == spatial_dimension_to_split) {
transpose_dims[dim_counter++] = activations_batch_dim;
new_batch_dim = pushed_counter;
pushed_counter++;
new_spatial_dim = pushed_counter;
}
if (is_backprop && i == dim_numbers.input_batch_dimension()) {
new_dim_numbers.set_input_batch_dimension(pushed_counter);
} else if (i == dim_numbers.input_feature_dimension()) {
new_dim_numbers.set_input_feature_dimension(pushed_counter);
} else {
auto it = absl::c_find(dim_numbers.input_spatial_dimensions(), i);
if (it != dim_numbers.input_spatial_dimensions().end()) {
int64 j = it - dim_numbers.input_spatial_dimensions().begin();
new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
}
}
transpose_dims[dim_counter++] = i;
pushed_counter++;
}
activations_batch_dim = new_batch_dim;
spatial_dimension_to_split = new_spatial_dim;
TF_ASSIGN_OR_RETURN(activations,
MakeTransposeHlo(activations, transpose_dims));
if (is_backprop) {
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
} else {
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
}
} }
dim_numbers = new_dim_numbers; dim_numbers = new_dim_numbers;
} }
return SpaceNextToBatchDetails{activations, transpose_dims}; return SpaceNextToBatchDetails{activations, transpose_dims};
} }
StatusOr<HloInstruction*>
ConvolutionVisitor::IncreaseSpatialSizeOnSpaceToBatchedShape(
HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
int64 spatial_dimension, int64 new_spatial_dim_size) {
CHECK_EQ(batch_dimension + 1, spatial_dimension);
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
activations->shape().dimensions().end());
const int64 new_batch_size = activations->shape().dimensions(batch_dimension);
int64 spatial_dim_size = activations->shape().dimensions(spatial_dimension);
const int64 reshaped_space_size =
spatial_dim_size * new_batch_size / old_batch_size;
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
<< new_batch_size << " old_batch_size " << old_batch_size;
new_dimensions[spatial_dimension] = reshaped_space_size;
new_dimensions[batch_dimension] = old_batch_size;
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
MakeReshapeHlo(new_dimensions, activations));
VLOG(3) << "First reshape done";
PaddingConfig padding_config =
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
padding_config.mutable_dimensions(spatial_dimension)
->set_edge_padding_high(new_spatial_dim_size * new_batch_size /
old_batch_size -
reshaped_space_size);
padding_config.mutable_dimensions(spatial_dimension)->set_edge_padding_low(0);
HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(reshaped_activations->shape().element_type())));
TF_ASSIGN_OR_RETURN(
reshaped_activations,
MakePadHlo(reshaped_activations, padding, padding_config));
std::vector<int64> reshape_back_dims(
reshaped_activations->shape().dimensions().begin(),
reshaped_activations->shape().dimensions().end());
reshape_back_dims[spatial_dimension] = new_spatial_dim_size;
reshape_back_dims[batch_dimension] = new_batch_size;
TF_ASSIGN_OR_RETURN(HloInstruction * activations_new,
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
VLOG(3) << "Size increased activations " << activations_new->ToString();
return activations_new;
}
StatusOr<bool> ConvolutionVisitor::Run() { StatusOr<bool> ConvolutionVisitor::Run() {
for (auto conv : conv_visitor_list_) { for (auto conv : conv_visitor_list_) {
if (convs_to_visit_.count(conv) > 0) { if (convs_to_visit_.count(conv) > 0) {
@ -519,6 +618,29 @@ StatusOr<bool> ConvolutionVisitor::Run() {
// Iterate through all instructions that we could not propagate through, and // Iterate through all instructions that we could not propagate through, and
// turn their operands from batch-to-space as needed. // turn their operands from batch-to-space as needed.
for (auto instr : non_propagatable_instrs_) { for (auto instr : non_propagatable_instrs_) {
if (instr->opcode() == HloOpcode::kConvolution) {
VLOG(1) << "Instr " << instr->ToString();
}
// Try to propagate on backprop filters
if (instr->opcode() == HloOpcode::kConvolution &&
!IsConvSuitableForSpaceToBatch(instr)) {
HloInstruction* producer = nullptr;
if (old_to_new_instrs_.contains(instr->mutable_operand(0))) {
producer = instr->mutable_operand(0);
} else if (old_to_new_instrs_.contains(instr->mutable_operand(1))) {
producer = instr->mutable_operand(1);
}
if (producer) {
if (CanPropagate(instr, producer, /*last_try=*/true)) {
bool needs_further_propagation;
TF_ASSIGN_OR_RETURN(needs_further_propagation,
Propagate(instr, producer));
TF_CHECK_OK(computation_->ReplaceInstruction(
instr, old_to_new_instrs_[instr]));
continue;
}
}
}
VLOG(1) << "Could not eventually propagate through " << instr->ToString(); VLOG(1) << "Could not eventually propagate through " << instr->ToString();
absl::flat_hash_map<int64, HloInstruction*> operand_map; absl::flat_hash_map<int64, HloInstruction*> operand_map;
for (int64 i = 0; i < instr->operand_count(); ++i) { for (int64 i = 0; i < instr->operand_count(); ++i) {
@ -539,14 +661,14 @@ bool IsTrivialElementwise(HloInstruction* hlo) {
if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng || if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng ||
hlo->opcode() == HloOpcode::kCopy || hlo->opcode() == HloOpcode::kCopy ||
hlo->opcode() == HloOpcode::kConstant || hlo->opcode() == HloOpcode::kConstant ||
hlo->opcode() == HloOpcode::kIota) { hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) {
return false; return false;
} }
return hlo->IsElementwise(); return hlo->IsElementwise();
} }
bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer, bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
HloInstruction* producer) { HloInstruction* producer, bool last_try) {
if (IsTrivialElementwise(consumer)) { if (IsTrivialElementwise(consumer)) {
VLOG(2) << "Doing propagation check on elementwise op: " VLOG(2) << "Doing propagation check on elementwise op: "
<< consumer->ToString(); << consumer->ToString();
@ -604,7 +726,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
// Make sure all other dimensions are of the same size. // Make sure all other dimensions are of the same size.
if (pivot_new_instr->shape().dimensions(j) != if (pivot_new_instr->shape().dimensions(j) !=
new_instr->shape().dimensions(j)) { new_instr->shape().dimensions(j)) {
if (!(consumer->IsElementwiseBinary() && if (!((consumer->IsElementwiseBinary() ||
consumer->opcode() == HloOpcode::kSelect) &&
j == instr_to_dim_map_[pivot_operand].second)) { j == instr_to_dim_map_[pivot_operand].second)) {
VLOG(2) << "Elementwise op: checking for shape equivalence " VLOG(2) << "Elementwise op: checking for shape equivalence "
<< consumer->ToString() << consumer->ToString()
@ -653,20 +776,70 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
return false; return false;
} }
// Same reason why we give up on batch group counts applies to features in
// backprop.
if (consumer->feature_group_count() != 1) {
return false;
}
VLOG(2) << "Checking for backprop filter conv propagatability"; VLOG(2) << "Checking for backprop filter conv propagatability";
CHECK_EQ(consumer->operand_count(), 2); CHECK_EQ(consumer->operand_count(), 2);
VLOG(2) << "Checking for backprop filter conv operands "
<< consumer->operand_count();
auto activations = consumer->mutable_operand(0); auto activations = consumer->mutable_operand(0);
auto kernel = consumer->mutable_operand(1); auto kernel = consumer->mutable_operand(1);
if (!old_to_new_instrs_.contains(kernel)) { if (!last_try) {
VLOG(2) << "Backprop filter conv not ready for propagation because of " if (!old_to_new_instrs_.contains(kernel) ||
"kernel is not space-to-batched"; !old_to_new_instrs_.contains(activations)) {
return false;
}
}
if (!old_to_new_instrs_.contains(kernel) &&
!old_to_new_instrs_.contains(activations)) {
return false; return false;
} }
const int64 rhs_dilation = consumer->window()
.dimensions(get_chosen_spatial_dim(consumer))
.window_dilation();
if (!old_to_new_instrs_.contains(kernel)) {
const int64 rhs_batch =
kernel->shape().dimensions(consumer->convolution_dimension_numbers()
.kernel_input_feature_dimension());
auto dim_map_val_op_0 = instr_to_dim_map_[activations];
const int64 old_batch_dim = dim_map_val_op_0.first;
const int64 old_space_dim = dim_map_val_op_0.second;
auto first_operand = old_to_new_instrs_[activations];
auto permute_dims_first_operand =
instr_to_dim_permute_map_[first_operand];
const int64 new_batch_dim =
DimLookUp(permute_dims_first_operand, old_batch_dim);
const int64 new_space_dim =
DimLookUp(permute_dims_first_operand, old_space_dim);
const int64 lhs_batch = first_operand->shape().dimensions(new_batch_dim);
if (first_operand->shape().dimensions(new_space_dim) % rhs_dilation !=
0) {
return false;
}
// Because we want to convert activations into a space-to-batched version
// only for backprop filter convolutions, we want to make sure that the
// batch dimensions (feature dimensions, technically) are same sized.
// Since LHS is already space-to-batched, we need to account for it too.
if (rhs_batch * kNumSplits != lhs_batch) {
return false;
}
// If kernel have not been propagated through, we can do
// space-to-batch on them provided kernel has been propagated.
VLOG(2)
<< "Backprop filter conv ready for propagation: activations ready, "
" kernel will be space-to-batched";
return true;
}
if (!old_to_new_instrs_.contains(activations)) { if (!old_to_new_instrs_.contains(activations)) {
const int64 lhs_batch = activations->shape().dimensions( const int64 lhs_batch = activations->shape().dimensions(
consumer->convolution_dimension_numbers().input_feature_dimension()); consumer->convolution_dimension_numbers().input_feature_dimension());
@ -720,10 +893,7 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
return false; return false;
} }
const int64 rhs_dilation = consumer->window() if (first_operand->shape().dimensions(new_space_dim_operand_0) >
.dimensions(get_chosen_spatial_dim(consumer))
.window_dilation();
if (first_operand->shape().dimensions(new_space_dim_operand_0) !=
rhs_dilation * rhs_dilation *
second_operand->shape().dimensions(new_space_dim_operand_1)) { second_operand->shape().dimensions(new_space_dim_operand_1)) {
VLOG(2) << "Backprop filter conv not ready for propagation because of " VLOG(2) << "Backprop filter conv not ready for propagation because of "
@ -816,20 +986,57 @@ void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer,
auto new_producer = old_to_new_instrs_[producer]; auto new_producer = old_to_new_instrs_[producer];
auto permute_dims = instr_to_dim_permute_map_[new_producer]; auto permute_dims = instr_to_dim_permute_map_[new_producer];
auto dim_map_val = instr_to_dim_map_[producer]; auto dim_map_val = instr_to_dim_map_[producer];
const int64 old_batch_dim = dim_map_val.first;
const int64 old_space_dim = dim_map_val.second;
auto orig_broadcast_dims = consumer->dimensions();
bool batch_is_broadcasted =
absl::c_linear_search(orig_broadcast_dims, old_batch_dim);
const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
bool map_found = broadcast_map_.contains(consumer);
if (map_found) {
// Check if we previously had created the same broadcast.
for (auto previous_broadcast : broadcast_map_[consumer]) {
if (ShapeUtil::CompatibleIgnoringElementType(previous_broadcast->shape(),
new_producer->shape())) {
return;
}
}
}
std::vector<int64> final_shape_dims(
new_producer->shape().dimensions().begin(),
new_producer->shape().dimensions().end());
if (batch_is_broadcasted) {
final_shape_dims[new_batch_dim] =
producer->shape().dimensions(old_batch_dim);
final_shape_dims[new_space_dim] *= kNumSplits;
}
std::vector<int64> broadcast_dims; std::vector<int64> broadcast_dims;
for (auto j : consumer->dimensions()) { for (auto j : consumer->dimensions()) {
broadcast_dims.push_back(DimLookUp(permute_dims, j)); broadcast_dims.push_back(DimLookUp(permute_dims, j));
} }
auto new_broadcast = auto new_broadcast = MakeBroadcastHlo(consumer->mutable_operand(0),
MakeBroadcastHlo(consumer->mutable_operand(0), broadcast_dims, broadcast_dims, final_shape_dims);
new_producer->shape().dimensions());
VLOG(1) << "Created broadcast " << new_broadcast->ToString(); VLOG(1) << "Created broadcast " << new_broadcast->ToString();
// Pass on the permutation information from the producer. if (batch_is_broadcasted) {
old_to_new_instrs_[consumer] = new_broadcast; new_broadcast =
instr_to_dim_map_[consumer] = dim_map_val; MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast)
instr_to_dim_permute_map_[new_broadcast] = std::vector<int64>( .ValueOrDie();
instr_to_dim_permute_map_[old_to_new_instrs_[producer]]); VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString();
}
if (!map_found) {
absl::flat_hash_set<HloInstruction*> set_of_broadcasts;
broadcast_map_[consumer] = set_of_broadcasts;
}
broadcast_map_[consumer].insert(new_broadcast);
} }
void ConvolutionVisitor::RewriteBroadcastTree( void ConvolutionVisitor::RewriteBroadcastTree(
@ -882,11 +1089,9 @@ bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast,
CHECK(instr_to_dim_map_.contains(old_other_op)); CHECK(instr_to_dim_map_.contains(old_other_op));
auto result = instr_to_dim_map_[old_other_op]; auto result = instr_to_dim_map_[old_other_op];
const int64 batch_dim = result.first;
const int64 space_dim = result.second; const int64 space_dim = result.second;
auto broadcast_dims = broadcast->dimensions(); auto broadcast_dims = broadcast->dimensions();
return !absl::c_linear_search(broadcast_dims, batch_dim) && return !absl::c_linear_search(broadcast_dims, space_dim);
!absl::c_linear_search(broadcast_dims, space_dim);
} }
bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
@ -990,27 +1195,50 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
// For elementwise binary ops, both of whose operands have been space-to- // For elementwise binary ops, both of whose operands have been space-to-
// batched, if their new spatial sizes don't match, choose the bigger one // batched, if their new spatial sizes don't match, choose the bigger one
// as the producer. // as the producer.
if (consumer->IsElementwiseBinary() && if (consumer->IsElementwiseBinary() ||
old_to_new_instrs_.contains(consumer->mutable_operand(0)) && consumer->opcode() == HloOpcode::kSelect) {
old_to_new_instrs_.contains(consumer->mutable_operand(1))) { int64 pivot_operand_number = -1;
is_pivot_producer_modified = true; HloInstruction* pivot_operand = nullptr;
if (old_to_new_instrs_[consumer->mutable_operand(0)] for (int i = 0; i < consumer->operand_count(); ++i) {
->shape() if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
.dimensions() > old_to_new_instrs_[consumer->mutable_operand(1)] continue;
->shape() }
.dimensions()) { auto operand = consumer->mutable_operand(i);
producer = consumer->mutable_operand(0); if (old_to_new_instrs_.contains(operand)) {
} else { if (pivot_operand_number == -1 ||
producer = consumer->mutable_operand(1); old_to_new_instrs_[pivot_operand]->shape().dimensions() <
old_to_new_instrs_[operand]->shape().dimensions()) {
is_pivot_producer_modified = true;
pivot_operand_number = i;
pivot_operand = consumer->mutable_operand(pivot_operand_number);
}
}
}
if (pivot_operand_number != -1) {
producer = pivot_operand;
} }
} }
for (int64 i = 0; i < consumer->operand_count(); ++i) { for (int64 i = 0; i < consumer->operand_count(); ++i) {
std::vector<HloInstruction*> instructions_to_transform; std::vector<HloInstruction*> instructions_to_transform;
if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) { if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
auto broadcast = consumer->mutable_operand(i);
PropagateOnBroadcast(broadcast, producer);
HloInstruction* new_broadcast = nullptr;
auto new_producer = old_to_new_instrs_[producer];
for (auto previous_broadcast : broadcast_map_[broadcast]) {
if (ShapeUtil::CompatibleIgnoringElementType(
previous_broadcast->shape(), new_producer->shape())) {
new_broadcast = previous_broadcast;
break;
}
}
CHECK_NE(new_broadcast, nullptr);
TF_CHECK_OK(
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
} else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
HloInstruction* operand_to_use = nullptr; HloInstruction* operand_to_use = nullptr;
auto result = instr_to_dim_map_[producer]; auto result = instr_to_dim_map_[producer];
const int64 old_batch_dim = result.first; const int64 old_batch_dim = result.first;
const int64 old_space_dim = result.second; const int64 old_space_dim = result.second;
@ -1070,13 +1298,6 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
} }
TF_CHECK_OK( TF_CHECK_OK(
new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use)); new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
} else if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
if (!old_to_new_instrs_.contains(consumer->operand(i))) {
PropagateOnBroadcast(consumer->mutable_operand(i), producer);
}
auto new_broadcast = old_to_new_instrs_[consumer->mutable_operand(i)];
TF_CHECK_OK(
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
} else if (consumer->IsElementwiseBinary() && } else if (consumer->IsElementwiseBinary() &&
IsBroadcastTree(consumer->mutable_operand(i), producer, IsBroadcastTree(consumer->mutable_operand(i), producer,
instructions_to_transform)) { instructions_to_transform)) {
@ -1463,8 +1684,10 @@ StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace( StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
HloInstruction* old_instr) { HloInstruction* old_instr) {
if (batch_to_space_map_.count(old_instr)) { if (batch_to_space_map_.count(old_instr)) {
CHECK_NE(batch_to_space_map_[old_instr], nullptr);
return batch_to_space_map_[old_instr]; return batch_to_space_map_[old_instr];
} }
auto result = instr_to_dim_map_[old_instr]; auto result = instr_to_dim_map_[old_instr];
const int64 old_batch_dim = result.first; const int64 old_batch_dim = result.first;
const int64 old_space_dim = result.second; const int64 old_space_dim = result.second;
@ -1683,68 +1906,28 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size " VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
<< slice_size; << slice_size;
const int64 new_batch_size =
activations_new->shape().dimensions(activations_batch_dim);
const int64 new_space_size = const int64 new_space_size =
activations_new->shape().dimensions(c.spatial_dimension_to_split); activations_new->shape().dimensions(c.spatial_dimension_to_split);
// In the below case, we cannot use the activations directly for Halo // In the below case, we cannot use the activations directly for Halo
// Duplication. We must reshape them. // Duplication. We must reshape them.
if (spatial_split_size > new_space_size) { if (spatial_split_size > new_space_size) {
std::vector<int64> new_dimensions(
activations_new->shape().dimensions().begin(),
activations_new->shape().dimensions().end());
const int64 reshaped_space_size =
new_space_size * new_batch_size / old_batch_size;
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
<< new_batch_size << " old_batch_size " << old_batch_size;
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
new_dimensions[activations_batch_dim] = old_batch_size;
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
MakeReshapeHlo(new_dimensions, activations_new));
VLOG(3) << "First reshape done";
PaddingConfig padding_config =
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_high(spatial_split_size * new_batch_size /
old_batch_size -
reshaped_space_size);
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_low(0);
HloInstruction* padding =
computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(reshaped_activations->shape().element_type())));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
reshaped_activations, activations_new,
MakePadHlo(reshaped_activations, padding, padding_config)); IncreaseSpatialSizeOnSpaceToBatchedShape(
activations_new, activations_batch_dim, old_batch_size,
std::vector<int64> reshape_back_dims( c.spatial_dimension_to_split, spatial_split_size));
reshaped_activations->shape().dimensions().begin(),
reshaped_activations->shape().dimensions().end());
reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size;
reshape_back_dims[activations_batch_dim] = new_batch_size;
TF_ASSIGN_OR_RETURN(
reshaped_activations,
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
VLOG(3) << "Second reshape done";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
activations_new, activations_new,
HaloDuplicateWithSlice( HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
reshaped_activations, c.spatial_dimension_to_split, activations_batch_dim, old_batch_size,
activations_batch_dim, old_batch_size, /*low_padding=*/c.base_dilation_factor != 1 &&
/*low_padding=*/c.base_dilation_factor != 1 && c.inherent_low_padding != 0
c.inherent_low_padding != 0 ? c.base_dilation_factor - 1
? c.base_dilation_factor - 1 : c.inherent_low_padding,
: c.inherent_low_padding, c.inherent_high_padding,
c.inherent_high_padding, slice_size - spatial_split_size, slice_size - spatial_split_size,
old_split_dim_size)); old_split_dim_size));
} else { } else {
// If the ideal spatial_split_size was smaller than the incoming spatial // If the ideal spatial_split_size was smaller than the incoming spatial
// dimension size, we don't need reshaping. Instead, we determine the // dimension size, we don't need reshaping. Instead, we determine the
@ -1835,14 +2018,15 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
int64& spatial_dimension_to_split, int64& spatial_dimension_to_split,
int64& activations_batch_dim, int64 high_padding, int64& activations_batch_dim, int64 high_padding,
int64 low_padding, int64 spatial_split_size, int64 low_padding, int64 spatial_split_size,
int64 num_splits, bool is_backprop) { int64 num_splits, bool is_backprop,
bool is_rhs) {
const int64 old_batch_size = const int64 old_batch_size =
activations->shape().dimensions(activations_batch_dim); activations->shape().dimensions(activations_batch_dim);
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(auto retval,
auto retval, BringSpaceNextToBatch(activations, dim_numbers, BringSpaceNextToBatch(
spatial_dimension_to_split, activations, dim_numbers, spatial_dimension_to_split,
activations_batch_dim, is_backprop)); activations_batch_dim, is_backprop, is_rhs));
activations = retval.instr; activations = retval.instr;
std::vector<int64> transpose_dims = retval.transpose_dims; std::vector<int64> transpose_dims = retval.transpose_dims;
@ -1903,7 +2087,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
.window_dilation(); .window_dilation();
auto original_conv_dims = convolution->convolution_dimension_numbers(); auto original_conv_dims = convolution->convolution_dimension_numbers();
const int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions( int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution)); get_chosen_spatial_dim(convolution));
auto kernel_old = convolution->mutable_operand(1); auto kernel_old = convolution->mutable_operand(1);
const int64 old_kernel_split_dim_size = const int64 old_kernel_split_dim_size =
@ -1914,18 +2098,24 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
int64 old_split_dim_size = activations_old->shape().dimensions(old_space_dim); int64 old_split_dim_size = activations_old->shape().dimensions(old_space_dim);
int64 old_batch_dim = original_conv_dims.input_feature_dimension(); int64 old_batch_dim = original_conv_dims.input_feature_dimension();
int64 kernel_old_batch_dim =
original_conv_dims.kernel_input_feature_dimension();
const int64 old_batch_size = const int64 old_batch_size =
activations_old->shape().dimensions(old_batch_dim); activations_old->shape().dimensions(old_batch_dim);
CHECK(old_to_new_instrs_.contains(kernel_old)); CHECK(old_to_new_instrs_.contains(kernel_old) ||
auto kernel_new = old_to_new_instrs_[kernel_old]; old_to_new_instrs_.contains(activations_old));
auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
HloInstruction* activations_new = nullptr; HloInstruction* activations_new = nullptr;
HloInstruction* kernel_new = nullptr;
bool activations_locally_space_to_batched = false; bool activations_locally_space_to_batched = false;
bool kernel_locally_space_to_batched = false;
std::vector<int64> permute_dims_kernel, permute_dims;
// If activations were no space-to-batched, we space-to-batch them below. // If activations were no space-to-batched, we space-to-batch them below.
if (!old_to_new_instrs_.contains(activations_old)) { if (!old_to_new_instrs_.contains(activations_old)) {
kernel_new = old_to_new_instrs_[kernel_old];
permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
VLOG(1) << "Space-to-batching activations to enable space-to-depth"; VLOG(1) << "Space-to-batching activations to enable space-to-depth";
const int64 prev_feature_dim = original_conv_dims.input_feature_dimension(); const int64 prev_feature_dim = original_conv_dims.input_feature_dimension();
const int64 prev_batch_dim = original_conv_dims.input_batch_dimension(); const int64 prev_batch_dim = original_conv_dims.input_batch_dimension();
@ -1960,14 +2150,63 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
VLOG(3) << "New Activations " << retval.first->ToString(); VLOG(3) << "New Activations " << retval.first->ToString();
activations_locally_space_to_batched = true; activations_locally_space_to_batched = true;
} else if (!old_to_new_instrs_.contains(kernel_old)) {
activations_new = old_to_new_instrs_[activations_old];
permute_dims = instr_to_dim_permute_map_[activations_new];
VLOG(1) << "Space-to-batching kernel to enable space-to-depth";
const int64 prev_feature_dim =
original_conv_dims.kernel_input_feature_dimension();
const int64 prev_output_feature_dim =
original_conv_dims.kernel_output_feature_dimension();
// TODO(b/168316428): The instr_to_dim_map_ is set incorrectly here, but it
// doesn't matter since it is never used. Investigate further to see if just
// not setting it works.
instr_to_dim_map_[kernel_old] =
std::make_pair(prev_feature_dim, prev_output_feature_dim);
const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
const int64 new_split_dim_size =
activations_new->shape().dimensions(new_space_dim);
const int64 needed_spatial_size =
CeilOfRatio(new_split_dim_size, rhs_dilation);
int64 old_kernel_split_dim_size =
kernel_old->shape().dimensions(kernel_space_dim);
const int64 pad_size =
needed_spatial_size * kNumSplits - old_kernel_split_dim_size;
ConvolutionDimensionNumbers tmp_dim_numbers;
tmp_dim_numbers = original_conv_dims;
TF_ASSIGN_OR_RETURN(
auto retval, SplitSpace(kernel_old, tmp_dim_numbers, kernel_space_dim,
kernel_old_batch_dim,
/*high_padding=*/pad_size, /*low_padding=*/0,
needed_spatial_size, kNumSplits,
/*is_backprop=*/true, /*is_rhs=*/true));
old_to_new_instrs_[kernel_old] = retval.first;
std::vector<int64> reversed_transpose_dims(retval.second.size());
for (int64 i = 0; i < retval.second.size(); ++i) {
reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
}
instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
VLOG(3) << "New kernel " << retval.first->ToString();
kernel_locally_space_to_batched = true;
} }
CHECK(old_to_new_instrs_.contains(activations_old)); CHECK(old_to_new_instrs_.contains(activations_old));
CHECK(old_to_new_instrs_.contains(kernel_old));
activations_new = old_to_new_instrs_[activations_old]; activations_new = old_to_new_instrs_[activations_old];
kernel_new = old_to_new_instrs_[kernel_old];
const int64 new_spatial_dimension = const int64 new_spatial_dimension =
activations_new->shape().dimensions_size(); activations_new->shape().dimensions_size();
auto permute_dims = instr_to_dim_permute_map_[activations_new]; permute_dims = instr_to_dim_permute_map_[activations_new];
permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
auto permuted_conv_dims_numbers = original_conv_dims; auto permuted_conv_dims_numbers = original_conv_dims;
// Note the inversion here : batch and feature are inverted in backprop // Note the inversion here : batch and feature are inverted in backprop
@ -2024,9 +2263,12 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
permuted_conv_dims_numbers.kernel_spatial_dimensions( permuted_conv_dims_numbers.kernel_spatial_dimensions(
get_chosen_spatial_dim(convolution)); get_chosen_spatial_dim(convolution));
const int64 new_split_dim_size = int64 new_split_dim_size =
activations_new->shape().dimensions(spatial_dimension_to_split); activations_new->shape().dimensions(spatial_dimension_to_split);
const int64 kernel_new_split_dim_size =
kernel_new->shape().dimensions(kernel_spatial_dimension_to_split);
permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim); permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim);
permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim); permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim);
@ -2040,6 +2282,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers, BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
spatial_dimension_to_split, activations_batch_dim, spatial_dimension_to_split, activations_batch_dim,
/*is_backprop=*/true)); /*is_backprop=*/true));
std::vector<int64> transpose_dims = retval.transpose_dims; std::vector<int64> transpose_dims = retval.transpose_dims;
CHECK(!transpose_dims.empty()); CHECK(!transpose_dims.empty());
activations_new = retval.instr; activations_new = retval.instr;
@ -2048,6 +2291,17 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
<< activations_new->ToString(); << activations_new->ToString();
VLOG(1) << "activations_batch_dim " << activations_batch_dim VLOG(1) << "activations_batch_dim " << activations_batch_dim
<< " activations_feature_dim " << activations_feature_dim; << " activations_feature_dim " << activations_feature_dim;
const int64 expected_split_dim_size =
rhs_dilation * kernel_new_split_dim_size;
if (new_split_dim_size != expected_split_dim_size) {
CHECK_LT(new_split_dim_size, expected_split_dim_size);
new_split_dim_size = expected_split_dim_size;
TF_ASSIGN_OR_RETURN(
activations_new,
IncreaseSpatialSizeOnSpaceToBatchedShape(
activations_new, activations_batch_dim, old_batch_size,
spatial_dimension_to_split, new_split_dim_size));
}
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant( auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations_new->shape().element_type()))); LiteralUtil::Zero(activations_new->shape().element_type())));
@ -2060,16 +2314,18 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
activations_batch_dim, spatial_dimension_to_split, activations_batch_dim, spatial_dimension_to_split,
old_batch_dim, old_space_dim)); old_batch_dim, old_space_dim));
} }
VLOG(3) << "Selecting the valid kernel area"; if (!kernel_locally_space_to_batched) {
// Select kernel correctly by masking additional space. VLOG(3) << "Selecting the valid kernel area";
TF_ASSIGN_OR_RETURN( // Select kernel correctly by masking additional space.
kernel_new, TF_ASSIGN_OR_RETURN(
SelectValidPortion( kernel_new,
kernel_new, kernel_old, select_val, SelectValidPortion(kernel_new, kernel_old, select_val,
/*new_batch_dim=*/kernel_input_feature_dim, /*new_batch_dim=*/kernel_input_feature_dim,
kernel_spatial_dimension_to_split, kernel_spatial_dimension_to_split,
/*old_batch_dim=*/original_conv_dims.kernel_input_feature_dimension(), /*old_batch_dim=*/
kernel_space_dim)); original_conv_dims.kernel_input_feature_dimension(),
kernel_space_dim));
}
// Create the new convolution dim numbers. // Create the new convolution dim numbers.
auto new_dim_numbers = permuted_conv_dims_numbers; auto new_dim_numbers = permuted_conv_dims_numbers;
@ -2115,7 +2371,9 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
old_split_dim_size - expanded_kernel + 1 + old_split_dim_size - expanded_kernel + 1 +
(inherent_low_padding < 0 ? inherent_low_padding : 0) + (inherent_low_padding < 0 ? inherent_low_padding : 0) +
(inherent_high_padding < 0 ? inherent_high_padding : 0); (inherent_high_padding < 0 ? inherent_high_padding : 0);
VLOG(1) << "overlap_count " << overlap_count; VLOG(1) << "overlap_count " << overlap_count << " inherent_low_padding "
<< inherent_low_padding << " inherent_high_padding "
<< inherent_high_padding;
// Insert original activations. // Insert original activations.
for (int64 i = 0; i < overlap_count; ++i) { for (int64 i = 0; i < overlap_count; ++i) {
@ -2190,7 +2448,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_padding_low(0); ->set_padding_low(0);
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution)) new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
->set_size(new_split_dim_size / rhs_dilation); ->set_size(CeilOfRatio(new_split_dim_size, rhs_dilation));
// Set the window for the additional spatial dim. This is a vanilla window. // Set the window for the additional spatial dim. This is a vanilla window.
auto window_dim = new_window.add_dimensions(); auto window_dim = new_window.add_dimensions();
@ -2211,6 +2469,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
/*preferred_element_type=*/convolution->shape().element_type())); /*preferred_element_type=*/convolution->shape().element_type()));
convolution->SetupDerivedInstruction(new_conv); convolution->SetupDerivedInstruction(new_conv);
VLOG(2) << "New backprop filter convolution " << new_conv->ToString();
std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(), std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
new_conv->shape().dimensions().end()); new_conv->shape().dimensions().end());
@ -2364,8 +2624,8 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution); DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
if (reduce_window_or_select_and_scatter != nullptr) { if (reduce_window_or_select_and_scatter != nullptr) {
VLOG(2) << "DoesConvolutionFeedReduceWindowOrSelectAndScatter " VLOG(2)
<< reduce_window_or_select_and_scatter; << "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
// Take into account the stride of the reduce window while choosing the // Take into account the stride of the reduce window while choosing the
// spatial_split_size. This will guarantee propagation through reduce // spatial_split_size. This will guarantee propagation through reduce
// windows. // windows.
@ -2457,6 +2717,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
/*preferred_element_type=*/convolution->shape().element_type())); /*preferred_element_type=*/convolution->shape().element_type()));
convolution->SetupDerivedInstruction(new_conv); convolution->SetupDerivedInstruction(new_conv);
// If the activations were to be batch-to-spaced again, simply use the
// original value.
batch_to_space_map_[convolution->mutable_operand(0)] =
convolution->mutable_operand(0);
VLOG(1) << "Space-to-batched convolution " << new_conv->ToString(); VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
const int64 output_split_spatial_dim = const int64 output_split_spatial_dim =
@ -2483,7 +2748,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
get_chosen_spatial_dim(original_conv))); get_chosen_spatial_dim(original_conv)));
instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims); instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims);
if (non_propagatable_instrs_.count(convolution) > 0) {
non_propagatable_instrs_.erase(convolution);
}
TF_CHECK_OK(PropagateOnUsers(original_conv)); TF_CHECK_OK(PropagateOnUsers(original_conv));
changed_ = true; changed_ = true;