[XLA] Move add across dynamic update slice to make it smaller.

PiperOrigin-RevId: 282421400
Change-Id: Ie891961c4ed573c14924c11d8923c4e07eae1ed4
This commit is contained in:
Yunxing Dai 2019-11-25 13:27:36 -08:00 committed by TensorFlower Gardener
parent 9c71c11a64
commit f39dec8996
2 changed files with 79 additions and 0 deletions

View File

@ -550,6 +550,47 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
sum_of_constants));
}
// Convert add with fullshape into add with partial shape when a
// portion of add is effective:
// zero (fullshape) rhs (partialshape)
// . | |
// . lhs . dynamic_update_slice (fullshape)
// . | |
// Add (fullshape)
//
// to:
// lhs
// |
// dynamic_slice (partialshape) rhs (partialshape)
// . | |
// . lhs . add (partial_shape)+----+
// . | |
// dynamic_update_slice (fullshape)
//
// This is pattern is discovered in control flow V2 gradient update.
if (Match(add,
m::Add(m::Op(&lhs),
m::Op(&rhs)
.WithOpcode(HloOpcode::kDynamicUpdateSlice)
.WithOperand(
0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) {
const Shape& partial_shape = rhs->operand(1)->shape();
auto sliced_lhs =
computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2),
partial_shape.dimensions()));
auto add_partial = computation_->AddInstruction(
HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd,
sliced_lhs, rhs->mutable_operand(1)));
auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice(
lhs->shape(), lhs, add_partial,
absl::MakeSpan(rhs->operands()).subspan(2));
return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full));
}
// A*C + B*C => (A+B)*C
//
// - If A, B, and C are integers, do this unconditionally. Proof of

View File

@ -637,6 +637,44 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
EXPECT_EQ(root, param0);
}
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroWithDynamicSlice) {
auto m = CreateNewVerifiedModule();
Shape full_shape = ShapeUtil::MakeShape(F32, {1800, 12, 512});
Shape partial_shape = ShapeUtil::MakeShape(F32, {1, 12, 512});
HloComputation::Builder builder(TestName());
HloInstruction* full_param = builder.AddInstruction(
HloInstruction::CreateParameter(0, full_shape, "param0"));
HloInstruction* partial_param = builder.AddInstruction(
HloInstruction::CreateParameter(1, partial_shape, "param1"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
HloInstruction* index = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
HloInstruction* bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(full_shape, zero, {}));
HloInstruction* dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
full_shape, bcast, partial_param, {index, index, index}));
builder.AddInstruction(HloInstruction::CreateBinary(
full_shape, HloOpcode::kAdd, full_param, dynamic_update_slice));
auto computation = m->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root->opcode(), HloOpcode::kDynamicUpdateSlice);
EXPECT_THAT(root->operand(0), full_param);
EXPECT_THAT(root->operand(1), GmockMatch(m::Add(m::DynamicSlice(), m::Op())));
}
TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
auto m = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());