Generate shared_name for resource handle ops if their shared_name is empty.

Resource handle ops with empty shared_names have undesired semantics. Two
instances of resource ops with the same attributes and empty shared_name actually
returns two different resources in current TF.

PiperOrigin-RevId: 333758699
Change-Id: I56e6cc7138432875d25f45932d7b470495eaa680
This commit is contained in:
Kuangyuan Chen 2020-09-25 10:24:15 -07:00 committed by TensorFlower Gardener
parent 9d268e2e49
commit f9e9618859
5 changed files with 220 additions and 0 deletions

View File

@ -1025,6 +1025,16 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "upgrade_graph",
srcs = ["translate/upgrade_graph.cc"],
hdrs = ["translate/upgrade_graph.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:graph",
],
)
cc_library(
name = "convert_graphdef",
srcs = [
@ -1050,6 +1060,7 @@ cc_library(
":tensorflow_types",
":tf_saved_model_passes",
":translate_utils",
":upgrade_graph",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/cc/saved_model:loader_lite",

View File

@ -0,0 +1,95 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s -tf-output-arrays=hash_table_node -o - | FileCheck %s
node: {
name: "hash_table_node"
op: "HashTableV2"
attr: {
key: "key_dtype"
value: {
type: DT_INT32
}
}
attr: {
key: "shared_name"
value: {
s: ""
}
}
attr: {
key: "value_dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "Call"
op: "PartitionedCall"
attr {
key: "Tin"
value {
list {
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_RESOURCE
}
}
}
attr {
key: "f"
value {
func {
name: "create_resource"
}
}
}
}
library {
function {
signature {
name: "create_resource"
output_arg {
name: "handle"
type: DT_RESOURCE
}
}
node_def: {
name: "hash_table_node"
op: "HashTableV2"
attr: {
key: "key_dtype"
value: {
type: DT_INT32
}
}
attr: {
key: "shared_name"
value: {
s: ""
}
}
attr: {
key: "value_dtype"
value: {
type: DT_FLOAT
}
}
}
ret {
key: "handle"
value: "hash_table_node:table_handle:0"
}
}
}
# CHECK: tf.HashTableV2
# CHECK-SAME: shared_name = "hash_table_node"
# CHECK: func @create_resource
# CHECK: tf.HashTableV2
# CHECK-SAME: shared_name = "create_resource_hash_table_node"

View File

@ -73,6 +73,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
@ -180,6 +181,8 @@ class NameUniquifier : public OpOrArgNameMapper {
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
bool restrict_functionalization_to_tpu_nodes) {
TF_RETURN_IF_ERROR(GenerateResourceSharedNameIfEmpty(*graph, *flib_def));
// If `restrict_functionalization_to_tpu_nodes` is true let filter function
// return true for `_tpu_replicate` nodes, otherwise don't set filter.
NodeFilter node_filter =

View File

@ -0,0 +1,79 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
namespace tensorflow {
Status GenerateResourceSharedNameIfEmpty(Graph& graph,
FunctionLibraryDefinition& flib_def) {
auto is_resource_op_with_empty_shared_name = [](const NodeDef& node_def,
const OpDef& op_def) {
// Only upgrade when it is a resource handle op.
if (op_def.output_arg().size() != 1 ||
op_def.output_arg(0).type() != tensorflow::DT_RESOURCE)
return false;
// If the OpDef has "use_node_name_sharing" field, then it is valid to use
// node names as shared names.
if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
[](const auto& attr_def) {
return attr_def.name() == "use_node_name_sharing" &&
attr_def.type() == "bool";
}))
return false;
if (!std::any_of(op_def.attr().begin(), op_def.attr().end(),
[](const auto& attr_def) {
return attr_def.name() == "shared_name" &&
attr_def.type() == "string";
}))
return false;
auto iter = node_def.attr().find("shared_name");
if (iter == node_def.attr().end()) return true;
return iter->second.s().empty();
};
// Upgrade nodes in the graph.
for (auto* node : graph.nodes()) {
if (is_resource_op_with_empty_shared_name(node->def(), node->op_def())) {
node->AddAttr("shared_name", node->name());
}
}
// Upgrade nodes in the functions.
auto func_names = flib_def.ListFunctionNames();
for (const auto& func_name : func_names) {
const FunctionDef* orig = flib_def.Find(func_name);
DCHECK(orig);
auto copy = *orig;
for (auto& node_def : *copy.mutable_node_def()) {
const OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(flib_def.LookUpOpDef(node_def.op(), &op_def));
if (is_resource_op_with_empty_shared_name(node_def, *op_def)) {
// Use the concat of function name and node name for such ops in a
// function as the shared_name.
(*node_def.mutable_attr())["shared_name"].set_s(
absl::StrCat(func_name, "_", node_def.name()));
}
}
TF_RETURN_IF_ERROR(flib_def.ReplaceFunction(func_name, copy));
}
return tensorflow::Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Generate the shared_name for resource handle ops in the graph and functions
// if their shared_names are empty. Resource handle ops with empty shared_name
// may have undesired semantics.
Status GenerateResourceSharedNameIfEmpty(Graph& graph,
FunctionLibraryDefinition& flib_def);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_