- Fix: Save frontend attributes in while loop

- Fix: save backend / frontend attributes in ReplaceInstruction
This commit is contained in:
Anthony Barbier 2019-08-19 11:40:47 +01:00
parent 53e79b073e
commit d333176f5a
5 changed files with 18 additions and 4 deletions

View File

@ -594,6 +594,7 @@ cc_library(
":functionalize_cond",
":functionalize_control_flow_util",
":tf2xla_util",
":frontend_attributes_util",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:status_macros",

View File

@ -19,13 +19,11 @@ limitations under the License.
namespace tensorflow {
namespace {
const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes";
} // namespace
const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes";
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) {
const AttrValue *attr = attrs.Find(kFrontendAttributesAttribute);
const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName);
if (attr == nullptr) {
return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
absl::nullopt);

View File

@ -24,6 +24,8 @@ limitations under the License.
namespace tensorflow {
// Frontend Attributes Id.
extern const char kXlaFrontendAttributesAttrName[];
// Return the FrontendAttributes stored in the AttrSlice if there are some.
//
// Return an InvalidArgument error if some attributes are present but

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -530,6 +531,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
string outside_compilation;
string frontend_attributes;
if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName,
&frontend_attributes)
.ok()) {
builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes);
}
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
&outside_compilation)
.ok()) {

View File

@ -837,6 +837,12 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
if (new_instruction->metadata().op_name().empty()) {
new_instruction->set_metadata(old_instruction->metadata());
}
new_instruction->set_raw_backend_config_string(
old_instruction->raw_backend_config_string());
if (new_instruction->frontend_attributes().map().empty()) {
new_instruction->set_frontend_attributes(
old_instruction->frontend_attributes());
}
// Like the metadata above, if the user didn't specify any sharding
// information on the new instruction we should copy the old sharding