Skip space-to-batch optimization on convs that are used by a different rank reduce-window or select-and-scatter.

PiperOrigin-RevId: 361053054
Change-Id: Idf82848912ceea722aebbd07f0c87a1b14499673
This commit is contained in:
Yunxing Dai 2021-03-04 19:04:49 -08:00 committed by TensorFlower Gardener
parent 0911cc1470
commit dd1ce23e93
2 changed files with 37 additions and 2 deletions

View File

@ -2798,7 +2798,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
auto reduce_window_or_select_and_scatter =
DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
if (reduce_window_or_select_and_scatter != nullptr) {
if (reduce_window_or_select_and_scatter != nullptr &&
reduce_window_or_select_and_scatter->shape().rank() ==
convolution->shape().rank()) {
VLOG(2)
<< "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
// Take into account the stride of the reduce window while choosing the

View File

@ -64,13 +64,46 @@ ENTRY computation {
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1);
}
TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch1WithReduceWindow) {
string hlo_string = R"(
HloModule module
adder (lhs: bf16[], rhs: bf16[]) -> bf16[] {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
ENTRY computation {
%p0 = bf16[1,258,258,32] parameter(0)
%p1 = bf16[3,3,32,32] parameter(1)
%convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3},
dim_labels=b01f_01io->b01f
%constant = bf16[3] constant({1.0, 2.0, 3.0})
%tuple = (bf16[1,256,256,32], bf16[3])tuple(%convolution, %constant)
ROOT %gte = bf16[1,256,256,32] get-tuple-element(%tuple), index=0
%gte2 = bf16[3]get-tuple-element(%tuple), index=1
%init = bf16[] constant(1.0)
%reduce-window = bf16[3] reduce-window(bf16[3] %gte2, bf16[] %init),
window={size=1}, to_apply=%adder
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
ConvolutionSpaceToBatchConverter converter;
// Test that a reduce window consumer with different rank won't freeze the
// compiler.
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
}
TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) {
string hlo_string = R"(
HloModule module
ENTRY computation {
%p0 = bf16[2,258,258,32] parameter(0)
%p1 = bf16[3,3,32,32] parameter(1)
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
dim_labels=b01f_01io->b01f
}