[XLA] Enable RHS space-to-batch on backprop filter convolutions.
PiperOrigin-RevId: 350873443 Change-Id: I1388f4fe5f151c10062326d1b08d36c8eaec860c
This commit is contained in:
parent
5975aae33f
commit
5b9181194e
@ -82,7 +82,8 @@ class ConvolutionVisitor {
|
|||||||
// Function that determines if space-to-batch can be propagated into the
|
// Function that determines if space-to-batch can be propagated into the
|
||||||
// consumer. Such propagation is only possible when all required operands are
|
// consumer. Such propagation is only possible when all required operands are
|
||||||
// space-to-batch'ed.
|
// space-to-batch'ed.
|
||||||
bool CanPropagate(HloInstruction* consumer, HloInstruction* producer);
|
bool CanPropagate(HloInstruction* consumer, HloInstruction* producer,
|
||||||
|
bool last_try = false);
|
||||||
|
|
||||||
// Returns true if the op has all its direct and indirect operands being
|
// Returns true if the op has all its direct and indirect operands being
|
||||||
// created via broadcasts. Consumer uses op, and is space-to-batched.
|
// created via broadcasts. Consumer uses op, and is space-to-batched.
|
||||||
@ -116,7 +117,7 @@ class ConvolutionVisitor {
|
|||||||
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
||||||
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
||||||
int64 high_padding, int64 low_padding, int64 spatial_split_size,
|
int64 high_padding, int64 low_padding, int64 spatial_split_size,
|
||||||
int64 num_splits, bool is_backprop = false);
|
int64 num_splits, bool is_backprop = false, bool is_rhs = false);
|
||||||
|
|
||||||
// Perform space-to-batch propagation on the convolution. Assumes the
|
// Perform space-to-batch propagation on the convolution. Assumes the
|
||||||
// activations were already space-to-batched.
|
// activations were already space-to-batched.
|
||||||
@ -149,7 +150,13 @@ class ConvolutionVisitor {
|
|||||||
StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch(
|
StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch(
|
||||||
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
||||||
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
||||||
bool is_backprop = false);
|
bool is_backprop = false, bool is_rhs = false);
|
||||||
|
|
||||||
|
// Increases the spatial dimension size in an already space-to-batched shape
|
||||||
|
// so that the new size is new_spatial_dim_size.
|
||||||
|
StatusOr<HloInstruction*> IncreaseSpatialSizeOnSpaceToBatchedShape(
|
||||||
|
HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
|
||||||
|
int64 spatial_dimension, int64 new_spatial_dim_size);
|
||||||
|
|
||||||
// Function that converts spaced-to-batch shape back to the original.
|
// Function that converts spaced-to-batch shape back to the original.
|
||||||
StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr);
|
StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr);
|
||||||
@ -213,6 +220,10 @@ class ConvolutionVisitor {
|
|||||||
absl::flat_hash_map<HloInstruction*, std::vector<int64>>
|
absl::flat_hash_map<HloInstruction*, std::vector<int64>>
|
||||||
instr_to_dim_permute_map_;
|
instr_to_dim_permute_map_;
|
||||||
|
|
||||||
|
// Map maintaining previously space-to-batched broadcasts.
|
||||||
|
absl::flat_hash_map<HloInstruction*, absl::flat_hash_set<HloInstruction*>>
|
||||||
|
broadcast_map_;
|
||||||
|
|
||||||
// Whether rewrite has occurred.
|
// Whether rewrite has occurred.
|
||||||
bool changed_ = false;
|
bool changed_ = false;
|
||||||
|
|
||||||
@ -456,7 +467,7 @@ StatusOr<ConvolutionVisitor::SpaceNextToBatchDetails>
|
|||||||
ConvolutionVisitor::BringSpaceNextToBatch(
|
ConvolutionVisitor::BringSpaceNextToBatch(
|
||||||
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
|
||||||
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
int64& spatial_dimension_to_split, int64& activations_batch_dim,
|
||||||
bool is_backprop) {
|
bool is_backprop, bool is_rhs) {
|
||||||
std::vector<int64> transpose_dims(activations->shape().rank());
|
std::vector<int64> transpose_dims(activations->shape().rank());
|
||||||
if (spatial_dimension_to_split == activations_batch_dim + 1) {
|
if (spatial_dimension_to_split == activations_batch_dim + 1) {
|
||||||
absl::c_iota(transpose_dims, 0);
|
absl::c_iota(transpose_dims, 0);
|
||||||
@ -465,49 +476,137 @@ ConvolutionVisitor::BringSpaceNextToBatch(
|
|||||||
int64 pushed_counter = 0;
|
int64 pushed_counter = 0;
|
||||||
int64 new_batch_dim, new_spatial_dim;
|
int64 new_batch_dim, new_spatial_dim;
|
||||||
int64 dim_counter = 0;
|
int64 dim_counter = 0;
|
||||||
for (int i = 0; i < activations->shape().rank(); ++i) {
|
if (is_rhs) {
|
||||||
if (i == activations_batch_dim) {
|
CHECK(is_backprop);
|
||||||
continue;
|
for (int i = 0; i < activations->shape().rank(); ++i) {
|
||||||
}
|
if (i == activations_batch_dim) {
|
||||||
if (i == spatial_dimension_to_split) {
|
continue;
|
||||||
transpose_dims[dim_counter++] = activations_batch_dim;
|
}
|
||||||
new_batch_dim = pushed_counter;
|
if (i == spatial_dimension_to_split) {
|
||||||
pushed_counter++;
|
transpose_dims[dim_counter++] = activations_batch_dim;
|
||||||
new_spatial_dim = pushed_counter;
|
new_batch_dim = pushed_counter;
|
||||||
}
|
pushed_counter++;
|
||||||
|
new_spatial_dim = pushed_counter;
|
||||||
|
}
|
||||||
|
|
||||||
if (is_backprop && i == dim_numbers.input_batch_dimension()) {
|
if (i == dim_numbers.kernel_output_feature_dimension()) {
|
||||||
new_dim_numbers.set_input_batch_dimension(pushed_counter);
|
new_dim_numbers.set_kernel_output_feature_dimension(pushed_counter);
|
||||||
} else if (i == dim_numbers.input_feature_dimension()) {
|
} else {
|
||||||
new_dim_numbers.set_input_feature_dimension(pushed_counter);
|
auto it = absl::c_find(dim_numbers.kernel_spatial_dimensions(), i);
|
||||||
} else {
|
if (it != dim_numbers.kernel_spatial_dimensions().end()) {
|
||||||
for (int j = 0; j < dim_numbers.input_spatial_dimensions_size(); ++j) {
|
int64 j = it - dim_numbers.kernel_spatial_dimensions().begin();
|
||||||
if (i == dim_numbers.input_spatial_dimensions(j)) {
|
new_dim_numbers.set_kernel_spatial_dimensions(j, pushed_counter);
|
||||||
new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
transpose_dims[dim_counter++] = i;
|
||||||
|
pushed_counter++;
|
||||||
}
|
}
|
||||||
transpose_dims[dim_counter++] = i;
|
|
||||||
pushed_counter++;
|
|
||||||
}
|
|
||||||
|
|
||||||
activations_batch_dim = new_batch_dim;
|
activations_batch_dim = new_batch_dim;
|
||||||
spatial_dimension_to_split = new_spatial_dim;
|
spatial_dimension_to_split = new_spatial_dim;
|
||||||
TF_ASSIGN_OR_RETURN(activations,
|
TF_ASSIGN_OR_RETURN(activations,
|
||||||
MakeTransposeHlo(activations, transpose_dims));
|
MakeTransposeHlo(activations, transpose_dims));
|
||||||
|
|
||||||
|
new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim);
|
||||||
|
|
||||||
if (is_backprop) {
|
|
||||||
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
|
|
||||||
} else {
|
} else {
|
||||||
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
|
for (int i = 0; i < activations->shape().rank(); ++i) {
|
||||||
|
if (i == activations_batch_dim) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (i == spatial_dimension_to_split) {
|
||||||
|
transpose_dims[dim_counter++] = activations_batch_dim;
|
||||||
|
new_batch_dim = pushed_counter;
|
||||||
|
pushed_counter++;
|
||||||
|
new_spatial_dim = pushed_counter;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_backprop && i == dim_numbers.input_batch_dimension()) {
|
||||||
|
new_dim_numbers.set_input_batch_dimension(pushed_counter);
|
||||||
|
} else if (i == dim_numbers.input_feature_dimension()) {
|
||||||
|
new_dim_numbers.set_input_feature_dimension(pushed_counter);
|
||||||
|
} else {
|
||||||
|
auto it = absl::c_find(dim_numbers.input_spatial_dimensions(), i);
|
||||||
|
if (it != dim_numbers.input_spatial_dimensions().end()) {
|
||||||
|
int64 j = it - dim_numbers.input_spatial_dimensions().begin();
|
||||||
|
new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
transpose_dims[dim_counter++] = i;
|
||||||
|
pushed_counter++;
|
||||||
|
}
|
||||||
|
|
||||||
|
activations_batch_dim = new_batch_dim;
|
||||||
|
spatial_dimension_to_split = new_spatial_dim;
|
||||||
|
TF_ASSIGN_OR_RETURN(activations,
|
||||||
|
MakeTransposeHlo(activations, transpose_dims));
|
||||||
|
|
||||||
|
if (is_backprop) {
|
||||||
|
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
|
||||||
|
} else {
|
||||||
|
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dim_numbers = new_dim_numbers;
|
dim_numbers = new_dim_numbers;
|
||||||
}
|
}
|
||||||
|
|
||||||
return SpaceNextToBatchDetails{activations, transpose_dims};
|
return SpaceNextToBatchDetails{activations, transpose_dims};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<HloInstruction*>
|
||||||
|
ConvolutionVisitor::IncreaseSpatialSizeOnSpaceToBatchedShape(
|
||||||
|
HloInstruction* activations, int64 batch_dimension, int64 old_batch_size,
|
||||||
|
int64 spatial_dimension, int64 new_spatial_dim_size) {
|
||||||
|
CHECK_EQ(batch_dimension + 1, spatial_dimension);
|
||||||
|
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
|
||||||
|
activations->shape().dimensions().end());
|
||||||
|
|
||||||
|
const int64 new_batch_size = activations->shape().dimensions(batch_dimension);
|
||||||
|
int64 spatial_dim_size = activations->shape().dimensions(spatial_dimension);
|
||||||
|
const int64 reshaped_space_size =
|
||||||
|
spatial_dim_size * new_batch_size / old_batch_size;
|
||||||
|
|
||||||
|
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
|
||||||
|
<< new_batch_size << " old_batch_size " << old_batch_size;
|
||||||
|
new_dimensions[spatial_dimension] = reshaped_space_size;
|
||||||
|
new_dimensions[batch_dimension] = old_batch_size;
|
||||||
|
|
||||||
|
// Reshape the output of the new conv into the old convolutions shape.
|
||||||
|
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
|
||||||
|
MakeReshapeHlo(new_dimensions, activations));
|
||||||
|
|
||||||
|
VLOG(3) << "First reshape done";
|
||||||
|
PaddingConfig padding_config =
|
||||||
|
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
|
||||||
|
padding_config.mutable_dimensions(spatial_dimension)
|
||||||
|
->set_edge_padding_high(new_spatial_dim_size * new_batch_size /
|
||||||
|
old_batch_size -
|
||||||
|
reshaped_space_size);
|
||||||
|
padding_config.mutable_dimensions(spatial_dimension)->set_edge_padding_low(0);
|
||||||
|
HloInstruction* padding =
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::Zero(reshaped_activations->shape().element_type())));
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
reshaped_activations,
|
||||||
|
MakePadHlo(reshaped_activations, padding, padding_config));
|
||||||
|
|
||||||
|
std::vector<int64> reshape_back_dims(
|
||||||
|
reshaped_activations->shape().dimensions().begin(),
|
||||||
|
reshaped_activations->shape().dimensions().end());
|
||||||
|
|
||||||
|
reshape_back_dims[spatial_dimension] = new_spatial_dim_size;
|
||||||
|
reshape_back_dims[batch_dimension] = new_batch_size;
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(HloInstruction * activations_new,
|
||||||
|
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
|
||||||
|
|
||||||
|
VLOG(3) << "Size increased activations " << activations_new->ToString();
|
||||||
|
|
||||||
|
return activations_new;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<bool> ConvolutionVisitor::Run() {
|
StatusOr<bool> ConvolutionVisitor::Run() {
|
||||||
for (auto conv : conv_visitor_list_) {
|
for (auto conv : conv_visitor_list_) {
|
||||||
if (convs_to_visit_.count(conv) > 0) {
|
if (convs_to_visit_.count(conv) > 0) {
|
||||||
@ -519,6 +618,29 @@ StatusOr<bool> ConvolutionVisitor::Run() {
|
|||||||
// Iterate through all instructions that we could not propagate through, and
|
// Iterate through all instructions that we could not propagate through, and
|
||||||
// turn their operands from batch-to-space as needed.
|
// turn their operands from batch-to-space as needed.
|
||||||
for (auto instr : non_propagatable_instrs_) {
|
for (auto instr : non_propagatable_instrs_) {
|
||||||
|
if (instr->opcode() == HloOpcode::kConvolution) {
|
||||||
|
VLOG(1) << "Instr " << instr->ToString();
|
||||||
|
}
|
||||||
|
// Try to propagate on backprop filters
|
||||||
|
if (instr->opcode() == HloOpcode::kConvolution &&
|
||||||
|
!IsConvSuitableForSpaceToBatch(instr)) {
|
||||||
|
HloInstruction* producer = nullptr;
|
||||||
|
if (old_to_new_instrs_.contains(instr->mutable_operand(0))) {
|
||||||
|
producer = instr->mutable_operand(0);
|
||||||
|
} else if (old_to_new_instrs_.contains(instr->mutable_operand(1))) {
|
||||||
|
producer = instr->mutable_operand(1);
|
||||||
|
}
|
||||||
|
if (producer) {
|
||||||
|
if (CanPropagate(instr, producer, /*last_try=*/true)) {
|
||||||
|
bool needs_further_propagation;
|
||||||
|
TF_ASSIGN_OR_RETURN(needs_further_propagation,
|
||||||
|
Propagate(instr, producer));
|
||||||
|
TF_CHECK_OK(computation_->ReplaceInstruction(
|
||||||
|
instr, old_to_new_instrs_[instr]));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
VLOG(1) << "Could not eventually propagate through " << instr->ToString();
|
VLOG(1) << "Could not eventually propagate through " << instr->ToString();
|
||||||
absl::flat_hash_map<int64, HloInstruction*> operand_map;
|
absl::flat_hash_map<int64, HloInstruction*> operand_map;
|
||||||
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
||||||
@ -539,14 +661,14 @@ bool IsTrivialElementwise(HloInstruction* hlo) {
|
|||||||
if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng ||
|
if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng ||
|
||||||
hlo->opcode() == HloOpcode::kCopy ||
|
hlo->opcode() == HloOpcode::kCopy ||
|
||||||
hlo->opcode() == HloOpcode::kConstant ||
|
hlo->opcode() == HloOpcode::kConstant ||
|
||||||
hlo->opcode() == HloOpcode::kIota) {
|
hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return hlo->IsElementwise();
|
return hlo->IsElementwise();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||||
HloInstruction* producer) {
|
HloInstruction* producer, bool last_try) {
|
||||||
if (IsTrivialElementwise(consumer)) {
|
if (IsTrivialElementwise(consumer)) {
|
||||||
VLOG(2) << "Doing propagation check on elementwise op: "
|
VLOG(2) << "Doing propagation check on elementwise op: "
|
||||||
<< consumer->ToString();
|
<< consumer->ToString();
|
||||||
@ -604,7 +726,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
|||||||
// Make sure all other dimensions are of the same size.
|
// Make sure all other dimensions are of the same size.
|
||||||
if (pivot_new_instr->shape().dimensions(j) !=
|
if (pivot_new_instr->shape().dimensions(j) !=
|
||||||
new_instr->shape().dimensions(j)) {
|
new_instr->shape().dimensions(j)) {
|
||||||
if (!(consumer->IsElementwiseBinary() &&
|
if (!((consumer->IsElementwiseBinary() ||
|
||||||
|
consumer->opcode() == HloOpcode::kSelect) &&
|
||||||
j == instr_to_dim_map_[pivot_operand].second)) {
|
j == instr_to_dim_map_[pivot_operand].second)) {
|
||||||
VLOG(2) << "Elementwise op: checking for shape equivalence "
|
VLOG(2) << "Elementwise op: checking for shape equivalence "
|
||||||
<< consumer->ToString()
|
<< consumer->ToString()
|
||||||
@ -653,20 +776,70 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same reason why we give up on batch group counts applies to features in
|
||||||
|
// backprop.
|
||||||
|
if (consumer->feature_group_count() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
VLOG(2) << "Checking for backprop filter conv propagatability";
|
VLOG(2) << "Checking for backprop filter conv propagatability";
|
||||||
CHECK_EQ(consumer->operand_count(), 2);
|
CHECK_EQ(consumer->operand_count(), 2);
|
||||||
VLOG(2) << "Checking for backprop filter conv operands "
|
|
||||||
<< consumer->operand_count();
|
|
||||||
|
|
||||||
auto activations = consumer->mutable_operand(0);
|
auto activations = consumer->mutable_operand(0);
|
||||||
auto kernel = consumer->mutable_operand(1);
|
auto kernel = consumer->mutable_operand(1);
|
||||||
|
|
||||||
if (!old_to_new_instrs_.contains(kernel)) {
|
if (!last_try) {
|
||||||
VLOG(2) << "Backprop filter conv not ready for propagation because of "
|
if (!old_to_new_instrs_.contains(kernel) ||
|
||||||
"kernel is not space-to-batched";
|
!old_to_new_instrs_.contains(activations)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!old_to_new_instrs_.contains(kernel) &&
|
||||||
|
!old_to_new_instrs_.contains(activations)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64 rhs_dilation = consumer->window()
|
||||||
|
.dimensions(get_chosen_spatial_dim(consumer))
|
||||||
|
.window_dilation();
|
||||||
|
|
||||||
|
if (!old_to_new_instrs_.contains(kernel)) {
|
||||||
|
const int64 rhs_batch =
|
||||||
|
kernel->shape().dimensions(consumer->convolution_dimension_numbers()
|
||||||
|
.kernel_input_feature_dimension());
|
||||||
|
auto dim_map_val_op_0 = instr_to_dim_map_[activations];
|
||||||
|
const int64 old_batch_dim = dim_map_val_op_0.first;
|
||||||
|
const int64 old_space_dim = dim_map_val_op_0.second;
|
||||||
|
auto first_operand = old_to_new_instrs_[activations];
|
||||||
|
auto permute_dims_first_operand =
|
||||||
|
instr_to_dim_permute_map_[first_operand];
|
||||||
|
const int64 new_batch_dim =
|
||||||
|
DimLookUp(permute_dims_first_operand, old_batch_dim);
|
||||||
|
const int64 new_space_dim =
|
||||||
|
DimLookUp(permute_dims_first_operand, old_space_dim);
|
||||||
|
const int64 lhs_batch = first_operand->shape().dimensions(new_batch_dim);
|
||||||
|
|
||||||
|
if (first_operand->shape().dimensions(new_space_dim) % rhs_dilation !=
|
||||||
|
0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Because we want to convert activations into a space-to-batched version
|
||||||
|
// only for backprop filter convolutions, we want to make sure that the
|
||||||
|
// batch dimensions (feature dimensions, technically) are same sized.
|
||||||
|
// Since LHS is already space-to-batched, we need to account for it too.
|
||||||
|
if (rhs_batch * kNumSplits != lhs_batch) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If kernel have not been propagated through, we can do
|
||||||
|
// space-to-batch on them provided kernel has been propagated.
|
||||||
|
VLOG(2)
|
||||||
|
<< "Backprop filter conv ready for propagation: activations ready, "
|
||||||
|
" kernel will be space-to-batched";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
if (!old_to_new_instrs_.contains(activations)) {
|
if (!old_to_new_instrs_.contains(activations)) {
|
||||||
const int64 lhs_batch = activations->shape().dimensions(
|
const int64 lhs_batch = activations->shape().dimensions(
|
||||||
consumer->convolution_dimension_numbers().input_feature_dimension());
|
consumer->convolution_dimension_numbers().input_feature_dimension());
|
||||||
@ -720,10 +893,7 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 rhs_dilation = consumer->window()
|
if (first_operand->shape().dimensions(new_space_dim_operand_0) >
|
||||||
.dimensions(get_chosen_spatial_dim(consumer))
|
|
||||||
.window_dilation();
|
|
||||||
if (first_operand->shape().dimensions(new_space_dim_operand_0) !=
|
|
||||||
rhs_dilation *
|
rhs_dilation *
|
||||||
second_operand->shape().dimensions(new_space_dim_operand_1)) {
|
second_operand->shape().dimensions(new_space_dim_operand_1)) {
|
||||||
VLOG(2) << "Backprop filter conv not ready for propagation because of "
|
VLOG(2) << "Backprop filter conv not ready for propagation because of "
|
||||||
@ -816,20 +986,57 @@ void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer,
|
|||||||
auto new_producer = old_to_new_instrs_[producer];
|
auto new_producer = old_to_new_instrs_[producer];
|
||||||
auto permute_dims = instr_to_dim_permute_map_[new_producer];
|
auto permute_dims = instr_to_dim_permute_map_[new_producer];
|
||||||
auto dim_map_val = instr_to_dim_map_[producer];
|
auto dim_map_val = instr_to_dim_map_[producer];
|
||||||
|
|
||||||
|
const int64 old_batch_dim = dim_map_val.first;
|
||||||
|
const int64 old_space_dim = dim_map_val.second;
|
||||||
|
|
||||||
|
auto orig_broadcast_dims = consumer->dimensions();
|
||||||
|
|
||||||
|
bool batch_is_broadcasted =
|
||||||
|
absl::c_linear_search(orig_broadcast_dims, old_batch_dim);
|
||||||
|
const int64 new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
|
||||||
|
const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
|
||||||
|
|
||||||
|
bool map_found = broadcast_map_.contains(consumer);
|
||||||
|
if (map_found) {
|
||||||
|
// Check if we previously had created the same broadcast.
|
||||||
|
for (auto previous_broadcast : broadcast_map_[consumer]) {
|
||||||
|
if (ShapeUtil::CompatibleIgnoringElementType(previous_broadcast->shape(),
|
||||||
|
new_producer->shape())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64> final_shape_dims(
|
||||||
|
new_producer->shape().dimensions().begin(),
|
||||||
|
new_producer->shape().dimensions().end());
|
||||||
|
if (batch_is_broadcasted) {
|
||||||
|
final_shape_dims[new_batch_dim] =
|
||||||
|
producer->shape().dimensions(old_batch_dim);
|
||||||
|
final_shape_dims[new_space_dim] *= kNumSplits;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64> broadcast_dims;
|
std::vector<int64> broadcast_dims;
|
||||||
for (auto j : consumer->dimensions()) {
|
for (auto j : consumer->dimensions()) {
|
||||||
broadcast_dims.push_back(DimLookUp(permute_dims, j));
|
broadcast_dims.push_back(DimLookUp(permute_dims, j));
|
||||||
}
|
}
|
||||||
auto new_broadcast =
|
auto new_broadcast = MakeBroadcastHlo(consumer->mutable_operand(0),
|
||||||
MakeBroadcastHlo(consumer->mutable_operand(0), broadcast_dims,
|
broadcast_dims, final_shape_dims);
|
||||||
new_producer->shape().dimensions());
|
|
||||||
VLOG(1) << "Created broadcast " << new_broadcast->ToString();
|
VLOG(1) << "Created broadcast " << new_broadcast->ToString();
|
||||||
|
|
||||||
// Pass on the permutation information from the producer.
|
if (batch_is_broadcasted) {
|
||||||
old_to_new_instrs_[consumer] = new_broadcast;
|
new_broadcast =
|
||||||
instr_to_dim_map_[consumer] = dim_map_val;
|
MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast)
|
||||||
instr_to_dim_permute_map_[new_broadcast] = std::vector<int64>(
|
.ValueOrDie();
|
||||||
instr_to_dim_permute_map_[old_to_new_instrs_[producer]]);
|
VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!map_found) {
|
||||||
|
absl::flat_hash_set<HloInstruction*> set_of_broadcasts;
|
||||||
|
broadcast_map_[consumer] = set_of_broadcasts;
|
||||||
|
}
|
||||||
|
broadcast_map_[consumer].insert(new_broadcast);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvolutionVisitor::RewriteBroadcastTree(
|
void ConvolutionVisitor::RewriteBroadcastTree(
|
||||||
@ -882,11 +1089,9 @@ bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast,
|
|||||||
CHECK(instr_to_dim_map_.contains(old_other_op));
|
CHECK(instr_to_dim_map_.contains(old_other_op));
|
||||||
|
|
||||||
auto result = instr_to_dim_map_[old_other_op];
|
auto result = instr_to_dim_map_[old_other_op];
|
||||||
const int64 batch_dim = result.first;
|
|
||||||
const int64 space_dim = result.second;
|
const int64 space_dim = result.second;
|
||||||
auto broadcast_dims = broadcast->dimensions();
|
auto broadcast_dims = broadcast->dimensions();
|
||||||
return !absl::c_linear_search(broadcast_dims, batch_dim) &&
|
return !absl::c_linear_search(broadcast_dims, space_dim);
|
||||||
!absl::c_linear_search(broadcast_dims, space_dim);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
|
bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
|
||||||
@ -990,27 +1195,50 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
|
|||||||
// For elementwise binary ops, both of whose operands have been space-to-
|
// For elementwise binary ops, both of whose operands have been space-to-
|
||||||
// batched, if their new spatial sizes don't match, choose the bigger one
|
// batched, if their new spatial sizes don't match, choose the bigger one
|
||||||
// as the producer.
|
// as the producer.
|
||||||
if (consumer->IsElementwiseBinary() &&
|
if (consumer->IsElementwiseBinary() ||
|
||||||
old_to_new_instrs_.contains(consumer->mutable_operand(0)) &&
|
consumer->opcode() == HloOpcode::kSelect) {
|
||||||
old_to_new_instrs_.contains(consumer->mutable_operand(1))) {
|
int64 pivot_operand_number = -1;
|
||||||
is_pivot_producer_modified = true;
|
HloInstruction* pivot_operand = nullptr;
|
||||||
if (old_to_new_instrs_[consumer->mutable_operand(0)]
|
for (int i = 0; i < consumer->operand_count(); ++i) {
|
||||||
->shape()
|
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
|
||||||
.dimensions() > old_to_new_instrs_[consumer->mutable_operand(1)]
|
continue;
|
||||||
->shape()
|
}
|
||||||
.dimensions()) {
|
auto operand = consumer->mutable_operand(i);
|
||||||
producer = consumer->mutable_operand(0);
|
if (old_to_new_instrs_.contains(operand)) {
|
||||||
} else {
|
if (pivot_operand_number == -1 ||
|
||||||
producer = consumer->mutable_operand(1);
|
old_to_new_instrs_[pivot_operand]->shape().dimensions() <
|
||||||
|
old_to_new_instrs_[operand]->shape().dimensions()) {
|
||||||
|
is_pivot_producer_modified = true;
|
||||||
|
pivot_operand_number = i;
|
||||||
|
pivot_operand = consumer->mutable_operand(pivot_operand_number);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pivot_operand_number != -1) {
|
||||||
|
producer = pivot_operand;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64 i = 0; i < consumer->operand_count(); ++i) {
|
for (int64 i = 0; i < consumer->operand_count(); ++i) {
|
||||||
std::vector<HloInstruction*> instructions_to_transform;
|
std::vector<HloInstruction*> instructions_to_transform;
|
||||||
|
|
||||||
if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
|
if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
|
||||||
|
auto broadcast = consumer->mutable_operand(i);
|
||||||
|
PropagateOnBroadcast(broadcast, producer);
|
||||||
|
HloInstruction* new_broadcast = nullptr;
|
||||||
|
auto new_producer = old_to_new_instrs_[producer];
|
||||||
|
for (auto previous_broadcast : broadcast_map_[broadcast]) {
|
||||||
|
if (ShapeUtil::CompatibleIgnoringElementType(
|
||||||
|
previous_broadcast->shape(), new_producer->shape())) {
|
||||||
|
new_broadcast = previous_broadcast;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK_NE(new_broadcast, nullptr);
|
||||||
|
TF_CHECK_OK(
|
||||||
|
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
|
||||||
|
} else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
|
||||||
HloInstruction* operand_to_use = nullptr;
|
HloInstruction* operand_to_use = nullptr;
|
||||||
|
|
||||||
auto result = instr_to_dim_map_[producer];
|
auto result = instr_to_dim_map_[producer];
|
||||||
const int64 old_batch_dim = result.first;
|
const int64 old_batch_dim = result.first;
|
||||||
const int64 old_space_dim = result.second;
|
const int64 old_space_dim = result.second;
|
||||||
@ -1070,13 +1298,6 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
|
|||||||
}
|
}
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
|
new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
|
||||||
} else if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
|
|
||||||
if (!old_to_new_instrs_.contains(consumer->operand(i))) {
|
|
||||||
PropagateOnBroadcast(consumer->mutable_operand(i), producer);
|
|
||||||
}
|
|
||||||
auto new_broadcast = old_to_new_instrs_[consumer->mutable_operand(i)];
|
|
||||||
TF_CHECK_OK(
|
|
||||||
new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
|
|
||||||
} else if (consumer->IsElementwiseBinary() &&
|
} else if (consumer->IsElementwiseBinary() &&
|
||||||
IsBroadcastTree(consumer->mutable_operand(i), producer,
|
IsBroadcastTree(consumer->mutable_operand(i), producer,
|
||||||
instructions_to_transform)) {
|
instructions_to_transform)) {
|
||||||
@ -1463,8 +1684,10 @@ StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
|
|||||||
StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
|
StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
|
||||||
HloInstruction* old_instr) {
|
HloInstruction* old_instr) {
|
||||||
if (batch_to_space_map_.count(old_instr)) {
|
if (batch_to_space_map_.count(old_instr)) {
|
||||||
|
CHECK_NE(batch_to_space_map_[old_instr], nullptr);
|
||||||
return batch_to_space_map_[old_instr];
|
return batch_to_space_map_[old_instr];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = instr_to_dim_map_[old_instr];
|
auto result = instr_to_dim_map_[old_instr];
|
||||||
const int64 old_batch_dim = result.first;
|
const int64 old_batch_dim = result.first;
|
||||||
const int64 old_space_dim = result.second;
|
const int64 old_space_dim = result.second;
|
||||||
@ -1683,68 +1906,28 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||||||
VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
|
VLOG(1) << "spatial_split_size " << spatial_split_size << " slice_size "
|
||||||
<< slice_size;
|
<< slice_size;
|
||||||
|
|
||||||
const int64 new_batch_size =
|
|
||||||
activations_new->shape().dimensions(activations_batch_dim);
|
|
||||||
const int64 new_space_size =
|
const int64 new_space_size =
|
||||||
activations_new->shape().dimensions(c.spatial_dimension_to_split);
|
activations_new->shape().dimensions(c.spatial_dimension_to_split);
|
||||||
// In the below case, we cannot use the activations directly for Halo
|
// In the below case, we cannot use the activations directly for Halo
|
||||||
// Duplication. We must reshape them.
|
// Duplication. We must reshape them.
|
||||||
if (spatial_split_size > new_space_size) {
|
if (spatial_split_size > new_space_size) {
|
||||||
std::vector<int64> new_dimensions(
|
|
||||||
activations_new->shape().dimensions().begin(),
|
|
||||||
activations_new->shape().dimensions().end());
|
|
||||||
const int64 reshaped_space_size =
|
|
||||||
new_space_size * new_batch_size / old_batch_size;
|
|
||||||
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
|
|
||||||
<< new_batch_size << " old_batch_size " << old_batch_size;
|
|
||||||
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
|
|
||||||
new_dimensions[activations_batch_dim] = old_batch_size;
|
|
||||||
|
|
||||||
// Reshape the output of the new conv into the old convolutions shape.
|
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
|
|
||||||
MakeReshapeHlo(new_dimensions, activations_new));
|
|
||||||
|
|
||||||
VLOG(3) << "First reshape done";
|
|
||||||
PaddingConfig padding_config =
|
|
||||||
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
|
|
||||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
|
||||||
->set_edge_padding_high(spatial_split_size * new_batch_size /
|
|
||||||
old_batch_size -
|
|
||||||
reshaped_space_size);
|
|
||||||
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
|
|
||||||
->set_edge_padding_low(0);
|
|
||||||
HloInstruction* padding =
|
|
||||||
computation_->AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::Zero(reshaped_activations->shape().element_type())));
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
reshaped_activations,
|
activations_new,
|
||||||
MakePadHlo(reshaped_activations, padding, padding_config));
|
IncreaseSpatialSizeOnSpaceToBatchedShape(
|
||||||
|
activations_new, activations_batch_dim, old_batch_size,
|
||||||
std::vector<int64> reshape_back_dims(
|
c.spatial_dimension_to_split, spatial_split_size));
|
||||||
reshaped_activations->shape().dimensions().begin(),
|
|
||||||
reshaped_activations->shape().dimensions().end());
|
|
||||||
|
|
||||||
reshape_back_dims[c.spatial_dimension_to_split] = spatial_split_size;
|
|
||||||
reshape_back_dims[activations_batch_dim] = new_batch_size;
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
reshaped_activations,
|
|
||||||
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
|
|
||||||
|
|
||||||
VLOG(3) << "Second reshape done";
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
activations_new,
|
activations_new,
|
||||||
HaloDuplicateWithSlice(
|
HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
|
||||||
reshaped_activations, c.spatial_dimension_to_split,
|
activations_batch_dim, old_batch_size,
|
||||||
activations_batch_dim, old_batch_size,
|
/*low_padding=*/c.base_dilation_factor != 1 &&
|
||||||
/*low_padding=*/c.base_dilation_factor != 1 &&
|
c.inherent_low_padding != 0
|
||||||
c.inherent_low_padding != 0
|
? c.base_dilation_factor - 1
|
||||||
? c.base_dilation_factor - 1
|
: c.inherent_low_padding,
|
||||||
: c.inherent_low_padding,
|
c.inherent_high_padding,
|
||||||
c.inherent_high_padding, slice_size - spatial_split_size,
|
slice_size - spatial_split_size,
|
||||||
old_split_dim_size));
|
old_split_dim_size));
|
||||||
} else {
|
} else {
|
||||||
// If the ideal spatial_split_size was smaller than the incoming spatial
|
// If the ideal spatial_split_size was smaller than the incoming spatial
|
||||||
// dimension size, we don't need reshaping. Instead, we determine the
|
// dimension size, we don't need reshaping. Instead, we determine the
|
||||||
@ -1835,14 +2018,15 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
|
|||||||
int64& spatial_dimension_to_split,
|
int64& spatial_dimension_to_split,
|
||||||
int64& activations_batch_dim, int64 high_padding,
|
int64& activations_batch_dim, int64 high_padding,
|
||||||
int64 low_padding, int64 spatial_split_size,
|
int64 low_padding, int64 spatial_split_size,
|
||||||
int64 num_splits, bool is_backprop) {
|
int64 num_splits, bool is_backprop,
|
||||||
|
bool is_rhs) {
|
||||||
const int64 old_batch_size =
|
const int64 old_batch_size =
|
||||||
activations->shape().dimensions(activations_batch_dim);
|
activations->shape().dimensions(activations_batch_dim);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(auto retval,
|
||||||
auto retval, BringSpaceNextToBatch(activations, dim_numbers,
|
BringSpaceNextToBatch(
|
||||||
spatial_dimension_to_split,
|
activations, dim_numbers, spatial_dimension_to_split,
|
||||||
activations_batch_dim, is_backprop));
|
activations_batch_dim, is_backprop, is_rhs));
|
||||||
|
|
||||||
activations = retval.instr;
|
activations = retval.instr;
|
||||||
std::vector<int64> transpose_dims = retval.transpose_dims;
|
std::vector<int64> transpose_dims = retval.transpose_dims;
|
||||||
@ -1903,7 +2087,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
.window_dilation();
|
.window_dilation();
|
||||||
|
|
||||||
auto original_conv_dims = convolution->convolution_dimension_numbers();
|
auto original_conv_dims = convolution->convolution_dimension_numbers();
|
||||||
const int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions(
|
int64 kernel_space_dim = original_conv_dims.kernel_spatial_dimensions(
|
||||||
get_chosen_spatial_dim(convolution));
|
get_chosen_spatial_dim(convolution));
|
||||||
auto kernel_old = convolution->mutable_operand(1);
|
auto kernel_old = convolution->mutable_operand(1);
|
||||||
const int64 old_kernel_split_dim_size =
|
const int64 old_kernel_split_dim_size =
|
||||||
@ -1914,18 +2098,24 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
int64 old_split_dim_size = activations_old->shape().dimensions(old_space_dim);
|
int64 old_split_dim_size = activations_old->shape().dimensions(old_space_dim);
|
||||||
|
|
||||||
int64 old_batch_dim = original_conv_dims.input_feature_dimension();
|
int64 old_batch_dim = original_conv_dims.input_feature_dimension();
|
||||||
|
int64 kernel_old_batch_dim =
|
||||||
|
original_conv_dims.kernel_input_feature_dimension();
|
||||||
const int64 old_batch_size =
|
const int64 old_batch_size =
|
||||||
activations_old->shape().dimensions(old_batch_dim);
|
activations_old->shape().dimensions(old_batch_dim);
|
||||||
|
|
||||||
CHECK(old_to_new_instrs_.contains(kernel_old));
|
CHECK(old_to_new_instrs_.contains(kernel_old) ||
|
||||||
auto kernel_new = old_to_new_instrs_[kernel_old];
|
old_to_new_instrs_.contains(activations_old));
|
||||||
|
|
||||||
auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
|
|
||||||
|
|
||||||
HloInstruction* activations_new = nullptr;
|
HloInstruction* activations_new = nullptr;
|
||||||
|
HloInstruction* kernel_new = nullptr;
|
||||||
bool activations_locally_space_to_batched = false;
|
bool activations_locally_space_to_batched = false;
|
||||||
|
bool kernel_locally_space_to_batched = false;
|
||||||
|
std::vector<int64> permute_dims_kernel, permute_dims;
|
||||||
// If activations were no space-to-batched, we space-to-batch them below.
|
// If activations were no space-to-batched, we space-to-batch them below.
|
||||||
if (!old_to_new_instrs_.contains(activations_old)) {
|
if (!old_to_new_instrs_.contains(activations_old)) {
|
||||||
|
kernel_new = old_to_new_instrs_[kernel_old];
|
||||||
|
permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
|
||||||
|
|
||||||
VLOG(1) << "Space-to-batching activations to enable space-to-depth";
|
VLOG(1) << "Space-to-batching activations to enable space-to-depth";
|
||||||
const int64 prev_feature_dim = original_conv_dims.input_feature_dimension();
|
const int64 prev_feature_dim = original_conv_dims.input_feature_dimension();
|
||||||
const int64 prev_batch_dim = original_conv_dims.input_batch_dimension();
|
const int64 prev_batch_dim = original_conv_dims.input_batch_dimension();
|
||||||
@ -1960,14 +2150,63 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
VLOG(3) << "New Activations " << retval.first->ToString();
|
VLOG(3) << "New Activations " << retval.first->ToString();
|
||||||
|
|
||||||
activations_locally_space_to_batched = true;
|
activations_locally_space_to_batched = true;
|
||||||
|
} else if (!old_to_new_instrs_.contains(kernel_old)) {
|
||||||
|
activations_new = old_to_new_instrs_[activations_old];
|
||||||
|
permute_dims = instr_to_dim_permute_map_[activations_new];
|
||||||
|
|
||||||
|
VLOG(1) << "Space-to-batching kernel to enable space-to-depth";
|
||||||
|
const int64 prev_feature_dim =
|
||||||
|
original_conv_dims.kernel_input_feature_dimension();
|
||||||
|
const int64 prev_output_feature_dim =
|
||||||
|
original_conv_dims.kernel_output_feature_dimension();
|
||||||
|
// TODO(b/168316428): The instr_to_dim_map_ is set incorrectly here, but it
|
||||||
|
// doesn't matter since it is never used. Investigate further to see if just
|
||||||
|
// not setting it works.
|
||||||
|
instr_to_dim_map_[kernel_old] =
|
||||||
|
std::make_pair(prev_feature_dim, prev_output_feature_dim);
|
||||||
|
|
||||||
|
const int64 new_space_dim = DimLookUp(permute_dims, old_space_dim);
|
||||||
|
const int64 new_split_dim_size =
|
||||||
|
activations_new->shape().dimensions(new_space_dim);
|
||||||
|
const int64 needed_spatial_size =
|
||||||
|
CeilOfRatio(new_split_dim_size, rhs_dilation);
|
||||||
|
int64 old_kernel_split_dim_size =
|
||||||
|
kernel_old->shape().dimensions(kernel_space_dim);
|
||||||
|
const int64 pad_size =
|
||||||
|
needed_spatial_size * kNumSplits - old_kernel_split_dim_size;
|
||||||
|
|
||||||
|
ConvolutionDimensionNumbers tmp_dim_numbers;
|
||||||
|
tmp_dim_numbers = original_conv_dims;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto retval, SplitSpace(kernel_old, tmp_dim_numbers, kernel_space_dim,
|
||||||
|
kernel_old_batch_dim,
|
||||||
|
/*high_padding=*/pad_size, /*low_padding=*/0,
|
||||||
|
needed_spatial_size, kNumSplits,
|
||||||
|
/*is_backprop=*/true, /*is_rhs=*/true));
|
||||||
|
|
||||||
|
old_to_new_instrs_[kernel_old] = retval.first;
|
||||||
|
|
||||||
|
std::vector<int64> reversed_transpose_dims(retval.second.size());
|
||||||
|
for (int64 i = 0; i < retval.second.size(); ++i) {
|
||||||
|
reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
|
||||||
|
}
|
||||||
|
instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
|
||||||
|
|
||||||
|
VLOG(3) << "New kernel " << retval.first->ToString();
|
||||||
|
|
||||||
|
kernel_locally_space_to_batched = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(old_to_new_instrs_.contains(activations_old));
|
CHECK(old_to_new_instrs_.contains(activations_old));
|
||||||
|
CHECK(old_to_new_instrs_.contains(kernel_old));
|
||||||
activations_new = old_to_new_instrs_[activations_old];
|
activations_new = old_to_new_instrs_[activations_old];
|
||||||
|
kernel_new = old_to_new_instrs_[kernel_old];
|
||||||
const int64 new_spatial_dimension =
|
const int64 new_spatial_dimension =
|
||||||
activations_new->shape().dimensions_size();
|
activations_new->shape().dimensions_size();
|
||||||
|
|
||||||
auto permute_dims = instr_to_dim_permute_map_[activations_new];
|
permute_dims = instr_to_dim_permute_map_[activations_new];
|
||||||
|
permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
|
||||||
|
|
||||||
auto permuted_conv_dims_numbers = original_conv_dims;
|
auto permuted_conv_dims_numbers = original_conv_dims;
|
||||||
|
|
||||||
// Note the inversion here : batch and feature are inverted in backprop
|
// Note the inversion here : batch and feature are inverted in backprop
|
||||||
@ -2024,9 +2263,12 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
permuted_conv_dims_numbers.kernel_spatial_dimensions(
|
permuted_conv_dims_numbers.kernel_spatial_dimensions(
|
||||||
get_chosen_spatial_dim(convolution));
|
get_chosen_spatial_dim(convolution));
|
||||||
|
|
||||||
const int64 new_split_dim_size =
|
int64 new_split_dim_size =
|
||||||
activations_new->shape().dimensions(spatial_dimension_to_split);
|
activations_new->shape().dimensions(spatial_dimension_to_split);
|
||||||
|
|
||||||
|
const int64 kernel_new_split_dim_size =
|
||||||
|
kernel_new->shape().dimensions(kernel_spatial_dimension_to_split);
|
||||||
|
|
||||||
permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim);
|
permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim);
|
||||||
permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim);
|
permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim);
|
||||||
|
|
||||||
@ -2040,6 +2282,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
|
BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
|
||||||
spatial_dimension_to_split, activations_batch_dim,
|
spatial_dimension_to_split, activations_batch_dim,
|
||||||
/*is_backprop=*/true));
|
/*is_backprop=*/true));
|
||||||
|
|
||||||
std::vector<int64> transpose_dims = retval.transpose_dims;
|
std::vector<int64> transpose_dims = retval.transpose_dims;
|
||||||
CHECK(!transpose_dims.empty());
|
CHECK(!transpose_dims.empty());
|
||||||
activations_new = retval.instr;
|
activations_new = retval.instr;
|
||||||
@ -2048,6 +2291,17 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
<< activations_new->ToString();
|
<< activations_new->ToString();
|
||||||
VLOG(1) << "activations_batch_dim " << activations_batch_dim
|
VLOG(1) << "activations_batch_dim " << activations_batch_dim
|
||||||
<< " activations_feature_dim " << activations_feature_dim;
|
<< " activations_feature_dim " << activations_feature_dim;
|
||||||
|
const int64 expected_split_dim_size =
|
||||||
|
rhs_dilation * kernel_new_split_dim_size;
|
||||||
|
if (new_split_dim_size != expected_split_dim_size) {
|
||||||
|
CHECK_LT(new_split_dim_size, expected_split_dim_size);
|
||||||
|
new_split_dim_size = expected_split_dim_size;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
activations_new,
|
||||||
|
IncreaseSpatialSizeOnSpaceToBatchedShape(
|
||||||
|
activations_new, activations_batch_dim, old_batch_size,
|
||||||
|
spatial_dimension_to_split, new_split_dim_size));
|
||||||
|
}
|
||||||
|
|
||||||
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::Zero(activations_new->shape().element_type())));
|
LiteralUtil::Zero(activations_new->shape().element_type())));
|
||||||
@ -2060,16 +2314,18 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
activations_batch_dim, spatial_dimension_to_split,
|
activations_batch_dim, spatial_dimension_to_split,
|
||||||
old_batch_dim, old_space_dim));
|
old_batch_dim, old_space_dim));
|
||||||
}
|
}
|
||||||
VLOG(3) << "Selecting the valid kernel area";
|
if (!kernel_locally_space_to_batched) {
|
||||||
// Select kernel correctly by masking additional space.
|
VLOG(3) << "Selecting the valid kernel area";
|
||||||
TF_ASSIGN_OR_RETURN(
|
// Select kernel correctly by masking additional space.
|
||||||
kernel_new,
|
TF_ASSIGN_OR_RETURN(
|
||||||
SelectValidPortion(
|
kernel_new,
|
||||||
kernel_new, kernel_old, select_val,
|
SelectValidPortion(kernel_new, kernel_old, select_val,
|
||||||
/*new_batch_dim=*/kernel_input_feature_dim,
|
/*new_batch_dim=*/kernel_input_feature_dim,
|
||||||
kernel_spatial_dimension_to_split,
|
kernel_spatial_dimension_to_split,
|
||||||
/*old_batch_dim=*/original_conv_dims.kernel_input_feature_dimension(),
|
/*old_batch_dim=*/
|
||||||
kernel_space_dim));
|
original_conv_dims.kernel_input_feature_dimension(),
|
||||||
|
kernel_space_dim));
|
||||||
|
}
|
||||||
|
|
||||||
// Create the new convolution dim numbers.
|
// Create the new convolution dim numbers.
|
||||||
auto new_dim_numbers = permuted_conv_dims_numbers;
|
auto new_dim_numbers = permuted_conv_dims_numbers;
|
||||||
@ -2115,7 +2371,9 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
old_split_dim_size - expanded_kernel + 1 +
|
old_split_dim_size - expanded_kernel + 1 +
|
||||||
(inherent_low_padding < 0 ? inherent_low_padding : 0) +
|
(inherent_low_padding < 0 ? inherent_low_padding : 0) +
|
||||||
(inherent_high_padding < 0 ? inherent_high_padding : 0);
|
(inherent_high_padding < 0 ? inherent_high_padding : 0);
|
||||||
VLOG(1) << "overlap_count " << overlap_count;
|
VLOG(1) << "overlap_count " << overlap_count << " inherent_low_padding "
|
||||||
|
<< inherent_low_padding << " inherent_high_padding "
|
||||||
|
<< inherent_high_padding;
|
||||||
|
|
||||||
// Insert original activations.
|
// Insert original activations.
|
||||||
for (int64 i = 0; i < overlap_count; ++i) {
|
for (int64 i = 0; i < overlap_count; ++i) {
|
||||||
@ -2190,7 +2448,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
|
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
|
||||||
->set_padding_low(0);
|
->set_padding_low(0);
|
||||||
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
|
new_window.mutable_dimensions(get_chosen_spatial_dim(convolution))
|
||||||
->set_size(new_split_dim_size / rhs_dilation);
|
->set_size(CeilOfRatio(new_split_dim_size, rhs_dilation));
|
||||||
|
|
||||||
// Set the window for the additional spatial dim. This is a vanilla window.
|
// Set the window for the additional spatial dim. This is a vanilla window.
|
||||||
auto window_dim = new_window.add_dimensions();
|
auto window_dim = new_window.add_dimensions();
|
||||||
@ -2211,6 +2469,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
|||||||
/*preferred_element_type=*/convolution->shape().element_type()));
|
/*preferred_element_type=*/convolution->shape().element_type()));
|
||||||
convolution->SetupDerivedInstruction(new_conv);
|
convolution->SetupDerivedInstruction(new_conv);
|
||||||
|
|
||||||
|
VLOG(2) << "New backprop filter convolution " << new_conv->ToString();
|
||||||
|
|
||||||
std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
|
std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
|
||||||
new_conv->shape().dimensions().end());
|
new_conv->shape().dimensions().end());
|
||||||
|
|
||||||
@ -2364,8 +2624,8 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
|||||||
DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
|
DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
|
||||||
|
|
||||||
if (reduce_window_or_select_and_scatter != nullptr) {
|
if (reduce_window_or_select_and_scatter != nullptr) {
|
||||||
VLOG(2) << "DoesConvolutionFeedReduceWindowOrSelectAndScatter "
|
VLOG(2)
|
||||||
<< reduce_window_or_select_and_scatter;
|
<< "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
|
||||||
// Take into account the stride of the reduce window while choosing the
|
// Take into account the stride of the reduce window while choosing the
|
||||||
// spatial_split_size. This will guarantee propagation through reduce
|
// spatial_split_size. This will guarantee propagation through reduce
|
||||||
// windows.
|
// windows.
|
||||||
@ -2457,6 +2717,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
|||||||
/*preferred_element_type=*/convolution->shape().element_type()));
|
/*preferred_element_type=*/convolution->shape().element_type()));
|
||||||
convolution->SetupDerivedInstruction(new_conv);
|
convolution->SetupDerivedInstruction(new_conv);
|
||||||
|
|
||||||
|
// If the activations were to be batch-to-spaced again, simply use the
|
||||||
|
// original value.
|
||||||
|
batch_to_space_map_[convolution->mutable_operand(0)] =
|
||||||
|
convolution->mutable_operand(0);
|
||||||
|
|
||||||
VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
|
VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
|
||||||
|
|
||||||
const int64 output_split_spatial_dim =
|
const int64 output_split_spatial_dim =
|
||||||
@ -2483,7 +2748,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
|||||||
get_chosen_spatial_dim(original_conv)));
|
get_chosen_spatial_dim(original_conv)));
|
||||||
|
|
||||||
instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims);
|
instr_to_dim_permute_map_[new_conv] = std::vector<int64>(transpose_dims);
|
||||||
|
if (non_propagatable_instrs_.count(convolution) > 0) {
|
||||||
|
non_propagatable_instrs_.erase(convolution);
|
||||||
|
}
|
||||||
TF_CHECK_OK(PropagateOnUsers(original_conv));
|
TF_CHECK_OK(PropagateOnUsers(original_conv));
|
||||||
|
|
||||||
changed_ = true;
|
changed_ = true;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user