[XLA] Add support for bf16 map() to HloEvaluator

PiperOrigin-RevId: 307796916
Change-Id: I31d2f96069d095bd197aef3b741fec42074b54b5
This commit is contained in:
Benjamin Kramer 2020-04-22 05:31:05 -07:00 committed by TensorFlower Gardener
parent 83b579780d
commit e4eb0fb64a
2 changed files with 26 additions and 0 deletions

View File

@ -4442,5 +4442,27 @@ TEST_F(HloEvaluatorTest, CopyStartCopyDone) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_F(HloEvaluatorTest, MapBF16) {
const absl::string_view hlo_text = R"(
HloModule test
map_computation {
p = bf16[] parameter(0)
add = bf16[] add(p, p)
ROOT conv = f32[] convert(add)
}
ENTRY CopyStartCopyDone {
c = bf16[3] constant({1, 2, 3})
ROOT map = f32[3] map(c), to_apply=map_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
TF_ASSERT_OK_AND_ASSIGN(
Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
} // namespace
} // namespace xla

View File

@ -1680,6 +1680,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
MapImpl<Eigen::half>(map));
break;
}
case BF16: {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bfloat16>(map));
break;
}
case F32: {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
break;