diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index db651d3c323..b04635dda03 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -4442,5 +4442,27 @@ TEST_F(HloEvaluatorTest, CopyStartCopyDone) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_F(HloEvaluatorTest, MapBF16) { + const absl::string_view hlo_text = R"( + HloModule test + + map_computation { + p = bf16[] parameter(0) + add = bf16[] add(p, p) + ROOT conv = f32[] convert(add) + } + + ENTRY CopyStartCopyDone { + c = bf16[3] constant({1, 2, 3}) + ROOT map = f32[3] map(c), to_apply=map_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal expected = LiteralUtil::CreateR1({2.f, 4.f, 6.f}); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 6fa3f9fb34b..e105ea8ce18 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1680,6 +1680,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { MapImpl(map)); break; } + case BF16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } case F32: { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break;