[XLA] Add support for bf16 map() to HloEvaluator
PiperOrigin-RevId: 307796916 Change-Id: I31d2f96069d095bd197aef3b741fec42074b54b5
This commit is contained in:
parent
83b579780d
commit
e4eb0fb64a
@ -4442,5 +4442,27 @@ TEST_F(HloEvaluatorTest, CopyStartCopyDone) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, MapBF16) {
|
||||
const absl::string_view hlo_text = R"(
|
||||
HloModule test
|
||||
|
||||
map_computation {
|
||||
p = bf16[] parameter(0)
|
||||
add = bf16[] add(p, p)
|
||||
ROOT conv = f32[] convert(add)
|
||||
}
|
||||
|
||||
ENTRY CopyStartCopyDone {
|
||||
c = bf16[3] constant({1, 2, 3})
|
||||
ROOT map = f32[3] map(c), to_apply=map_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
@ -1680,6 +1680,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
MapImpl<Eigen::half>(map));
|
||||
break;
|
||||
}
|
||||
case BF16: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bfloat16>(map));
|
||||
break;
|
||||
}
|
||||
case F32: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
|
||||
break;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user