Remove special case handling of potential bitcasts.
In the GPU backend, this code would be never triggered anyway, because when we do fusion, we would already have replaced reshapes and transposes that can be bitcasts with bitcasts. In the CPU backend, we would have done this already for reshapes that only add or remove 1-sized dimensions, and for transposes inside fusions it is actually beneficial to not replace transposes by bitcasts because for transposes we can more easily derive the multi-dimensional index for the operand. Also remove the related test, and the now unused CouldBeBitcast method. Note that the tests also never tested what they were supposed to test, there should have been another operation *after* the bitcast as the root of the potential fusion node. PiperOrigin-RevId: 302844175 Change-Id: Ic3640246dcd75b65a2e8cbd823f7e70c25038d4d
This commit is contained in:
parent
f8f4b9e386
commit
0298dd6900
@ -3525,17 +3525,6 @@ bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
|
||||
return lhs->unique_id() < rhs->unique_id();
|
||||
}
|
||||
|
||||
bool HloInstruction::CouldBeBitcast() const {
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kTranspose:
|
||||
return true;
|
||||
case HloOpcode::kReshape:
|
||||
return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Status HloInstruction::GetBackendConfigInternal(
|
||||
tensorflow::protobuf::Message* proto) const {
|
||||
proto->Clear();
|
||||
|
@ -1562,10 +1562,6 @@ class HloInstruction {
|
||||
// Returns the module for this instruction.
|
||||
HloModule* GetModule() const;
|
||||
|
||||
// Returns whether we could assign input and output layouts to this
|
||||
// instruction to make it a bitcast.
|
||||
bool CouldBeBitcast() const;
|
||||
|
||||
// Get/Set the number of partitions per outer dimension (in order, starting
|
||||
// with outer-most dimension first). Currently used by the parallel cpu
|
||||
// backend to partition HLOs into parallel tasks.
|
||||
|
@ -692,14 +692,6 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (producer->CouldBeBitcast() &&
|
||||
// We can't fuse parameters anyhow, so we leave the user unfused to become
|
||||
// a bitcast. If the operand is not a parameter, we would break a
|
||||
// potential fusion to make it a bitcast, which is not so clear a win.
|
||||
producer->operand(0)->opcode() == HloOpcode::kParameter) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -110,54 +110,6 @@ TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) {
|
||||
<< module->ToString();
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
|
||||
auto reshape1 = builder.AddInstruction(
|
||||
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(reshape1, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
|
||||
auto reshape1 = builder.AddInstruction(
|
||||
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(reshape1, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
|
||||
auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(S32, {}), param0, {}));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(transpose1, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
|
||||
|
Loading…
x
Reference in New Issue
Block a user