From fcfdbcf14abc526a7f78468c19345a60d64573b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Jun 2020 09:49:22 -0700 Subject: [PATCH] Revert of [XLA] Add support for sinking broadcasts through ops with multiple broadcasts operands. as it is causing some internal failures. Investigation in progress. PiperOrigin-RevId: 315293975 Change-Id: If65d7aaf53f29cac52072bc14b06e3b5a8c5fc49 --- .../xla/service/algebraic_simplifier.cc | 54 ++++---------- .../xla/service/algebraic_simplifier_test.cc | 73 ------------------- 2 files changed, 14 insertions(+), 113 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4025cb46f18..e0a8b87c83b 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -3058,20 +3058,6 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( return false; } HloInstruction* operand = broadcast->mutable_operand(0); - auto is_scalar_broadcast = [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(instruction->operand(0)->shape()); - }; - auto is_equal_broadcast = [operand, - broadcast](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kBroadcast && - ShapeUtil::Equal(operand->shape(), - instruction->operand(0)->shape()) && - broadcast->dimensions() == instruction->dimensions(); - }; - auto is_compatible_broadcast = [&](const HloInstruction* instruction) { - return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction); - }; for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; @@ -3090,20 +3076,18 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( continue; } - // Check if all the operands of the user are compatible broadcasts for - // sinking. (They are either scalar broadcasts or broadcasts casting - // from/to the same shape/dimensions) - int64 compatible_broadcast_count = 0; + // Find the unique non-scalar operand or continue if there isn't one. + int64 scalar_broadcast_count = 0; int64 broadcast_use_count = 0; for (HloInstruction* user_operand : user->operands()) { - if (is_compatible_broadcast(user_operand)) { - ++compatible_broadcast_count; + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + ++scalar_broadcast_count; } else if (broadcast == user_operand) { ++broadcast_use_count; } } - if (compatible_broadcast_count + broadcast_use_count != - user->operand_count()) { + if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { continue; } std::vector new_operands; @@ -3111,24 +3095,14 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( Shape changed_shape; for (HloInstruction* user_operand : user->operands()) { - // If this is a broadcast operand that is not our original broadcast input - // to this function then we might need to change the input. - if (is_compatible_broadcast(user_operand)) { - // If this is a broadcast from a scalar value rewrite a broadcast from - // the scalar to the new shape enforced from the other broadcast - // operands. - if (is_scalar_broadcast(user_operand)) { - changed_shape = ShapeUtil::ChangeElementType( - operand->shape(), user_operand->shape().element_type()); - simplifier_->UpdateLayout(&changed_shape); - new_operands.push_back( - computation_->AddInstruction(HloInstruction::CreateBroadcast( - changed_shape, user_operand->mutable_operand(0), {}))); - } else { - // For the non-scalar broadcasts we guarantee that the shape of the - // operand of the broadcast needs to be already a compatible shape. - new_operands.push_back(user_operand->mutable_operand(0)); - } + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + changed_shape = ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + changed_shape, user_operand->mutable_operand(0), {}))); } else { CHECK_EQ(broadcast, user_operand); new_operands.push_back(operand); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index bcfc2fdc740..3ac47821654 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -338,79 +338,6 @@ TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) { m::ConstantScalar(3.0)))))); } -TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - b0 = f32[4] broadcast(p0), dimensions={} - b1 = f32[4] broadcast(p1), dimensions={} - ROOT multiply = f32[4] multiply(b1, b0) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); - EXPECT_THAT( - m->entry_computation()->root_instruction(), - GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)), - m::Broadcast(m::Parameter(0)))))); -} - -TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = f32[4] parameter(0) - c0 = f32[] constant(2.0) - b0 = f32[4,2] broadcast(c0), dimensions={} - b1 = f32[4,2] broadcast(p0), dimensions={0} - ROOT multiply = f32[4,2] multiply(b1, b0) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Broadcast(m::Multiply( - m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0)))))); -} - -TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = f32[4] parameter(0) - p1 = f32[4] parameter(1) - b0 = f32[4,2] broadcast(p0), dimensions={0} - b1 = f32[4,2] broadcast(p1), dimensions={0} - ROOT multiply = f32[4,2] multiply(b1, b0) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); - EXPECT_THAT( - m->entry_computation()->root_instruction(), - GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0))))); -} - -TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) { - const char* kModuleStr = R"( - HloModule m - test { - p0 = f32[4] parameter(0) - p1 = f32[8] parameter(1) - b0 = f32[4,8] broadcast(p0), dimensions={0} - b1 = f32[4,8] broadcast(p1), dimensions={1} - ROOT multiply = f32[4,8] multiply(b1, b0) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)), - m::Broadcast(m::Parameter(0))))); -} - TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMultiplyOfConstantAndBroadcast) { const char* kModuleStr = R"(