[XLA] Adding an option to disable negative padding replacement in the algebraic_simplifier.
PiperOrigin-RevId: 353340055 Change-Id: If59e9a1c5a5f82c3124912dba19bc49bf6104d13
This commit is contained in:
parent
aa334ce6f3
commit
c306f73ce9
@ -3425,7 +3425,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
||||
}
|
||||
}
|
||||
|
||||
if (has_negative) {
|
||||
if (has_negative && options_.enable_negative_padding_replacement()) {
|
||||
// Pad has negative padding. Replace with a pad with the non-negative
|
||||
// padding followed by a slice which effectively performs the negative
|
||||
// padding.
|
||||
|
@ -138,6 +138,15 @@ class AlgebraicSimplifierOptions {
|
||||
|
||||
bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; }
|
||||
|
||||
void set_enable_negative_padding_replacement(
|
||||
bool enable_negative_padding_replacement) {
|
||||
enable_negative_padding_replacement_ = enable_negative_padding_replacement;
|
||||
}
|
||||
|
||||
bool enable_negative_padding_replacement() const {
|
||||
return enable_negative_padding_replacement_;
|
||||
}
|
||||
|
||||
void set_replace_transpose_with_bitcast(bool replace_transpose_with_bitcast) {
|
||||
replace_transpose_with_bitcast_ = replace_transpose_with_bitcast;
|
||||
}
|
||||
@ -169,6 +178,7 @@ class AlgebraicSimplifierOptions {
|
||||
bool enable_floats_are_real_{false};
|
||||
bool enable_window_reduce_to_reduce_replacement_{true};
|
||||
bool enable_reduce_of_reshape_{true};
|
||||
bool enable_negative_padding_replacement_{true};
|
||||
bool replace_transpose_with_bitcast_{true};
|
||||
int64 very_small_gather_size_{4};
|
||||
Metadata metadata_;
|
||||
|
@ -3317,6 +3317,54 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
|
||||
has_negative_padding(computation->root_instruction()->operand(0)));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CanDisableNegativePadding) {
|
||||
// Verify that a pad instruction with negative padding is replaced with a
|
||||
// pad with non-negative padding followed by a slice.
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
PaddingConfig padding;
|
||||
int64 low_padding[2] = {-1, -2};
|
||||
int64 high_padding[2] = {2, -3};
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
auto dimension = padding.add_dimensions();
|
||||
dimension->set_edge_padding_low(low_padding[i]);
|
||||
dimension->set_edge_padding_high(high_padding[i]);
|
||||
dimension->set_interior_padding(0);
|
||||
}
|
||||
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Verify that we can disable the negative padding optimization.
|
||||
AlgebraicSimplifierOptions opts = default_options_;
|
||||
opts.set_enable_negative_padding_replacement(false);
|
||||
|
||||
AlgebraicSimplifier simplifier(opts);
|
||||
|
||||
auto has_negative_padding = [](const HloInstruction* pad) {
|
||||
for (auto& padding_dimension : pad->padding_config().dimensions()) {
|
||||
if (padding_dimension.edge_padding_low() < 0 ||
|
||||
padding_dimension.edge_padding_high() < 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
|
||||
EXPECT_TRUE(has_negative_padding(pad));
|
||||
|
||||
// Nothing has changed since the negative padding replacement is disabled.
|
||||
ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) {
|
||||
// Verify that a pad instruction with interior padding on one-sized
|
||||
// dimensions, removes the interior padding.
|
||||
|
Loading…
x
Reference in New Issue
Block a user