Skip space-to-batch optimization on convs that are used by a different rank reduce-window or select-and-scatter.
PiperOrigin-RevId: 361053054 Change-Id: Idf82848912ceea722aebbd07f0c87a1b14499673
This commit is contained in:
parent
0911cc1470
commit
dd1ce23e93
@ -2798,7 +2798,9 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
auto reduce_window_or_select_and_scatter =
|
||||
DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
|
||||
|
||||
if (reduce_window_or_select_and_scatter != nullptr) {
|
||||
if (reduce_window_or_select_and_scatter != nullptr &&
|
||||
reduce_window_or_select_and_scatter->shape().rank() ==
|
||||
convolution->shape().rank()) {
|
||||
VLOG(2)
|
||||
<< "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
|
||||
// Take into account the stride of the reduce window while choosing the
|
||||
|
@ -64,13 +64,46 @@ ENTRY computation {
|
||||
EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1);
|
||||
}
|
||||
|
||||
TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch1WithReduceWindow) {
|
||||
string hlo_string = R"(
|
||||
HloModule module
|
||||
adder (lhs: bf16[], rhs: bf16[]) -> bf16[] {
|
||||
lhs = bf16[] parameter(0)
|
||||
rhs = bf16[] parameter(1)
|
||||
ROOT add = bf16[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY computation {
|
||||
%p0 = bf16[1,258,258,32] parameter(0)
|
||||
%p1 = bf16[3,3,32,32] parameter(1)
|
||||
%convolution = bf16[1,256,256,32] convolution(%p0, %p1), window={size=3x3},
|
||||
dim_labels=b01f_01io->b01f
|
||||
%constant = bf16[3] constant({1.0, 2.0, 3.0})
|
||||
%tuple = (bf16[1,256,256,32], bf16[3])tuple(%convolution, %constant)
|
||||
ROOT %gte = bf16[1,256,256,32] get-tuple-element(%tuple), index=0
|
||||
%gte2 = bf16[3]get-tuple-element(%tuple), index=1
|
||||
%init = bf16[] constant(1.0)
|
||||
%reduce-window = bf16[3] reduce-window(bf16[3] %gte2, bf16[] %init),
|
||||
window={size=1}, to_apply=%adder
|
||||
}
|
||||
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
ConvolutionSpaceToBatchConverter converter;
|
||||
// Test that a reduce window consumer with different rank won't freeze the
|
||||
// compiler.
|
||||
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) {
|
||||
string hlo_string = R"(
|
||||
HloModule module
|
||||
ENTRY computation {
|
||||
%p0 = bf16[2,258,258,32] parameter(0)
|
||||
%p1 = bf16[3,3,32,32] parameter(1)
|
||||
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
|
||||
ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3},
|
||||
dim_labels=b01f_01io->b01f
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user