From c05a9edacd9ca0ad9db55b17c1d98b84aed03df4 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Tue, 6 Oct 2020 12:56:52 -0700 Subject: [PATCH] [XLA] Add support for parsing negative nans as constants in HLO parser. PiperOrigin-RevId: 335701644 Change-Id: Icedf22f31a99c8522e23e8ae95536a0cf1b7d767 --- tensorflow/compiler/xla/service/hlo_lexer.cc | 8 ++++++++ tensorflow/compiler/xla/service/hlo_lexer.h | 1 + tensorflow/compiler/xla/service/hlo_parser.cc | 4 ++++ tensorflow/compiler/xla/service/hlo_parser_test.cc | 13 +++++++++++++ 4 files changed, 26 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 749193a83ef..3c44b390969 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -387,6 +387,12 @@ TokKind HloLexer::LexNumberOrPattern() { return TokKind::kNegInf; } + static LazyRE2 neg_nan = {"-nan"}; + if (RE2::Consume(&consumable, *neg_nan)) { + current_ptr_ = consumable.begin(); + return TokKind::kNegNan; + } + return TokKind::kError; } @@ -502,6 +508,8 @@ string TokKindToString(TokKind kind) { return "kw_nan"; case TokKind::kw_inf: return "kw_inf"; + case TokKind::kNegNan: + return "kNegNan"; case TokKind::kNegInf: return "kNegInf"; case TokKind::kPrimitiveType: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index b8c7debaab4..4068ad76581 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -65,6 +65,7 @@ enum class TokKind { kw_nan, kw_inf, + kNegNan, // -nan kNegInf, // -inf // Typed tokens. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 37bdeaa1073..d04a7695f3c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -2717,6 +2717,7 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) { case TokKind::kInt: case TokKind::kDecimal: case TokKind::kw_nan: + case TokKind::kNegNan: case TokKind::kw_inf: case TokKind::kNegInf: { add_one_elem_seen(); @@ -4374,6 +4375,9 @@ bool HloParserImpl::ParseDouble(double* result) { case TokKind::kw_nan: *result = std::numeric_limits::quiet_NaN(); break; + case TokKind::kNegNan: + *result = -std::numeric_limits::quiet_NaN(); + break; case TokKind::kw_inf: *result = std::numeric_limits::infinity(); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d220d735622..3cb9a1c564b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -2120,6 +2120,19 @@ ENTRY %ShortConstant.v4 () -> f32[67,89] { EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); } +TEST_F(HloParserTest, NegativeNan) { + const string original = R"(HloModule NegativeNan_module + +ENTRY %NegativeNan () -> bf16[2] { + ROOT %constant = bf16[2]{0} constant({-nan, -nan}) +} + +)"; + auto result = ParseAndReturnUnverifiedModule(original); + EXPECT_EQ(Status::OK(), result.status()); + EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); +} + TEST_F(HloParserTest, AttributesAnyOrder) { const string original = R"(HloModule any_order_module