Line up select-and-scatter stride with convolutions during space-to-batch.
PiperOrigin-RevId: 344132359 Change-Id: I917cb4f9aed1939ab159214e8bc9f2bea10c34d8
This commit is contained in:
parent
4e9a867e42
commit
d45e8258f1
@ -153,8 +153,8 @@ class ConvolutionVisitor {
|
||||
return permute_dims[id];
|
||||
}
|
||||
|
||||
HloInstruction* DoesConvolutionFeedReduceWindow(HloInstruction* instr,
|
||||
int64 depth);
|
||||
HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
|
||||
HloInstruction* instr, int64 depth);
|
||||
|
||||
private:
|
||||
// Current HloComputation instance the ConvolutionVisitor is traversing.
|
||||
@ -650,6 +650,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
|
||||
// The permuting must match.
|
||||
if (permute_dims_first_operand != permute_dims_second_operand) {
|
||||
VLOG(2) << "Can't propagate through select and scatter due to "
|
||||
"permutation mismatch";
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -663,6 +665,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
|
||||
if (first_operand->shape().dimensions(new_batch_dim) !=
|
||||
second_operand->shape().dimensions(new_batch_dim)) {
|
||||
VLOG(2)
|
||||
<< "Can't propagate through select and scatter due to dim mismatch";
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -676,6 +680,8 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
|
||||
pad_low) /
|
||||
stride !=
|
||||
second_operand->shape().dimensions(new_space_dim)) {
|
||||
VLOG(2) << "Can't propagate through select and scatter due to stride "
|
||||
"mismatch";
|
||||
return false;
|
||||
}
|
||||
VLOG(1) << "Can propagate through select and scatter";
|
||||
@ -1910,14 +1916,16 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow(
|
||||
HloInstruction*
|
||||
ConvolutionVisitor::DoesConvolutionFeedReduceWindowOrSelectAndScatter(
|
||||
HloInstruction* instr, int64 depth = kReduceWindowSearchDepth) {
|
||||
if (depth == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (auto user : instr->users()) {
|
||||
if (user->opcode() == HloOpcode::kReduceWindow) {
|
||||
if (user->opcode() == HloOpcode::kReduceWindow ||
|
||||
user->opcode() == HloOpcode::kSelectAndScatter) {
|
||||
return user;
|
||||
}
|
||||
// Stop the search if these ops are encountered.
|
||||
@ -1926,7 +1934,8 @@ HloInstruction* ConvolutionVisitor::DoesConvolutionFeedReduceWindow(
|
||||
user->opcode() == HloOpcode::kTranspose) {
|
||||
continue;
|
||||
}
|
||||
auto ret = DoesConvolutionFeedReduceWindow(user, depth - 1);
|
||||
auto ret =
|
||||
DoesConvolutionFeedReduceWindowOrSelectAndScatter(user, depth - 1);
|
||||
if (ret != nullptr) {
|
||||
return ret;
|
||||
}
|
||||
@ -2045,16 +2054,19 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
spatial_split_size += c.stride;
|
||||
}
|
||||
|
||||
auto reduce_window = DoesConvolutionFeedReduceWindow(convolution);
|
||||
auto reduce_window_or_select_and_scatter =
|
||||
DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
|
||||
|
||||
if (reduce_window != nullptr) {
|
||||
VLOG(2) << "DoesConvolutionFeedReduceWindow " << reduce_window;
|
||||
if (reduce_window_or_select_and_scatter != nullptr) {
|
||||
VLOG(2) << "DoesConvolutionFeedReduceWindowOrSelectAndScatter "
|
||||
<< reduce_window_or_select_and_scatter;
|
||||
// Take into account the stride of the reduce window while choosing the
|
||||
// spatial_split_size. This will guarantee propagation through reduce
|
||||
// windows.
|
||||
const int64 red_win_stride =
|
||||
reduce_window->window().dimensions(output_spatial_dim).stride();
|
||||
while ((spatial_split_size / c.stride) % red_win_stride != 0) {
|
||||
const int64 win_stride = reduce_window_or_select_and_scatter->window()
|
||||
.dimensions(output_spatial_dim)
|
||||
.stride();
|
||||
while ((spatial_split_size / c.stride) % win_stride != 0) {
|
||||
spatial_split_size += c.stride;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user