Fix wrong integer type for resource_arg_unique_id and resubmit resource_arg_unique_id change.

PiperOrigin-RevId: 285502849
Change-Id: Id0cfe71193676fab04912dc19024fe10961370ff
This commit is contained in:
Yuanzhong Xu 2019-12-13 17:23:24 -08:00 committed by TensorFlower Gardener
parent 8bd638bc3f
commit 539eb9f8b2
8 changed files with 259 additions and 16 deletions

View File

@ -0,0 +1,99 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=func_call -o - | FileCheck %s
node {
name: "x"
op: "VarHandleOp"
device: "/CPU:0"
attr {
key: "container"
value {
s: "a"
}
}
attr {
key: "dtype"
value {
type: DT_INT64
}
}
attr {
key: "shape"
value {
shape {
}
}
}
attr {
key: "shared_name"
value {
s: "x"
}
}
}
node {
name: "func_call"
op: "test_func_name"
input: "x"
input: "x"
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
}
library {
function {
signature {
name: "test_func_name"
input_arg {
name: "a_0"
type: DT_RESOURCE
}
input_arg {
name: "a_1"
type: DT_RESOURCE
}
output_arg {
name: "a"
type: DT_RESOURCE
}
}
resource_arg_unique_id {
key: 0
value: 0
}
resource_arg_unique_id {
key: 1
value: 0
}
ret {
key: "a"
value: "a_0"
}
attr {
key: "_disable_call_shape_inference"
value {
b: true
}
}
}
}
# Check that the `resource_arg_unique_id` for each argument is propagated to the
# `tf.resource_arg_unique_id` argument attribute of the function
# @test_func_name0.
# CHECK: func @main
# CHECK: tf_executor.graph
# CHECK: "tf.VarHandleOp"()
# CHECK: "tf.LegacyCall"
# CHECK-SAME: {_disable_call_shape_inference = true, f = @test_func_name0}
# CHECK: tf_executor.fetch
# CHECK: return
# CHECK: func @test_func_name0
# CHECK-SAME: tf.resource_arg_unique_id = 0
# CHECK-SAME tf.resource_arg_unique_id = 0
# CHECK: tf_executor.graph
# CHECK: tf_executor.fetch
# CHECK: return

View File

@ -0,0 +1,62 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main() -> tensor<*x!tf.resource> attributes {tf.entry_function = {inputs = "", outputs = "func_call"}} {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.VarHandleOp"() {container = "a", device = "/CPU:0", dtype = i64, name = "x", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<i64>>>
%outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) {_disable_call_shape_inference = true, f = @test_func_name0} : (tensor<!tf.resource<tensor<i64>>>, tensor<!tf.resource<tensor<i64>>>) -> tensor<*x!tf.resource>
tf_executor.fetch %outputs_0 : tensor<*x!tf.resource>
}
return %0 : tensor<*x!tf.resource>
}
func @test_func_name0(%arg0: tensor<*x!tf.resource> {tf.resource_arg_unique_id = 0 : i64}, %arg1: tensor<*x!tf.resource> {tf.resource_arg_unique_id = 0 : i64}) -> tensor<*x!tf.resource> attributes {tf._disable_call_shape_inference = true} {
%0 = tf_executor.graph {
tf_executor.fetch %arg0 : tensor<*x!tf.resource>
}
return %0 : tensor<*x!tf.resource>
}
// Check that the `tf.resource_arg_unique_id` argument attributes of
// test_func_name0 are propagated to the function's arg_attr and
// resource_arg_unique_id.
// CHECK: name: "x"
// CHECK: op: "VarHandleOp"
// CHECK: name: "func_call"
// CHECK: input: "x"
// CHECK: input: "x"
// CHECK: library
// CHECK: function
// CHECK: signature
// CHECK: input_arg
// CHECK: type: DT_RESOURCE
// CHECK: input_arg
// CHECK: type: DT_RESOURCE
// CHECK: output_arg
// CHECK: type: DT_RESOURCE
// CHECK: ret
// Check _resource_arg_unique_id for each argument. Since they alias each other,
// both values are 0.
// CHECK: arg_attr
// CHECK-NEXT: key: 0
// CHECK-NEXT: value
// CHECK: key: "_resource_arg_unique_id"
// CHECK-NEXT: value
// CHECK-NEXT: i: 0
// CHECK: arg_attr
// CHECK-NEXT: key: 1
// CHECK-NEXT: value
// CHECK: key: "_resource_arg_unique_id"
// CHECK-NEXT: value
// CHECK-NEXT: i: 0
// Check resource_arg_unique_id for each argument. Since they alias each other,
// both values are 0.
// CHECK: resource_arg_unique_id
// CHECK-NEXT: key: 0
// CHECK-NEXT: value: 0
// CHECK: resource_arg_unique_id
// CHECK-NEXT: key: 1
// CHECK-NEXT: value: 0

