Generate shared_name for resource handle ops if their shared_name is empty.
Resource handle ops with empty shared_names have undesired semantics. Two instances of resource ops with the same attributes and empty shared_name actually returns two different resources in current TF. PiperOrigin-RevId: 333758699 Change-Id: I56e6cc7138432875d25f45932d7b470495eaa680
This commit is contained in:
parent
9d268e2e49
commit
f9e9618859
@ -1025,6 +1025,16 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "upgrade_graph",
|
||||
srcs = ["translate/upgrade_graph.cc"],
|
||||
hdrs = ["translate/upgrade_graph.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_graphdef",
|
||||
srcs = [
|
||||
@ -1050,6 +1060,7 @@ cc_library(
|
||||
":tensorflow_types",
|
||||
":tf_saved_model_passes",
|
||||
":translate_utils",
|
||||
":upgrade_graph",
|
||||
"//tensorflow/cc/saved_model:bundle_v2",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
|
@ -0,0 +1,95 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s -tf-output-arrays=hash_table_node -o - | FileCheck %s
|
||||
|
||||
node: {
|
||||
name: "hash_table_node"
|
||||
op: "HashTableV2"
|
||||
attr: {
|
||||
key: "key_dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "shared_name"
|
||||
value: {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "value_dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Call"
|
||||
op: "PartitionedCall"
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "f"
|
||||
value {
|
||||
func {
|
||||
name: "create_resource"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "create_resource"
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
}
|
||||
node_def: {
|
||||
name: "hash_table_node"
|
||||
op: "HashTableV2"
|
||||
attr: {
|
||||
key: "key_dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "shared_name"
|
||||
value: {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "value_dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "handle"
|
||||
value: "hash_table_node:table_handle:0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# CHECK: tf.HashTableV2
|
||||
# CHECK-SAME: shared_name = "hash_table_node"
|
||||
|
||||
# CHECK: func @create_resource
|
||||
# CHECK: tf.HashTableV2
|
||||
# CHECK-SAME: shared_name = "create_resource_hash_table_node"
|
@ -73,6 +73,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
@ -180,6 +181,8 @@ class NameUniquifier : public OpOrArgNameMapper {
|
||||
|
||||
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
|
||||
bool restrict_functionalization_to_tpu_nodes) {
|
||||
TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(*graph, *flib_def));
|
||||
|
||||
// If `restrict_functionalization_to_tpu_nodes` is true let filter function
|
||||
// return true for `_tpu_replicate` nodes, otherwise don't set filter.
|
||||
NodeFilter node_filter =
|
||||
|
@ -0,0 +1,79 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GenerateResourceSharedNameIfEmpty(Graph& graph,
|
||||
FunctionLibraryDefinition& flib_def) {
|
||||
auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
|
||||
const OpDef& op_def) {
|
||||
// Only upgrade when it is a resource handle op.
|
||||
if (op_def.output_arg().size() != 1 ||
|
||||
op_def.output_arg(0).type() != tensorflow::DT_RESOURCE)
|
||||
return false;
|
||||
|
||||
// If the OpDef has "use_node_name_sharing" field, then it is valid to use
|
||||
// node names as shared names.
|
||||
if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
|
||||
[](const auto& attr_def) {
|
||||
return attr_def.name() == "use_node_name_sharing" &&
|
||||
attr_def.type() == "bool";
|
||||
}))
|
||||
return false;
|
||||
|
||||
if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
|
||||
[](const auto& attr_def) {
|
||||
return attr_def.name() == "shared_name" &&
|
||||
attr_def.type() == "string";
|
||||
}))
|
||||
return false;
|
||||
|
||||
auto iter = node_def.attr().find("shared_name");
|
||||
if (iter == node_def.attr().end()) return true;
|
||||
return iter->second.s().empty();
|
||||
};
|
||||
|
||||
// Upgrade nodes in the graph.
|
||||
for (auto* node : graph.nodes()) {
|
||||
if (is_resource_op_with_empty_shared_name(node->def(), node->op_def())) {
|
||||
node->AddAttr("shared_name", node->name());
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade nodes in the functions.
|
||||
auto func_names = flib_def.ListFunctionNames();
|
||||
for (const auto& func_name : func_names) {
|
||||
const FunctionDef* orig = flib_def.Find(func_name);
|
||||
DCHECK(orig);
|
||||
auto copy = *orig;
|
||||
for (auto& node_def : *copy.mutable_node_def()) {
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def));
|
||||
if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
|
||||
// Use the concat of function name and node name for such ops in a
|
||||
// function as the shared_name.
|
||||
(*node_def.mutable_attr())["shared_name"].set_s(
|
||||
absl::StrCat(func_name, "_", node_def.name()));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(flib_def.ReplaceFunction(func_name, copy));
|
||||
}
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Generate the shared_name for resource handle ops in the graph and functions
|
||||
// if their shared_names are empty. Resource handle ops with empty shared_name
|
||||
// may have undesired semantics.
|
||||
Status GenerateResourceSharedNameIfEmpty(Graph& graph,
|
||||
FunctionLibraryDefinition& flib_def);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_
|
Loading…
x
Reference in New Issue
Block a user