[XLA] Move add across dynamic update slice to make it smaller.
PiperOrigin-RevId: 282421400 Change-Id: Ie891961c4ed573c14924c11d8923c4e07eae1ed4
This commit is contained in:
parent
9c71c11a64
commit
f39dec8996
@ -550,6 +550,47 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
||||
sum_of_constants));
|
||||
}
|
||||
|
||||
// Convert add with fullshape into add with partial shape when a
|
||||
// portion of add is effective:
|
||||
// zero (fullshape) rhs (partialshape)
|
||||
// . | |
|
||||
// . lhs . dynamic_update_slice (fullshape)
|
||||
// . | |
|
||||
// Add (fullshape)
|
||||
//
|
||||
// to:
|
||||
// lhs
|
||||
// |
|
||||
// dynamic_slice (partialshape) rhs (partialshape)
|
||||
// . | |
|
||||
// . lhs . add (partial_shape)+----+
|
||||
// . | |
|
||||
// dynamic_update_slice (fullshape)
|
||||
//
|
||||
// This is pattern is discovered in control flow V2 gradient update.
|
||||
if (Match(add,
|
||||
m::Add(m::Op(&lhs),
|
||||
m::Op(&rhs)
|
||||
.WithOpcode(HloOpcode::kDynamicUpdateSlice)
|
||||
.WithOperand(
|
||||
0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) {
|
||||
const Shape& partial_shape = rhs->operand(1)->shape();
|
||||
auto sliced_lhs =
|
||||
computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2),
|
||||
partial_shape.dimensions()));
|
||||
|
||||
auto add_partial = computation_->AddInstruction(
|
||||
HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd,
|
||||
sliced_lhs, rhs->mutable_operand(1)));
|
||||
|
||||
auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice(
|
||||
lhs->shape(), lhs, add_partial,
|
||||
absl::MakeSpan(rhs->operands()).subspan(2));
|
||||
|
||||
return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full));
|
||||
}
|
||||
|
||||
// A*C + B*C => (A+B)*C
|
||||
//
|
||||
// - If A, B, and C are integers, do this unconditionally. Proof of
|
||||
|
@ -637,6 +637,44 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
|
||||
EXPECT_EQ(root, param0);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroWithDynamicSlice) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
Shape full_shape = ShapeUtil::MakeShape(F32, {1800, 12, 512});
|
||||
|
||||
Shape partial_shape = ShapeUtil::MakeShape(F32, {1, 12, 512});
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* full_param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, full_shape, "param0"));
|
||||
HloInstruction* partial_param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, partial_shape, "param1"));
|
||||
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||
HloInstruction* index = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
|
||||
|
||||
HloInstruction* bcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(full_shape, zero, {}));
|
||||
|
||||
HloInstruction* dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
full_shape, bcast, partial_param, {index, index, index}));
|
||||
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
full_shape, HloOpcode::kAdd, full_param, dynamic_update_slice));
|
||||
|
||||
auto computation = m->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
|
||||
AlgebraicSimplifier simplifier(default_options_);
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
root = computation->root_instruction();
|
||||
EXPECT_THAT(root->opcode(), HloOpcode::kDynamicUpdateSlice);
|
||||
EXPECT_THAT(root->operand(0), full_param);
|
||||
EXPECT_THAT(root->operand(1), GmockMatch(m::Add(m::DynamicSlice(), m::Op())));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
Loading…
Reference in New Issue
Block a user