From 2d72bcdebeb9290b0afd475df758d1802547e3b3 Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee <hyouklee@google.com> Date: Mon, 27 Apr 2020 13:36:33 -0700 Subject: [PATCH] Change fronend-attribute value format as a string and add missing parsing PiperOrigin-RevId: 308688157 Change-Id: Iba79b5e7ede0f5b1fa562b75ada15315b6b6cae1 --- tensorflow/compiler/xla/service/hlo_instruction.cc | 10 ++++++++-- tensorflow/compiler/xla/service/hlo_parser.cc | 5 ++++- tensorflow/compiler/xla/service/hlo_parser_test.cc | 3 ++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 22b74663087..d0501c1a26a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -3366,8 +3366,14 @@ string FrontendAttributesToString( std::vector<std::pair<string, string>> sorted_attributes( frontend_attributes.map().begin(), frontend_attributes.map().end()); absl::c_sort(sorted_attributes); - return absl::StrFormat( - "{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("="))); + // Frontend attribute is a comma-separated list of attribute="value" pairs, + // e.g., frontend_attributes={name="value_a",type="int32"}. + const auto formatter = [](string* out, + const std::pair<string, string>& item) { + absl::StrAppend(out, item.first, "=\"", item.second, "\""); + }; + return absl::StrFormat("{%s}", + absl::StrJoin(sorted_attributes, ",", formatter)); } string PaddingConfigToString(const PaddingConfig& padding) { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index f41ed233ed3..76003cda002 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1892,6 +1892,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, if (outer_dimension_partitions) { instruction->set_outer_dimension_partitions(*outer_dimension_partitions); } + if (frontend_attributes) { + instruction->set_frontend_attributes(*frontend_attributes); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1946,7 +1949,7 @@ bool HloParserImpl::ParseFrontendAttributes( if (!ParseAttributeName(&attribute)) { return false; } - if (lexer_.GetKind() != TokKind::kIdent) { + if (lexer_.GetKind() != TokKind::kString) { return false; } (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal(); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 3d1f21ee8be..7e66b4e648d 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -2422,7 +2422,8 @@ TEST_F(HloParserTest, ParseSharding) { } TEST_F(HloParserTest, ParseFrontendAttributes) { - const string original = "{attr_a=test_a,attr_b=b}"; + const string original = + R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})"; TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes, ParseFrontendAttributes(original)); EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);