[XLA] Fix HloElementTypeConverter's handling of bitcast convert.

Previously, if you asked HloElementTypeConverter to e.g. convert BF16 to F32,
it would convert a BF16 input to bitcast-convert to F32.

This don't work, because the output of the bitcast-convert still has width 16
bits, and the input/output widths must match.

The right thing is to ignore bitcast-convert operations for the purposes of
this pass.

PiperOrigin-RevId: 232554328
This commit is contained in:
Justin Lebar 2019-02-05 14:05:59 -08:00 committed by TensorFlower Gardener
parent 234d42b5d4
commit f77ec5d5ee
2 changed files with 15 additions and 0 deletions

View File

@ -127,6 +127,7 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
// These are ops where it does not make sense to convert them.
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
opcode == HloOpcode::kBitcastConvert ||
opcode == HloOpcode::kGetTupleElement ||
opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
continue;

View File

@ -176,5 +176,19 @@ ENTRY main {
EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0));
}
TEST_F(HloElementTypeConverterTest, BitcastConvertIsUnmodified) {
const string& hlo_string = R"(
HloModule test
ENTRY test {
p = bf16[] parameter(0)
ROOT c = u16[] bitcast-convert(p)
})";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
HloElementTypeConverter converter(BF16, F32);
TF_ASSERT_OK_AND_ASSIGN(bool converted, RunHloPass(&converter, module.get()));
EXPECT_FALSE(converted);
}
} // namespace
} // namespace xla