Don't remove trivial dimensions if they require no broadcast.

Currently, the conversion from implicit broadcasts to explicit broadcasts also
removes dimensions which are the same as the output shape. This means that
sometimes potentially costly (on some backends) reshapes are required.
This CL changes the conversion that it will only remove trivial dimensions if
they actually require a broadcast.

PiperOrigin-RevId: 166970167
This commit is contained in:
A. Unique TensorFlower 2017-08-30 02:43:21 -07:00 committed by TensorFlower Gardener
parent d78a24d407
commit 96b8526273
2 changed files with 60 additions and 1 deletions

View File

@ -2478,7 +2478,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
std::vector<int64> broadcast_dimensions;
std::vector<int64> reshaped_dimensions;
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
if (operand->shape().dimensions(i) > 1) {
if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
broadcast_dimensions.push_back(i);
reshaped_dimensions.push_back(operand->shape().dimensions(i));
}

View File

@ -197,6 +197,65 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
operands[1]->opcode() == HloOpcode::kBroadcast);
}
TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
auto debug_options = DebugOptions();
debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
// Build a binary computation with degenerate broadcast.
//
// %a = Param({1, 2, 3});
// %b = Param({1, 2, 1});
// %add = Add(%a, %b, {});
ComputationHandle handle;
handle.set_handle(123);
UserComputation computation("TheComputation", handle);
ParameterRequest a_request;
*a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3});
a_request.set_name("a");
a_request.set_parameter(0);
TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
computation.AddParameterInstruction(a_request));
ParameterRequest b_request;
*b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1});
b_request.set_name("b");
b_request.set_parameter(1);
TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
computation.AddParameterInstruction(b_request));
BinaryOpRequest add;
add.set_binop(BINOP_ADD);
*add.mutable_lhs() = a_handle;
*add.mutable_rhs() = b_handle;
TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
return nullptr;
};
VersionedComputationHandle latest_version = computation.GetVersionedHandle();
// Build the HLO computation.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloComputation> hlo_computation,
computation.BuildHloComputation(latest_version.version, hlo_resolver,
debug_options));
// b a
// | |
// reshape |
// | |
// broadcast |
// \ /
// add
EXPECT_EQ(5, hlo_computation->instruction_count());
EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
const auto& operands = hlo_computation->root_instruction()->operands();
ASSERT_EQ(2, operands.size());
EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kParameter &&
operands[1]->opcode() == HloOpcode::kBroadcast);
}
TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
auto debug_options = DebugOptions();
debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);