[XLA] Add S16/U16 support for Map in the HLO evaluator
PiperOrigin-RevId: 343762801 Change-Id: I193d59f3823af604cc08c715156f97af36263aeb
This commit is contained in:
parent
9bb01ec208
commit
9367ca6fa3
@ -4567,6 +4567,50 @@ TEST_F(HloEvaluatorTest, MapBF16) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, MapS16) {
|
||||
const absl::string_view hlo_text = R"(
|
||||
HloModule test
|
||||
|
||||
map_computation {
|
||||
p = s16[] parameter(0)
|
||||
add = s16[] add(p, p)
|
||||
ROOT conv = f32[] convert(add)
|
||||
}
|
||||
|
||||
ENTRY CopyStartCopyDone {
|
||||
c = s16[3] constant({1, 2, 3})
|
||||
ROOT map = f32[3] map(c), to_apply=map_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, MapU16) {
|
||||
const absl::string_view hlo_text = R"(
|
||||
HloModule test
|
||||
|
||||
map_computation {
|
||||
p = u16[] parameter(0)
|
||||
add = u16[] add(p, p)
|
||||
ROOT conv = f32[] convert(add)
|
||||
}
|
||||
|
||||
ENTRY CopyStartCopyDone {
|
||||
c = u16[3] constant({1, 2, 3})
|
||||
ROOT map = f32[3] map(c), to_apply=map_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, DotUpcast) {
|
||||
const absl::string_view hlo_text = R"(
|
||||
HloModule test
|
||||
|
@ -1770,6 +1770,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map));
|
||||
break;
|
||||
}
|
||||
case U16: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint16>(map));
|
||||
break;
|
||||
}
|
||||
case U32: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map));
|
||||
break;
|
||||
@ -1782,6 +1786,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map));
|
||||
break;
|
||||
}
|
||||
case S16: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int16>(map));
|
||||
break;
|
||||
}
|
||||
case S32: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map));
|
||||
break;
|
||||
|
Loading…
Reference in New Issue
Block a user