[XLA] Add S16/U16 support for Map in the HLO evaluator

PiperOrigin-RevId: 343762801
Change-Id: I193d59f3823af604cc08c715156f97af36263aeb
This commit is contained in:
David Majnemer 2020-11-22 16:21:29 -08:00 committed by TensorFlower Gardener
parent 9bb01ec208
commit 9367ca6fa3
2 changed files with 52 additions and 0 deletions

View File

@ -4567,6 +4567,50 @@ TEST_F(HloEvaluatorTest, MapBF16) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_F(HloEvaluatorTest, MapS16) {
const absl::string_view hlo_text = R"(
HloModule test
map_computation {
p = s16[] parameter(0)
add = s16[] add(p, p)
ROOT conv = f32[] convert(add)
}
ENTRY CopyStartCopyDone {
c = s16[3] constant({1, 2, 3})
ROOT map = f32[3] map(c), to_apply=map_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
TF_ASSERT_OK_AND_ASSIGN(
Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_F(HloEvaluatorTest, MapU16) {
const absl::string_view hlo_text = R"(
HloModule test
map_computation {
p = u16[] parameter(0)
add = u16[] add(p, p)
ROOT conv = f32[] convert(add)
}
ENTRY CopyStartCopyDone {
c = u16[3] constant({1, 2, 3})
ROOT map = f32[3] map(c), to_apply=map_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
TF_ASSERT_OK_AND_ASSIGN(
Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_F(HloEvaluatorTest, DotUpcast) {
const absl::string_view hlo_text = R"(
HloModule test

View File

@ -1770,6 +1770,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map));
break;
}
case U16: {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint16>(map));
break;
}
case U32: {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map));
break;
@ -1782,6 +1786,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map));
break;
}
case S16: {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int16>(map));
break;
}
case S32: {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map));
break;