[XLA] Add some more slice of pad optimizations.

PiperOrigin-RevId: 296361878
Change-Id: I4dbef5e94d95f3337c1004e8c3f09c7a94148075
This commit is contained in:
Blake Hechtman 2020-02-20 21:16:59 -08:00 committed by TensorFlower Gardener
parent 2ca35b7a30
commit 41b6bae3d1
2 changed files with 68 additions and 57 deletions

View File

@ -3204,53 +3204,6 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
return false;
}
if (slice->operand(0)->opcode() == HloOpcode::kPad) {
VLOG(10) << "Trying to simplify scalar slice of pad";
// Check there's no internal padding. Again, we could handle that too, since
// everything is statically known, but it's not worth it.
auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0));
auto padding_config = pad->padding_config();
int64 rank = padding_config.dimensions_size();
if (HasInteriorPadding(padding_config)) {
VLOG(10) << "Not folding scalar slice of pad, pad has interior padding";
return false;
}
// Check whether the scalar we're slicing out falls into the padding.
bool in_padding = [&]() {
for (int64 i = 0; i < rank; ++i) {
int64 start = slice->slice_starts(i);
int64 low = padding_config.dimensions(i).edge_padding_low();
int64 data = pad->operand(0)->shape().dimensions(i);
if (start < low || start >= low + data) {
return true;
}
}
return false;
}();
if (in_padding) {
VLOG(10) << "Folding scalar slice of pad into padding value";
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
slice, HloInstruction::CreateReshape(slice->shape(),
pad->mutable_padding_value())));
return true;
} else {
// We already know the output of the slice is scalar. If the padded
// value is scalar, and it's not in the padding, then it's exactly the
// output value.
bool replaced =
ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0));
if (replaced) {
VLOG(10) << "Folding scalar slice of pad into padded value";
} else {
VLOG(10) << "Not folding scalar slice of pad into padded value as they "
"have different shapes.";
}
return replaced;
}
}
if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
VLOG(10) << "Trying to simplify scalar slice of concat";
// Only do this for R1, there's no chance of this being useful otherwise.
@ -3356,20 +3309,54 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
HloInstruction* pad;
HloInstruction* pad_operand;
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
// Is the result of the slice the pad operand.
bool slice_undoes_pad = true;
// Can the slice be moved to the pad_operand without any padding being read.
bool slice_inside_pad = true;
// Does this slice slice out pading only.
bool slice_in_padding = false;
std::vector<int64> new_starts = slice->slice_starts();
std::vector<int64> new_limits = slice->slice_limits();
for (int64 i = 0; i < slice->shape().rank(); ++i) {
if (slice->slice_starts(i) !=
pad->padding_config().dimensions(i).edge_padding_low()) {
const int64 start = slice->slice_starts(i);
const int64 stride = slice->slice_strides(i);
const int64 limit = slice->slice_limits(i);
const int64 size = pad->shape().dimensions(i);
const auto& dim = pad->padding_config().dimensions(i);
const int64 low = dim.edge_padding_low();
const int64 high = dim.edge_padding_high();
const int64 interior = dim.interior_padding();
const int64 edge = size - high;
if (limit <= low || start >= edge) {
slice_in_padding = true;
break;
}
if (start != low || stride - 1 != interior) {
slice_undoes_pad = false;
}
if (slice->slice_strides(i) - 1 !=
pad->padding_config().dimensions(i).interior_padding()) {
slice_undoes_pad = false;
if (start < low || limit > edge || interior != 0 || stride != 1) {
slice_inside_pad = false;
}
new_starts[i] -= low;
new_limits[i] -= low;
}
if (slice_in_padding) {
return ReplaceInstruction(
slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()));
}
if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
return Status::OK();
}
if (slice_inside_pad) {
TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
MakeSliceHlo(pad_operand, new_starts, new_limits,
slice->slice_strides()));
return ReplaceInstruction(slice, new_slice);
}
}
if (slice->operand(0)->opcode() == HloOpcode::kSlice &&

View File

@ -4389,7 +4389,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
AlgebraicSimplifier simplifier(options);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
}
TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
@ -4410,7 +4410,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
AlgebraicSimplifier simplifier(options);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
}
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
@ -4429,7 +4429,31 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
AlgebraicSimplifierOptions options;
AlgebraicSimplifier simplifier(options);
EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Slice(m::Parameter(0))));
}
TEST_F(AlgebraicSimplifierTest, SliceOfPad) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
param = f32[3,4] parameter(0)
constant = f32[] constant(0.0)
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options;
AlgebraicSimplifier simplifier(options);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0))));
EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1));
}
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
@ -4450,7 +4474,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
AlgebraicSimplifier simplifier(options);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
}
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
@ -4494,7 +4518,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
AlgebraicSimplifier simplifier(options);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0))));
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0))));
}
TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {