Don't remove trivial dimensions if they require no broadcast.
Currently, the conversion from implicit broadcasts to explicit broadcasts also removes dimensions which are the same as the output shape. This means that sometimes potentially costly (on some backends) reshapes are required. This CL changes the conversion that it will only remove trivial dimensions if they actually require a broadcast. PiperOrigin-RevId: 166970167
This commit is contained in:
parent
d78a24d407
commit
96b8526273
@ -2478,7 +2478,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
|
||||
std::vector<int64> broadcast_dimensions;
|
||||
std::vector<int64> reshaped_dimensions;
|
||||
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
|
||||
if (operand->shape().dimensions(i) > 1) {
|
||||
if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
|
||||
broadcast_dimensions.push_back(i);
|
||||
reshaped_dimensions.push_back(operand->shape().dimensions(i));
|
||||
}
|
||||
|
@ -197,6 +197,65 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
|
||||
operands[1]->opcode() == HloOpcode::kBroadcast);
|
||||
}
|
||||
|
||||
TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
|
||||
auto debug_options = DebugOptions();
|
||||
debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
|
||||
|
||||
// Build a binary computation with degenerate broadcast.
|
||||
//
|
||||
// %a = Param({1, 2, 3});
|
||||
// %b = Param({1, 2, 1});
|
||||
// %add = Add(%a, %b, {});
|
||||
ComputationHandle handle;
|
||||
handle.set_handle(123);
|
||||
UserComputation computation("TheComputation", handle);
|
||||
|
||||
ParameterRequest a_request;
|
||||
*a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||
a_request.set_name("a");
|
||||
a_request.set_parameter(0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
|
||||
computation.AddParameterInstruction(a_request));
|
||||
|
||||
ParameterRequest b_request;
|
||||
*b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1});
|
||||
b_request.set_name("b");
|
||||
b_request.set_parameter(1);
|
||||
TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
|
||||
computation.AddParameterInstruction(b_request));
|
||||
|
||||
BinaryOpRequest add;
|
||||
add.set_binop(BINOP_ADD);
|
||||
*add.mutable_lhs() = a_handle;
|
||||
*add.mutable_rhs() = b_handle;
|
||||
TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
|
||||
|
||||
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
|
||||
return nullptr;
|
||||
};
|
||||
VersionedComputationHandle latest_version = computation.GetVersionedHandle();
|
||||
|
||||
// Build the HLO computation.
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloComputation> hlo_computation,
|
||||
computation.BuildHloComputation(latest_version.version, hlo_resolver,
|
||||
debug_options));
|
||||
|
||||
// b a
|
||||
// | |
|
||||
// reshape |
|
||||
// | |
|
||||
// broadcast |
|
||||
// \ /
|
||||
// add
|
||||
EXPECT_EQ(5, hlo_computation->instruction_count());
|
||||
EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
|
||||
const auto& operands = hlo_computation->root_instruction()->operands();
|
||||
ASSERT_EQ(2, operands.size());
|
||||
EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kParameter &&
|
||||
operands[1]->opcode() == HloOpcode::kBroadcast);
|
||||
}
|
||||
|
||||
TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
|
||||
auto debug_options = DebugOptions();
|
||||
debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
|
||||
|
Loading…
Reference in New Issue
Block a user