- Fix: Save frontend attributes in while loop
- Fix: save backend / frontend attributes in ReplaceInstruction
This commit is contained in:
parent
53e79b073e
commit
d333176f5a
@ -594,6 +594,7 @@ cc_library(
|
|||||||
":functionalize_cond",
|
":functionalize_cond",
|
||||||
":functionalize_control_flow_util",
|
":functionalize_control_flow_util",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
|
":frontend_attributes_util",
|
||||||
"//tensorflow/compiler/jit:union_find",
|
"//tensorflow/compiler/jit:union_find",
|
||||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
@ -19,13 +19,11 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes";
|
||||||
const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes";
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
|
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
|
||||||
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) {
|
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) {
|
||||||
const AttrValue *attr = attrs.Find(kFrontendAttributesAttribute);
|
const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName);
|
||||||
if (attr == nullptr) {
|
if (attr == nullptr) {
|
||||||
return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
|
return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
|
||||||
absl::nullopt);
|
absl::nullopt);
|
||||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Frontend Attributes Id.
|
||||||
|
extern const char kXlaFrontendAttributesAttrName[];
|
||||||
// Return the FrontendAttributes stored in the AttrSlice if there are some.
|
// Return the FrontendAttributes stored in the AttrSlice if there are some.
|
||||||
//
|
//
|
||||||
// Return an InvalidArgument error if some attributes are present but
|
// Return an InvalidArgument error if some attributes are present but
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/jit/union_find.h"
|
#include "tensorflow/compiler/jit/union_find.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
||||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
@ -530,6 +531,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
|||||||
builder.Attr("cond", cond_name);
|
builder.Attr("cond", cond_name);
|
||||||
builder.Attr("body", body_name);
|
builder.Attr("body", body_name);
|
||||||
string outside_compilation;
|
string outside_compilation;
|
||||||
|
string frontend_attributes;
|
||||||
|
if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName,
|
||||||
|
&frontend_attributes)
|
||||||
|
.ok()) {
|
||||||
|
builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes);
|
||||||
|
}
|
||||||
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
|
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
|
||||||
&outside_compilation)
|
&outside_compilation)
|
||||||
.ok()) {
|
.ok()) {
|
||||||
|
@ -837,6 +837,12 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
|
|||||||
if (new_instruction->metadata().op_name().empty()) {
|
if (new_instruction->metadata().op_name().empty()) {
|
||||||
new_instruction->set_metadata(old_instruction->metadata());
|
new_instruction->set_metadata(old_instruction->metadata());
|
||||||
}
|
}
|
||||||
|
new_instruction->set_raw_backend_config_string(
|
||||||
|
old_instruction->raw_backend_config_string());
|
||||||
|
if (new_instruction->frontend_attributes().map().empty()) {
|
||||||
|
new_instruction->set_frontend_attributes(
|
||||||
|
old_instruction->frontend_attributes());
|
||||||
|
}
|
||||||
|
|
||||||
// Like the metadata above, if the user didn't specify any sharding
|
// Like the metadata above, if the user didn't specify any sharding
|
||||||
// information on the new instruction we should copy the old sharding
|
// information on the new instruction we should copy the old sharding
|
||||||
|
Loading…
Reference in New Issue
Block a user