Re-implement GenerateResourceSharedNameIfEmpty to operate on GraphDef. The previous implementation modifies internal data structure of the graph and functions libarary directly, and after modifying the functions it invalidates the pointers referenced by nodes in the graph and causes a dangling pointer bug.
PiperOrigin-RevId: 351850970 Change-Id: If56b56ed602f0cb97af0c33489be9db42988a133
This commit is contained in:
parent
cfa9e61e4d
commit
e634adc15d
@ -0,0 +1,67 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s
|
||||
|
||||
# This is a stripped down GraphDef of the model from b/175240312. To hit the
|
||||
# bug, the GraphDef needs to have functions in the library and also a Merge node
|
||||
# to go into certain part of the functionalization code where it crashes.
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Switch0"
|
||||
op: "Switch"
|
||||
input: "input"
|
||||
input: "input"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "func0"
|
||||
op: "func_name"
|
||||
input: "Switch0:1"
|
||||
}
|
||||
node {
|
||||
name: "Merge"
|
||||
op: "Merge"
|
||||
input: "Switch0:1"
|
||||
input: "Switch0"
|
||||
input: "^func0"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "func_name"
|
||||
input_arg {
|
||||
name: "arg0"
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "retval0"
|
||||
value: "arg0"
|
||||
}
|
||||
}
|
||||
}
|
@ -182,8 +182,6 @@ class NameUniquifier : public OpOrArgNameMapper {
|
||||
|
||||
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
|
||||
bool restrict_functionalization_to_tpu_nodes) {
|
||||
TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(*graph, *flib_def));
|
||||
|
||||
// If `restrict_functionalization_to_tpu_nodes` is true let filter function
|
||||
// return true for `_tpu_replicate` nodes, otherwise don't set filter.
|
||||
NodeFilter node_filter =
|
||||
@ -3386,11 +3384,17 @@ Status SavedModelSignatureDefImporterLite::InitializeGraph(
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunGrappler(const_cast<MetaGraphDef*>(&meta_graph_def_)));
|
||||
|
||||
GraphDef graph_def = meta_graph_def_.graph_def();
|
||||
if (import_options.upgrade_legacy) {
|
||||
TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
|
||||
graph_def, graph_->flib_def().default_registry()));
|
||||
}
|
||||
|
||||
GraphConstructorOptions graph_ctor_options;
|
||||
graph_ctor_options.allow_internal_ops = true;
|
||||
graph_ctor_options.add_default_attributes = true;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
|
||||
graph_ctor_options, meta_graph_def_.graph_def(), graph_.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToGraph(graph_ctor_options, graph_def, graph_.get()));
|
||||
|
||||
// TODO(jpienaar): Remove need to const_cast.
|
||||
if (import_options.upgrade_legacy) {
|
||||
@ -3732,6 +3736,10 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
|
||||
if (add_default_attributes) {
|
||||
TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef));
|
||||
}
|
||||
if (specs.upgrade_legacy) {
|
||||
TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(
|
||||
preprocessed_graphdef, graph.flib_def().default_registry()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
|
||||
options, std::move(preprocessed_graphdef), &graph));
|
||||
return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs,
|
||||
|
@ -34,8 +34,8 @@ const llvm::StringSet<>& GetSharedNameGenerationCompatibleOps() {
|
||||
return *ops;
|
||||
}
|
||||
|
||||
Status GenerateResourceSharedNameIfEmpty(Graph& graph,
|
||||
FunctionLibraryDefinition& flib_def) {
|
||||
Status GenerateResourceSharedNameIfEmpty(
|
||||
GraphDef& gdef, const OpRegistryInterface* default_registry) {
|
||||
auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
|
||||
const OpDef& op_def) {
|
||||
if (!GetSharedNameGenerationCompatibleOps().contains(op_def.name())) {
|
||||
@ -64,31 +64,37 @@ Status GenerateResourceSharedNameIfEmpty(Graph& graph,
|
||||
return iter->second.s().empty();
|
||||
};
|
||||
|
||||
// Upgrade nodes in the graph.
|
||||
for (auto* node : graph.nodes()) {
|
||||
if (is_resource_op_with_empty_shared_name(node->def(), node->op_def())) {
|
||||
node->AddAttr("shared_name", node->name());
|
||||
FunctionDefLibrary* library = gdef.mutable_library();
|
||||
auto flib_def = library ? std::make_unique<FunctionLibraryDefinition>(
|
||||
default_registry, *library)
|
||||
: std::make_unique<FunctionLibraryDefinition>(
|
||||
default_registry, FunctionDefLibrary());
|
||||
|
||||
if (library) {
|
||||
// Upgrade nodes in the functions.
|
||||
for (FunctionDef& fdef : *library->mutable_function()) {
|
||||
auto func_name = fdef.signature().name();
|
||||
for (auto& node_def : *fdef.mutable_node_def()) {
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
|
||||
if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
|
||||
// Use the concat of function name and node name for such ops in a
|
||||
// function as the shared_name. "@" is used as the separator because
|
||||
// it is not allowed in the function name or the node name.
|
||||
(*node_def.mutable_attr())["shared_name"].set_s(
|
||||
absl::StrCat(node_def.name(), "@", func_name));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade nodes in the functions.
|
||||
auto func_names = flib_def.ListFunctionNames();
|
||||
for (const auto& func_name : func_names) {
|
||||
const FunctionDef* orig = flib_def.Find(func_name);
|
||||
DCHECK(orig);
|
||||
auto copy = *orig;
|
||||
for (auto& node_def : *copy.mutable_node_def()) {
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def));
|
||||
if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
|
||||
// Use the concat of function name and node name for such ops in a
|
||||
// function as the shared_name. "@" is used as the separator because it
|
||||
// is not allowed in the function name or the node name.
|
||||
(*node_def.mutable_attr())["shared_name"].set_s(
|
||||
absl::StrCat(node_def.name(), "@", func_name));
|
||||
}
|
||||
// Upgrade nodes in the GraphDef.
|
||||
for (auto& node_def : *gdef.mutable_node()) {
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(flib_def->LookUpOpDef(node_def.op(), &op_def));
|
||||
if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
|
||||
(*node_def.mutable_attr())["shared_name"].set_s(node_def.name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(flib_def.ReplaceFunction(func_name, copy));
|
||||
}
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
|
@ -27,8 +27,8 @@ class MetaGraphDef;
|
||||
// Generate the shared_name for resource handle ops in the graph and functions
|
||||
// if their shared_names are empty. Resource handle ops with empty shared_name
|
||||
// may have undesired semantics.
|
||||
Status GenerateResourceSharedNameIfEmpty(Graph& graph,
|
||||
FunctionLibraryDefinition& flib_def);
|
||||
Status GenerateResourceSharedNameIfEmpty(
|
||||
GraphDef& gdef, const OpRegistryInterface* default_registry);
|
||||
|
||||
// Run grapler passes over `meta_graph_def`.graph_def(), and optimize it in
|
||||
// place.
|
||||
|
Loading…
x
Reference in New Issue
Block a user