[XLA] Fix HloElementTypeConverter's handling of bitcast convert.
Previously, if you asked HloElementTypeConverter to e.g. convert BF16 to F32, it would convert a BF16 input to bitcast-convert to F32. This don't work, because the output of the bitcast-convert still has width 16 bits, and the input/output widths must match. The right thing is to ignore bitcast-convert operations for the purposes of this pass. PiperOrigin-RevId: 232554328
This commit is contained in:
parent
234d42b5d4
commit
f77ec5d5ee
@ -127,6 +127,7 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
|
||||
// These are ops where it does not make sense to convert them.
|
||||
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
|
||||
opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
|
||||
opcode == HloOpcode::kBitcastConvert ||
|
||||
opcode == HloOpcode::kGetTupleElement ||
|
||||
opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
|
||||
continue;
|
||||
|
@ -176,5 +176,19 @@ ENTRY main {
|
||||
EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0));
|
||||
}
|
||||
|
||||
TEST_F(HloElementTypeConverterTest, BitcastConvertIsUnmodified) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule test
|
||||
|
||||
ENTRY test {
|
||||
p = bf16[] parameter(0)
|
||||
ROOT c = u16[] bitcast-convert(p)
|
||||
})";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
HloElementTypeConverter converter(BF16, F32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool converted, RunHloPass(&converter, module.get()));
|
||||
EXPECT_FALSE(converted);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user