Line up select-and-scatter stride with convolutions during space-to-batch.

PiperOrigin-RevId: 344132359
Change-Id: I917cb4f9aed1939ab159214e8bc9f2bea10c34d8
This commit is contained in:
A. Unique TensorFlower 2020-11-24 14:27:20 -08:00 committed by TensorFlower Gardener
parent 4e9a867e42
commit d45e8258f1

View File

@ -153,8 +153,8 @@ class ConvolutionVisitor {
return permute_dims[id];
}
HloInstruction* DoesConvolutionFeedReduceWindow(HloInstruction* instr,
int64 depth);
HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
HloInstruction* instr, int64 depth);
private:
// Current HloComputation instance the ConvolutionVisitor is traversing.
@ -650,6 +650,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
// The permuting must match.
if (permute_dims_first_operand != permute_dims_second_operand) {
VLOG(2) << "Can't propagate through select and scatter due to "
"permutation mismatch";
return false;
}
@ -663,6 +665,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
if (first_operand->shape().dimensions(new_batch_dim) !=
second_operand->shape().dimensions(new_batch_dim)) {
VLOG(2)
<< "Can't propagate through select and scatter due to dim mismatch";
return false;
}
@ -676,6 +680,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
pad_low) /
stride !=
second_operand->shape().dimensions(new_space_dim)) {
VLOG(2) << "Can't propagate through select and scatter due to stride "
"mismatch";
return false;
}
VLOG(1) << "Can propagate through select and scatter";
@ -1910,14 +1916,16 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
return Status::OK();
}
HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow(
HloInstruction*
ConvolutionVisitor::DoesConvolutionFeedReduceWindowOrSelectAndScatter(
HloInstruction* instr, int64 depth = kReduceWindowSearchDepth) {
if (depth == 0) {
return nullptr;
}
for (auto user : instr->users()) {
if (user->opcode() == HloOpcode::kReduceWindow) {
if (user->opcode() == HloOpcode::kReduceWindow ||
user->opcode() == HloOpcode::kSelectAndScatter) {
return user;
}
// Stop the search if these ops are encountered.
@ -1926,7 +1934,8 @@ HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow(
user->opcode() == HloOpcode::kTranspose) {
continue;
}
auto ret = DoesConvolutionFeedReduceWindow(user, depth - 1);
auto ret =
DoesConvolutionFeedReduceWindowOrSelectAndScatter(user, depth - 1);
if (ret != nullptr) {
return ret;
}
@ -2045,16 +2054,19 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
spatial_split_size += c.stride;
}
auto reduce_window = DoesConvolutionFeedReduceWindow(convolution);
auto reduce_window_or_select_and_scatter =
DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
if (reduce_window != nullptr) {
VLOG(2) << "DoesConvolutionFeedReduceWindow " << reduce_window;
if (reduce_window_or_select_and_scatter != nullptr) {
VLOG(2) << "DoesConvolutionFeedReduceWindowOrSelectAndScatter "
<< reduce_window_or_select_and_scatter;
// Take into account the stride of the reduce window while choosing the
// spatial_split_size. This will guarantee propagation through reduce
// windows.
const int64 red_win_stride =
reduce_window->window().dimensions(output_spatial_dim).stride();
while ((spatial_split_size / c.stride) % red_win_stride != 0) {
const int64 win_stride = reduce_window_or_select_and_scatter->window()
.dimensions(output_spatial_dim)
.stride();
while ((spatial_split_size / c.stride) % win_stride != 0) {
spatial_split_size += c.stride;
}
}