diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index cfa5c98f593..297bfd93d12 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2478,7 +2478,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( std::vector broadcast_dimensions; std::vector reshaped_dimensions; for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { - if (operand->shape().dimensions(i) > 1) { + if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand->shape().dimensions(i)); } diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 70974959294..6b0d6b9e11c 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -197,6 +197,65 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } +TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { + auto debug_options = DebugOptions(); + debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); + + // Build a binary computation with degenerate broadcast. + // + // %a = Param({1, 2, 3}); + // %b = Param({1, 2, 1}); + // %add = Add(%a, %b, {}); + ComputationHandle handle; + handle.set_handle(123); + UserComputation computation("TheComputation", handle); + + ParameterRequest a_request; + *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); + a_request.set_name("a"); + a_request.set_parameter(0); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, + computation.AddParameterInstruction(a_request)); + + ParameterRequest b_request; + *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); + b_request.set_name("b"); + b_request.set_parameter(1); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, + computation.AddParameterInstruction(b_request)); + + BinaryOpRequest add; + add.set_binop(BINOP_ADD); + *add.mutable_lhs() = a_handle; + *add.mutable_rhs() = b_handle; + TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); + + auto hlo_resolver = [](const VersionedComputationHandle& handle) { + return nullptr; + }; + VersionedComputationHandle latest_version = computation.GetVersionedHandle(); + + // Build the HLO computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_computation, + computation.BuildHloComputation(latest_version.version, hlo_resolver, + debug_options)); + + // b a + // | | + // reshape | + // | | + // broadcast | + // \ / + // add + EXPECT_EQ(5, hlo_computation->instruction_count()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); + const auto& operands = hlo_computation->root_instruction()->operands(); + ASSERT_EQ(2, operands.size()); + EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kParameter && + operands[1]->opcode() == HloOpcode::kBroadcast); +} + TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { auto debug_options = DebugOptions(); debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);