From d333176f5a615f23f225be7bc906a2c5d9f56b51 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Mon, 19 Aug 2019 11:40:47 +0100 Subject: [PATCH] - Fix: Save frontend attributes in while loop - Fix: save backend / frontend attributes in ReplaceInstruction --- tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/compiler/tf2xla/frontend_attributes_util.cc | 6 ++---- tensorflow/compiler/tf2xla/frontend_attributes_util.h | 2 ++ tensorflow/compiler/tf2xla/functionalize_while.cc | 7 +++++++ tensorflow/compiler/xla/service/hlo_computation.cc | 6 ++++++ 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1e4f2e23ef3..329c706c763 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -594,6 +594,7 @@ cc_library( ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", + ":frontend_attributes_util", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc index 7c2564ffa99..b088001f287 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -19,13 +19,11 @@ limitations under the License. namespace tensorflow { -namespace { -const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes"; -} // namespace +const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes"; xla::StatusOr> GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) { - const AttrValue *attr = attrs.Find(kFrontendAttributesAttribute); + const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName); if (attr == nullptr) { return xla::StatusOr>( absl::nullopt); diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h index 1c2b1d8c1c5..421f21e71d1 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.h +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -24,6 +24,8 @@ limitations under the License. namespace tensorflow { +// Frontend Attributes Id. +extern const char kXlaFrontendAttributesAttrName[]; // Return the FrontendAttributes stored in the AttrSlice if there are some. // // Return an InvalidArgument error if some attributes are present but diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index e4a21f90598..d3d2f2ff79a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -530,6 +531,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, builder.Attr("cond", cond_name); builder.Attr("body", body_name); string outside_compilation; + string frontend_attributes; + if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName, + &frontend_attributes) + .ok()) { + builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes); + } if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, &outside_compilation) .ok()) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 6fe91e492ed..fce60bc430e 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -837,6 +837,12 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, if (new_instruction->metadata().op_name().empty()) { new_instruction->set_metadata(old_instruction->metadata()); } + new_instruction->set_raw_backend_config_string( + old_instruction->raw_backend_config_string()); + if (new_instruction->frontend_attributes().map().empty()) { + new_instruction->set_frontend_attributes( + old_instruction->frontend_attributes()); + } // Like the metadata above, if the user didn't specify any sharding // information on the new instruction we should copy the old sharding