View File

@ -258,6 +258,14 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
*node_def->mutable_device() = device_attr.getValue().str();
}
if (auto resource_arg_unique_id_attr =
func.getArgAttrOfType<mlir::IntegerAttr>(
index, "tf.resource_arg_unique_id")) {
AttrValue unique_id_attr;
unique_id_attr.set_i(resource_arg_unique_id_attr.getInt());
(*node_def->mutable_attr())["_resource_arg_unique_id"] = unique_id_attr;
}
return node_def;
}
@ -639,6 +647,14 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs,
if (auto attr = function.getAttrOfType<mlir::UnitAttr>(stateful_string)) {
func_def.mutable_signature()->set_is_stateful(true);
}
for (int64 i = 0; i < function.getNumArguments(); ++i) {
if (auto resource_arg_unique_id_attr =
function.getArgAttrOfType<mlir::IntegerAttr>(
i, "tf.resource_arg_unique_id")) {
(*func_def.mutable_resource_arg_unique_id())[i] =
resource_arg_unique_id_attr.getInt();
}
}
// Ignore the gradient and is_stateful attribute on the function as they have
// been handled above.

View File

@ -163,11 +163,15 @@ class ImporterBase {
StatusOr<mlir::FunctionType> InferLibFunctionType(const FunctionBody& fbody);
// Extracts arg and ret nodes from FunctionBody.
// `resource_arg_unique_ids` will be filled with the unique IDs of resource
// variables, as a list of {index, ID} pairs.
void GetArgsAndRetsFromFunctionBody(
const FunctionBody& fbody,
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
absl::InlinedVector<OutputTensor, 4>* ret_nodes,
absl::InlinedVector<Node*, 4>* control_ret_nodes);
absl::InlinedVector<Node*, 4>* control_ret_nodes,
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
resource_arg_unique_ids);
// Prepares converting the graph to an MLIR module. This step removes the
// backedges of the graph, orders the nodes and infers the shapes.
@ -180,7 +184,9 @@ class ImporterBase {
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs);
llvm::ArrayRef<mlir::NamedAttribute> attrs,
const absl::InlinedVector<std::pair<int64_t, int64_t>, 4>&
resource_arg_unique_ids);
// Finds out the function definition for the given function name from the
// graph and converts it to a function of the module. This method is called
@ -1000,7 +1006,9 @@ StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
void ImporterBase::GetArgsAndRetsFromFunctionBody(
const FunctionBody& fbody, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
absl::InlinedVector<OutputTensor, 4>* ret_nodes,
absl::InlinedVector<Node*, 4>* control_ret_nodes) {
absl::InlinedVector<Node*, 4>* control_ret_nodes,
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
resource_arg_unique_ids) {
arg_nodes->reserve(fbody.arg_nodes.size());
ret_nodes->reserve(fbody.ret_nodes.size());
for (auto arg : fbody.arg_nodes) {
@ -1009,6 +1017,9 @@ void ImporterBase::GetArgsAndRetsFromFunctionBody(
for (auto ret : fbody.ret_nodes) {
ret_nodes->emplace_back(ret, 0);
}
for (const auto& entry : fbody.fdef.resource_arg_unique_id()) {
resource_arg_unique_ids->push_back(entry);
}
*control_ret_nodes = fbody.control_ret_nodes;
}
@ -1101,12 +1112,14 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
absl::InlinedVector<OutputTensor, 4> arg_nodes;
absl::InlinedVector<OutputTensor, 4> ret_nodes;
absl::InlinedVector<Node*, 4> control_ret_nodes;
absl::InlinedVector<std::pair<int64_t, int64_t>, 4> resource_arg_unique_ids;
GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes,
&control_ret_nodes);
&control_ret_nodes, &resource_arg_unique_ids);
TF_RETURN_IF_ERROR(child_importer.Convert(
mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
llvm::makeArrayRef(attributes.begin(), attributes.end())));
llvm::makeArrayRef(attributes.begin(), attributes.end()),
resource_arg_unique_ids));
return Status::OK();
}
@ -1121,7 +1134,9 @@ Status ImporterBase::Convert(
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
llvm::ArrayRef<mlir::NamedAttribute> attrs,
const absl::InlinedVector<std::pair<int64_t, int64_t>, 4>&
resource_arg_unique_ids) {
// TODO(b/122040776): Uses debug info for FunctionDef.
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
func_name, func_type, attrs);
@ -1144,8 +1159,14 @@ Status ImporterBase::Convert(
// pairs.
TF_RETURN_IF_ERROR(AddBackedges());
return ConvertFunctionArgAndRets(function, graph, func_type.getInputs(),
arg_nodes, ret_nodes, control_ret_nodes);
TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph,
func_type.getInputs(), arg_nodes,
ret_nodes, control_ret_nodes));
for (const auto& entry : resource_arg_unique_ids) {
function.setArgAttr(entry.first, "tf.resource_arg_unique_id",
builder_.getI64IntegerAttr(entry.second));
}
return Status::OK();
}
Status ImporterBase::ConvertFunctionArgAndRets(
@ -1710,10 +1731,14 @@ class GraphDefImporter : public ImporterBase {
// output nodes, for function graphs. Arguments and return values are
// determined by node op type. Type and shape information of the function are
// inferred by the shape refiner in ImporterBase.
// `resource_arg_unique_ids` will be filled with the unique IDs of resource
// variables, as a list of {index, ID} pairs.
StatusOr<mlir::FunctionType> GetArgsRetsAndTypesFromFunctionGraph(
mlir::MLIRContext* context,
absl::InlinedVector<OutputTensor, 4>* arg_nodes,
absl::InlinedVector<OutputTensor, 4>* ret_nodes);
absl::InlinedVector<OutputTensor, 4>* ret_nodes,
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
resource_arg_unique_ids);
};
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
@ -1734,6 +1759,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
absl::InlinedVector<OutputTensor, 4> arg_nodes;
absl::InlinedVector<OutputTensor, 4> ret_nodes;
absl::InlinedVector<Node*, 4> control_ret_nodes;
absl::InlinedVector<std::pair<int64_t, int64_t>, 4> resource_arg_unique_ids;
llvm::SmallVector<mlir::NamedAttribute, 1> attrs;
if (specs.graph_as_function) {
if (specs.prune_unused_nodes || !specs.inputs.empty() ||
@ -1742,9 +1768,10 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
"Pruning of graph is currently unsupported when the main graph is "
"converted to a function.");
TF_ASSIGN_OR_RETURN(func_type,
TF_ASSIGN_OR_RETURN(
func_type,
importer.GetArgsRetsAndTypesFromFunctionGraph(
context, &arg_nodes, &ret_nodes));
context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids));
if (!arg_nodes.empty() || !ret_nodes.empty()) {
mlir::Builder b(context);
@ -1805,7 +1832,8 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
{producer, min_consumer, bad_consumers})));
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
"main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
"main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs,
resource_arg_unique_ids));
return module;
}
@ -1918,7 +1946,9 @@ StatusOr<mlir::FunctionType> GraphDefImporter::InferMainFunctionType(
StatusOr<mlir::FunctionType>
GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
mlir::MLIRContext* context, absl::InlinedVector<OutputTensor, 4>* arg_nodes,
absl::InlinedVector<OutputTensor, 4>* ret_nodes) {
absl::InlinedVector<OutputTensor, 4>* ret_nodes,
absl::InlinedVector<std::pair<int64_t, int64_t>, 4>*
resource_arg_unique_ids) {
auto add_node = [](Node* node, absl::InlinedVector<OutputTensor, 4>* nodes) {
auto* attr = node->attrs().Find("index");
if (!attr)
@ -1959,6 +1989,12 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
TF_ASSIGN_OR_RETURN(auto type,
InferOutputType(*arg_node.node, /*idx=*/0, builder));
arg_types.push_back(type);
tensorflow::int64 resource_arg_unique_id;
if (TryGetNodeAttr(arg_node.node->attrs(), "_resource_arg_unique_id",
&resource_arg_unique_id)) {
resource_arg_unique_ids->emplace_back(arg_node_and_idx.index(),
resource_arg_unique_id);
}
}
llvm::SmallVector<mlir::Type, 4> ret_types;

View File

@ -167,9 +167,11 @@ class FunctionInstantiationHelper {
}
// Builds index for nodes that can be used as node's input arguments.
// `resource_arg_unique_id`: if non-negative, will be populated to the
// "_resource_arg_unique_id" attribute of the arg node.
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
const FunctionDef::ArgAttrs* arg_attrs,
bool ints_on_device) {
bool ints_on_device, int64 resource_arg_unique_id) {
bool is_type_list;
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(
@ -196,6 +198,9 @@ class FunctionInstantiationHelper {
DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
AddAttr("T", dtype, gnode);
AddAttr("index", arg_index, gnode);
if (resource_arg_unique_id >= 0) {
AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode);
}
if (arg_attrs) {
for (const auto& arg_attr : arg_attrs->attr()) {
AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr());
@ -729,8 +734,14 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
auto it = fdef.arg_attr().find(i);
const FunctionDef::ArgAttrs* arg_attrs =
it != fdef.arg_attr().end() ? &it->second : nullptr;
auto resource_id_it = fdef.resource_arg_unique_id().find(i);
int64 resource_arg_unique_id =
resource_id_it != fdef.resource_arg_unique_id().end()
? resource_id_it->second
: -1LL;
s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs,
ints_on_device);
ints_on_device, resource_arg_unique_id);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", Print(arg_def));
return s;

View File

@ -37,6 +37,17 @@ message FunctionDef {
}
map<uint32, ArgAttrs> arg_attr = 7;
// Unique IDs for each resource argument, used to track aliasing resources. If
// Argument A and Argument B alias each other, then
// resource_arg_unique_ids[A.index] == resource_arg_unique_ids[B.index].
//
// If this field is empty, none of the arguments could alias; otherwise, every
// resource argument should have an entry in this field.
//
// When instantiated, the unique IDs will be attached to the _Arg nodes'
// "_resource_arg_unique_id" attribute.
map<uint32, uint32> resource_arg_unique_id = 8;
// NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21.
reserved 2;

View File

@ -432,6 +432,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
const string& input_name = node_names.GetInputName(node->name());
argdef->set_name(input_name);
FunctionDef::ArgAttrs arg_attrs;
int64 resource_arg_unique_id = -1;
for (const auto& attr : node->attrs()) {
// Only copy internal attributes. These attributes will be applied to
// _Arg/Placeholder nodes when this FunctionDef is converted to graph,
@ -440,10 +441,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
if (absl::StartsWith(attr.first, "_")) {
arg_attrs.mutable_attr()->insert(attr);
}
if (attr.first == "_resource_arg_unique_id") {
resource_arg_unique_id = attr.second.i();
}
}
if (arg_attrs.attr_size() > 0) {
(*fdef->mutable_arg_attr())[i] = std::move(arg_attrs);
}
if (resource_arg_unique_id >= 0) {
(*fdef->mutable_resource_arg_unique_id())[idx] = resource_arg_unique_id;
}
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
}

View File

@ -101,6 +101,7 @@ FunctionDefLibrary GetFunctionDefLibraryStub(
*(fn_stub->mutable_signature()) = fn.signature();
*(fn_stub->mutable_attr()) = fn.attr();
*(fn_stub->mutable_arg_attr()) = fn.arg_attr();
*(fn_stub->mutable_resource_arg_unique_id()) = fn.resource_arg_unique_id();
}
*stub.mutable_gradient() = fdef_lib.gradient();
return stub;