From 96b852627307d9375b2391ef6273abc78a2db5b2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 Aug 2017 02:43:21 -0700 Subject: [PATCH] Don't remove trivial dimensions if they require no broadcast. Currently, the conversion from implicit broadcasts to explicit broadcasts also removes dimensions which are the same as the output shape. This means that sometimes potentially costly (on some backends) reshapes are required. This CL changes the conversion that it will only remove trivial dimensions if they actually require a broadcast. PiperOrigin-RevId: 166970167 --- .../compiler/xla/service/user_computation.cc | 2 +- .../xla/service/user_computation_test.cc | 59 +++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index cfa5c98f593..297bfd93d12 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2478,7 +2478,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( std::vector broadcast_dimensions; std::vector reshaped_dimensions; for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { - if (operand->shape().dimensions(i) > 1) { + if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand->shape().dimensions(i)); } diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 70974959294..6b0d6b9e11c 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -197,6 +197,65 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } +TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { + auto debug_options = DebugOptions(); + debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); + + // Build a binary computation with degenerate broadcast. + // + // %a = Param({1, 2, 3}); + // %b = Param({1, 2, 1}); + // %add = Add(%a, %b, {}); + ComputationHandle handle; + handle.set_handle(123); + UserComputation computation("TheComputation", handle); + + ParameterRequest a_request; + *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); + a_request.set_name("a"); + a_request.set_parameter(0); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, + computation.AddParameterInstruction(a_request)); + + ParameterRequest b_request; + *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); + b_request.set_name("b"); + b_request.set_parameter(1); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, + computation.AddParameterInstruction(b_request)); + + BinaryOpRequest add; + add.set_binop(BINOP_ADD); + *add.mutable_lhs() = a_handle; + *add.mutable_rhs() = b_handle; + TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); + + auto hlo_resolver = [](const VersionedComputationHandle& handle) { + return nullptr; + }; + VersionedComputationHandle latest_version = computation.GetVersionedHandle(); + + // Build the HLO computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_computation, + computation.BuildHloComputation(latest_version.version, hlo_resolver, + debug_options)); + + // b a + // | | + // reshape | + // | | + // broadcast | + // \ / + // add + EXPECT_EQ(5, hlo_computation->instruction_count()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); + const auto& operands = hlo_computation->root_instruction()->operands(); + ASSERT_EQ(2, operands.size()); + EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kParameter && + operands[1]->opcode() == HloOpcode::kBroadcast); +} + TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { auto debug_options = DebugOptions(); debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);