[XLA] Add some more slice of pad optimizations.
PiperOrigin-RevId: 296361878 Change-Id: I4dbef5e94d95f3337c1004e8c3f09c7a94148075
This commit is contained in:
parent
2ca35b7a30
commit
41b6bae3d1
@ -3204,53 +3204,6 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (slice->operand(0)->opcode() == HloOpcode::kPad) {
|
||||
VLOG(10) << "Trying to simplify scalar slice of pad";
|
||||
// Check there's no internal padding. Again, we could handle that too, since
|
||||
// everything is statically known, but it's not worth it.
|
||||
auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0));
|
||||
auto padding_config = pad->padding_config();
|
||||
int64 rank = padding_config.dimensions_size();
|
||||
if (HasInteriorPadding(padding_config)) {
|
||||
VLOG(10) << "Not folding scalar slice of pad, pad has interior padding";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check whether the scalar we're slicing out falls into the padding.
|
||||
bool in_padding = [&]() {
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
int64 start = slice->slice_starts(i);
|
||||
int64 low = padding_config.dimensions(i).edge_padding_low();
|
||||
int64 data = pad->operand(0)->shape().dimensions(i);
|
||||
if (start < low || start >= low + data) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
|
||||
if (in_padding) {
|
||||
VLOG(10) << "Folding scalar slice of pad into padding value";
|
||||
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
|
||||
slice, HloInstruction::CreateReshape(slice->shape(),
|
||||
pad->mutable_padding_value())));
|
||||
return true;
|
||||
} else {
|
||||
// We already know the output of the slice is scalar. If the padded
|
||||
// value is scalar, and it's not in the padding, then it's exactly the
|
||||
// output value.
|
||||
bool replaced =
|
||||
ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0));
|
||||
if (replaced) {
|
||||
VLOG(10) << "Folding scalar slice of pad into padded value";
|
||||
} else {
|
||||
VLOG(10) << "Not folding scalar slice of pad into padded value as they "
|
||||
"have different shapes.";
|
||||
}
|
||||
return replaced;
|
||||
}
|
||||
}
|
||||
|
||||
if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
|
||||
VLOG(10) << "Trying to simplify scalar slice of concat";
|
||||
// Only do this for R1, there's no chance of this being useful otherwise.
|
||||
@ -3356,20 +3309,54 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
HloInstruction* pad;
|
||||
HloInstruction* pad_operand;
|
||||
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
|
||||
// Is the result of the slice the pad operand.
|
||||
bool slice_undoes_pad = true;
|
||||
// Can the slice be moved to the pad_operand without any padding being read.
|
||||
bool slice_inside_pad = true;
|
||||
// Does this slice slice out pading only.
|
||||
bool slice_in_padding = false;
|
||||
std::vector<int64> new_starts = slice->slice_starts();
|
||||
std::vector<int64> new_limits = slice->slice_limits();
|
||||
for (int64 i = 0; i < slice->shape().rank(); ++i) {
|
||||
if (slice->slice_starts(i) !=
|
||||
pad->padding_config().dimensions(i).edge_padding_low()) {
|
||||
const int64 start = slice->slice_starts(i);
|
||||
const int64 stride = slice->slice_strides(i);
|
||||
const int64 limit = slice->slice_limits(i);
|
||||
const int64 size = pad->shape().dimensions(i);
|
||||
|
||||
const auto& dim = pad->padding_config().dimensions(i);
|
||||
const int64 low = dim.edge_padding_low();
|
||||
const int64 high = dim.edge_padding_high();
|
||||
const int64 interior = dim.interior_padding();
|
||||
const int64 edge = size - high;
|
||||
|
||||
if (limit <= low || start >= edge) {
|
||||
slice_in_padding = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (start != low || stride - 1 != interior) {
|
||||
slice_undoes_pad = false;
|
||||
}
|
||||
if (slice->slice_strides(i) - 1 !=
|
||||
pad->padding_config().dimensions(i).interior_padding()) {
|
||||
slice_undoes_pad = false;
|
||||
|
||||
if (start < low || limit > edge || interior != 0 || stride != 1) {
|
||||
slice_inside_pad = false;
|
||||
}
|
||||
new_starts[i] -= low;
|
||||
new_limits[i] -= low;
|
||||
}
|
||||
if (slice_in_padding) {
|
||||
return ReplaceInstruction(
|
||||
slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()));
|
||||
}
|
||||
if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (slice_inside_pad) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
|
||||
MakeSliceHlo(pad_operand, new_starts, new_limits,
|
||||
slice->slice_strides()));
|
||||
return ReplaceInstruction(slice, new_slice);
|
||||
}
|
||||
}
|
||||
|
||||
if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
|
||||
|
@ -4389,7 +4389,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
|
||||
@ -4410,7 +4410,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
|
||||
@ -4429,7 +4429,31 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
|
||||
|
||||
AlgebraicSimplifierOptions options;
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Slice(m::Parameter(0))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPad) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY test {
|
||||
param = f32[3,4] parameter(0)
|
||||
constant = f32[] constant(0.0)
|
||||
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
|
||||
ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
AlgebraicSimplifierOptions options;
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0))));
|
||||
EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
|
||||
@ -4450,7 +4474,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
|
||||
@ -4494,7 +4518,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
|
||||
AlgebraicSimplifier simplifier(options);
|
||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0))));
|
||||
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
|
||||
|
Loading…
Reference in New Issue
Block a user