- Fix: Save frontend attributes in while loop
- Fix: save backend / frontend attributes in ReplaceInstruction
This commit is contained in:
parent
53e79b073e
commit
d333176f5a
@ -594,6 +594,7 @@ cc_library(
|
||||
":functionalize_cond",
|
||||
":functionalize_control_flow_util",
|
||||
":tf2xla_util",
|
||||
":frontend_attributes_util",
|
||||
"//tensorflow/compiler/jit:union_find",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
@ -19,13 +19,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes";
|
||||
} // namespace
|
||||
const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes";
|
||||
|
||||
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
|
||||
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) {
|
||||
const AttrValue *attr = attrs.Find(kFrontendAttributesAttribute);
|
||||
const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName);
|
||||
if (attr == nullptr) {
|
||||
return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
|
||||
absl::nullopt);
|
||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Frontend Attributes Id.
|
||||
extern const char kXlaFrontendAttributesAttrName[];
|
||||
// Return the FrontendAttributes stored in the AttrSlice if there are some.
|
||||
//
|
||||
// Return an InvalidArgument error if some attributes are present but
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
@ -530,6 +531,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
builder.Attr("cond", cond_name);
|
||||
builder.Attr("body", body_name);
|
||||
string outside_compilation;
|
||||
string frontend_attributes;
|
||||
if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName,
|
||||
&frontend_attributes)
|
||||
.ok()) {
|
||||
builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes);
|
||||
}
|
||||
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
|
||||
&outside_compilation)
|
||||
.ok()) {
|
||||
|
@ -837,6 +837,12 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
|
||||
if (new_instruction->metadata().op_name().empty()) {
|
||||
new_instruction->set_metadata(old_instruction->metadata());
|
||||
}
|
||||
new_instruction->set_raw_backend_config_string(
|
||||
old_instruction->raw_backend_config_string());
|
||||
if (new_instruction->frontend_attributes().map().empty()) {
|
||||
new_instruction->set_frontend_attributes(
|
||||
old_instruction->frontend_attributes());
|
||||
}
|
||||
|
||||
// Like the metadata above, if the user didn't specify any sharding
|
||||
// information on the new instruction we should copy the old sharding
|
||||
|
Loading…
Reference in New Issue
Block a user