From f39dec8996b9a812d51728f3f3767ac2bc8b34c6 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Mon, 25 Nov 2019 13:27:36 -0800 Subject: [PATCH] [XLA] Move add across dynamic update slice to make it smaller. PiperOrigin-RevId: 282421400 Change-Id: Ie891961c4ed573c14924c11d8923c4e07eae1ed4 --- .../xla/service/algebraic_simplifier.cc | 41 +++++++++++++++++++ .../xla/service/algebraic_simplifier_test.cc | 38 +++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index fbd6399da4a..4b6b91af122 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -550,6 +550,47 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { sum_of_constants)); } + // Convert add with fullshape into add with partial shape when a + // portion of add is effective: + // zero (fullshape) rhs (partialshape) + // . | | + // . lhs . dynamic_update_slice (fullshape) + // . | | + // Add (fullshape) + // + // to: + // lhs + // | + // dynamic_slice (partialshape) rhs (partialshape) + // . | | + // . lhs . add (partial_shape)+----+ + // . | | + // dynamic_update_slice (fullshape) + // + // This is pattern is discovered in control flow V2 gradient update. + if (Match(add, + m::Add(m::Op(&lhs), + m::Op(&rhs) + .WithOpcode(HloOpcode::kDynamicUpdateSlice) + .WithOperand( + 0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) { + const Shape& partial_shape = rhs->operand(1)->shape(); + auto sliced_lhs = + computation_->AddInstruction(HloInstruction::CreateDynamicSlice( + partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2), + partial_shape.dimensions())); + + auto add_partial = computation_->AddInstruction( + HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd, + sliced_lhs, rhs->mutable_operand(1))); + + auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice( + lhs->shape(), lhs, add_partial, + absl::MakeSpan(rhs->operands()).subspan(2)); + + return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full)); + } + // A*C + B*C => (A+B)*C // // - If A, B, and C are integers, do this unconditionally. Proof of diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 9c84ac10796..2618a12673f 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -637,6 +637,44 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroWithDynamicSlice) { + auto m = CreateNewVerifiedModule(); + Shape full_shape = ShapeUtil::MakeShape(F32, {1800, 12, 512}); + + Shape partial_shape = ShapeUtil::MakeShape(F32, {1, 12, 512}); + + HloComputation::Builder builder(TestName()); + HloInstruction* full_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, full_shape, "param0")); + HloInstruction* partial_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, partial_shape, "param1")); + + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction* index = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(full_shape, zero, {})); + + HloInstruction* dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_shape, bcast, partial_param, {index, index, index})); + + builder.AddInstruction(HloInstruction::CreateBinary( + full_shape, HloOpcode::kAdd, full_param, dynamic_update_slice)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root->opcode(), HloOpcode::kDynamicUpdateSlice); + EXPECT_THAT(root->operand(0), full_param); + EXPECT_THAT(root->operand(1), GmockMatch(m::Add(m::DynamicSlice(), m::Op()))); +} + TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName());