[XLA] Add some more slice of pad optimizations.
PiperOrigin-RevId: 296361878 Change-Id: I4dbef5e94d95f3337c1004e8c3f09c7a94148075
This commit is contained in:
parent
2ca35b7a30
commit
41b6bae3d1
tensorflow/compiler/xla/service
@ -3204,53 +3204,6 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slice->operand(0)->opcode() == HloOpcode::kPad) {
|
|
||||||
VLOG(10) << "Trying to simplify scalar slice of pad";
|
|
||||||
// Check there's no internal padding. Again, we could handle that too, since
|
|
||||||
// everything is statically known, but it's not worth it.
|
|
||||||
auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0));
|
|
||||||
auto padding_config = pad->padding_config();
|
|
||||||
int64 rank = padding_config.dimensions_size();
|
|
||||||
if (HasInteriorPadding(padding_config)) {
|
|
||||||
VLOG(10) << "Not folding scalar slice of pad, pad has interior padding";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check whether the scalar we're slicing out falls into the padding.
|
|
||||||
bool in_padding = [&]() {
|
|
||||||
for (int64 i = 0; i < rank; ++i) {
|
|
||||||
int64 start = slice->slice_starts(i);
|
|
||||||
int64 low = padding_config.dimensions(i).edge_padding_low();
|
|
||||||
int64 data = pad->operand(0)->shape().dimensions(i);
|
|
||||||
if (start < low || start >= low + data) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}();
|
|
||||||
|
|
||||||
if (in_padding) {
|
|
||||||
VLOG(10) << "Folding scalar slice of pad into padding value";
|
|
||||||
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
|
|
||||||
slice, HloInstruction::CreateReshape(slice->shape(),
|
|
||||||
pad->mutable_padding_value())));
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
// We already know the output of the slice is scalar. If the padded
|
|
||||||
// value is scalar, and it's not in the padding, then it's exactly the
|
|
||||||
// output value.
|
|
||||||
bool replaced =
|
|
||||||
ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0));
|
|
||||||
if (replaced) {
|
|
||||||
VLOG(10) << "Folding scalar slice of pad into padded value";
|
|
||||||
} else {
|
|
||||||
VLOG(10) << "Not folding scalar slice of pad into padded value as they "
|
|
||||||
"have different shapes.";
|
|
||||||
}
|
|
||||||
return replaced;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
|
if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
|
||||||
VLOG(10) << "Trying to simplify scalar slice of concat";
|
VLOG(10) << "Trying to simplify scalar slice of concat";
|
||||||
// Only do this for R1, there's no chance of this being useful otherwise.
|
// Only do this for R1, there's no chance of this being useful otherwise.
|
||||||
@ -3356,20 +3309,54 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
|||||||
HloInstruction* pad;
|
HloInstruction* pad;
|
||||||
HloInstruction* pad_operand;
|
HloInstruction* pad_operand;
|
||||||
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
|
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
|
||||||
|
// Is the result of the slice the pad operand.
|
||||||
bool slice_undoes_pad = true;
|
bool slice_undoes_pad = true;
|
||||||
|
// Can the slice be moved to the pad_operand without any padding being read.
|
||||||
|
bool slice_inside_pad = true;
|
||||||
|
// Does this slice slice out pading only.
|
||||||
|
bool slice_in_padding = false;
|
||||||
|
std::vector<int64> new_starts = slice->slice_starts();
|
||||||
|
std::vector<int64> new_limits = slice->slice_limits();
|
||||||
for (int64 i = 0; i < slice->shape().rank(); ++i) {
|
for (int64 i = 0; i < slice->shape().rank(); ++i) {
|
||||||
if (slice->slice_starts(i) !=
|
const int64 start = slice->slice_starts(i);
|
||||||
pad->padding_config().dimensions(i).edge_padding_low()) {
|
const int64 stride = slice->slice_strides(i);
|
||||||
|
const int64 limit = slice->slice_limits(i);
|
||||||
|
const int64 size = pad->shape().dimensions(i);
|
||||||
|
|
||||||
|
const auto& dim = pad->padding_config().dimensions(i);
|
||||||
|
const int64 low = dim.edge_padding_low();
|
||||||
|
const int64 high = dim.edge_padding_high();
|
||||||
|
const int64 interior = dim.interior_padding();
|
||||||
|
const int64 edge = size - high;
|
||||||
|
|
||||||
|
if (limit <= low || start >= edge) {
|
||||||
|
slice_in_padding = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start != low || stride - 1 != interior) {
|
||||||
slice_undoes_pad = false;
|
slice_undoes_pad = false;
|
||||||
}
|
}
|
||||||
if (slice->slice_strides(i) - 1 !=
|
|
||||||
pad->padding_config().dimensions(i).interior_padding()) {
|
if (start < low || limit > edge || interior != 0 || stride != 1) {
|
||||||
slice_undoes_pad = false;
|
slice_inside_pad = false;
|
||||||
}
|
}
|
||||||
|
new_starts[i] -= low;
|
||||||
|
new_limits[i] -= low;
|
||||||
|
}
|
||||||
|
if (slice_in_padding) {
|
||||||
|
return ReplaceInstruction(
|
||||||
|
slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()));
|
||||||
}
|
}
|
||||||
if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
|
if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
if (slice_inside_pad) {
|
||||||
|
TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
|
||||||
|
MakeSliceHlo(pad_operand, new_starts, new_limits,
|
||||||
|
slice->slice_strides()));
|
||||||
|
return ReplaceInstruction(slice, new_slice);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
|
if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
|
||||||
|
@ -4389,7 +4389,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
|
|||||||
AlgebraicSimplifier simplifier(options);
|
AlgebraicSimplifier simplifier(options);
|
||||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
auto root = module->entry_computation()->root_instruction();
|
auto root = module->entry_computation()->root_instruction();
|
||||||
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
|
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
|
TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
|
||||||
@ -4410,7 +4410,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
|
|||||||
AlgebraicSimplifier simplifier(options);
|
AlgebraicSimplifier simplifier(options);
|
||||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
auto root = module->entry_computation()->root_instruction();
|
auto root = module->entry_computation()->root_instruction();
|
||||||
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
|
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
|
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
|
||||||
@ -4429,7 +4429,31 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
|
|||||||
|
|
||||||
AlgebraicSimplifierOptions options;
|
AlgebraicSimplifierOptions options;
|
||||||
AlgebraicSimplifier simplifier(options);
|
AlgebraicSimplifier simplifier(options);
|
||||||
EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Slice(m::Parameter(0))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, SliceOfPad) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY test {
|
||||||
|
param = f32[3,4] parameter(0)
|
||||||
|
constant = f32[] constant(0.0)
|
||||||
|
pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
|
||||||
|
ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
|
||||||
|
AlgebraicSimplifierOptions options;
|
||||||
|
AlgebraicSimplifier simplifier(options);
|
||||||
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0))));
|
||||||
|
EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
|
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
|
||||||
@ -4450,7 +4474,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
|
|||||||
AlgebraicSimplifier simplifier(options);
|
AlgebraicSimplifier simplifier(options);
|
||||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
auto root = module->entry_computation()->root_instruction();
|
auto root = module->entry_computation()->root_instruction();
|
||||||
EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant())));
|
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
|
TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
|
||||||
@ -4494,7 +4518,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
|
|||||||
AlgebraicSimplifier simplifier(options);
|
AlgebraicSimplifier simplifier(options);
|
||||||
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
auto root = module->entry_computation()->root_instruction();
|
auto root = module->entry_computation()->root_instruction();
|
||||||
EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0))));
|
EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
|
TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
|
||||||
|
Loading…
Reference in New Issue
Block a user