diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt new file mode 100644 index 00000000000..a3f78e282bc --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt @@ -0,0 +1,99 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=func_call -o - | FileCheck %s + +node { + name: "x" + op: "VarHandleOp" + device: "/CPU:0" + attr { + key: "container" + value { + s: "a" + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "x" + } + } +} +node { + name: "func_call" + op: "test_func_name" + input: "x" + input: "x" + attr { + key: "_disable_call_shape_inference" + value { + b: true + } + } +} +library { + function { + signature { + name: "test_func_name" + input_arg { + name: "a_0" + type: DT_RESOURCE + } + input_arg { + name: "a_1" + type: DT_RESOURCE + } + output_arg { + name: "a" + type: DT_RESOURCE + } + } + resource_arg_unique_id { + key: 0 + value: 0 + } + resource_arg_unique_id { + key: 1 + value: 0 + } + ret { + key: "a" + value: "a_0" + } + attr { + key: "_disable_call_shape_inference" + value { + b: true + } + } + } +} + +# Check that the `resource_arg_unique_id` for each argument is propagated to the +# `tf.resource_arg_unique_id` argument attribute of the function +# @test_func_name0. + +# CHECK: func @main +# CHECK: tf_executor.graph +# CHECK: "tf.VarHandleOp"() +# CHECK: "tf.LegacyCall" +# CHECK-SAME: {_disable_call_shape_inference = true, f = @test_func_name0} +# CHECK: tf_executor.fetch +# CHECK: return +# CHECK: func @test_func_name0 +# CHECK-SAME: tf.resource_arg_unique_id = 0 +# CHECK-SAME tf.resource_arg_unique_id = 0 +# CHECK: tf_executor.graph +# CHECK: tf_executor.fetch +# CHECK: return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir new file mode 100644 index 00000000000..24cb7b703c6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir @@ -0,0 +1,62 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main() -> tensor<*x!tf.resource> attributes {tf.entry_function = {inputs = "", outputs = "func_call"}} { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.VarHandleOp"() {container = "a", device = "/CPU:0", dtype = i64, name = "x", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) {_disable_call_shape_inference = true, f = @test_func_name0} : (tensor>>, tensor>>) -> tensor<*x!tf.resource> + tf_executor.fetch %outputs_0 : tensor<*x!tf.resource> + } + return %0 : tensor<*x!tf.resource> +} +func @test_func_name0(%arg0: tensor<*x!tf.resource> {tf.resource_arg_unique_id = 0 : i64}, %arg1: tensor<*x!tf.resource> {tf.resource_arg_unique_id = 0 : i64}) -> tensor<*x!tf.resource> attributes {tf._disable_call_shape_inference = true} { + %0 = tf_executor.graph { + tf_executor.fetch %arg0 : tensor<*x!tf.resource> + } + return %0 : tensor<*x!tf.resource> +} + +// Check that the `tf.resource_arg_unique_id` argument attributes of +// test_func_name0 are propagated to the function's arg_attr and +// resource_arg_unique_id. + +// CHECK: name: "x" +// CHECK: op: "VarHandleOp" + +// CHECK: name: "func_call" +// CHECK: input: "x" +// CHECK: input: "x" + +// CHECK: library +// CHECK: function +// CHECK: signature +// CHECK: input_arg +// CHECK: type: DT_RESOURCE +// CHECK: input_arg +// CHECK: type: DT_RESOURCE +// CHECK: output_arg +// CHECK: type: DT_RESOURCE +// CHECK: ret + +// Check _resource_arg_unique_id for each argument. Since they alias each other, +// both values are 0. +// CHECK: arg_attr +// CHECK-NEXT: key: 0 +// CHECK-NEXT: value +// CHECK: key: "_resource_arg_unique_id" +// CHECK-NEXT: value +// CHECK-NEXT: i: 0 +// CHECK: arg_attr +// CHECK-NEXT: key: 1 +// CHECK-NEXT: value +// CHECK: key: "_resource_arg_unique_id" +// CHECK-NEXT: value +// CHECK-NEXT: i: 0 + +// Check resource_arg_unique_id for each argument. Since they alias each other, +// both values are 0. +// CHECK: resource_arg_unique_id +// CHECK-NEXT: key: 0 +// CHECK-NEXT: value: 0 +// CHECK: resource_arg_unique_id +// CHECK-NEXT: key: 1 +// CHECK-NEXT: value: 0 diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 58242e62f1c..3ea90fc8fbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -258,6 +258,14 @@ StatusOr> Exporter::GetArgumentNode( *node_def->mutable_device() = device_attr.getValue().str(); } + if (auto resource_arg_unique_id_attr = + func.getArgAttrOfType( + index, "tf.resource_arg_unique_id")) { + AttrValue unique_id_attr; + unique_id_attr.set_i(resource_arg_unique_id_attr.getInt()); + (*node_def->mutable_attr())["_resource_arg_unique_id"] = unique_id_attr; + } + return node_def; } @@ -639,6 +647,14 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, if (auto attr = function.getAttrOfType(stateful_string)) { func_def.mutable_signature()->set_is_stateful(true); } + for (int64 i = 0; i < function.getNumArguments(); ++i) { + if (auto resource_arg_unique_id_attr = + function.getArgAttrOfType( + i, "tf.resource_arg_unique_id")) { + (*func_def.mutable_resource_arg_unique_id())[i] = + resource_arg_unique_id_attr.getInt(); + } + } // Ignore the gradient and is_stateful attribute on the function as they have // been handled above. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3e847034a1b..17ebff7b79e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -163,11 +163,15 @@ class ImporterBase { StatusOr InferLibFunctionType(const FunctionBody& fbody); // Extracts arg and ret nodes from FunctionBody. + // `resource_arg_unique_ids` will be filled with the unique IDs of resource + // variables, as a list of {index, ID} pairs. void GetArgsAndRetsFromFunctionBody( const FunctionBody& fbody, absl::InlinedVector* arg_nodes, absl::InlinedVector* ret_nodes, - absl::InlinedVector* control_ret_nodes); + absl::InlinedVector* control_ret_nodes, + absl::InlinedVector, 4>* + resource_arg_unique_ids); // Prepares converting the graph to an MLIR module. This step removes the // backedges of the graph, orders the nodes and infers the shapes. @@ -180,7 +184,9 @@ class ImporterBase { const absl::InlinedVector& arg_nodes, const absl::InlinedVector& ret_nodes, const absl::InlinedVector& control_ret_nodes, - llvm::ArrayRef attrs); + llvm::ArrayRef attrs, + const absl::InlinedVector, 4>& + resource_arg_unique_ids); // Finds out the function definition for the given function name from the // graph and converts it to a function of the module. This method is called @@ -1000,7 +1006,9 @@ StatusOr ImporterBase::ConvertAttributeValue( void ImporterBase::GetArgsAndRetsFromFunctionBody( const FunctionBody& fbody, absl::InlinedVector* arg_nodes, absl::InlinedVector* ret_nodes, - absl::InlinedVector* control_ret_nodes) { + absl::InlinedVector* control_ret_nodes, + absl::InlinedVector, 4>* + resource_arg_unique_ids) { arg_nodes->reserve(fbody.arg_nodes.size()); ret_nodes->reserve(fbody.ret_nodes.size()); for (auto arg : fbody.arg_nodes) { @@ -1009,6 +1017,9 @@ void ImporterBase::GetArgsAndRetsFromFunctionBody( for (auto ret : fbody.ret_nodes) { ret_nodes->emplace_back(ret, 0); } + for (const auto& entry : fbody.fdef.resource_arg_unique_id()) { + resource_arg_unique_ids->push_back(entry); + } *control_ret_nodes = fbody.control_ret_nodes; } @@ -1101,12 +1112,14 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { absl::InlinedVector arg_nodes; absl::InlinedVector ret_nodes; absl::InlinedVector control_ret_nodes; + absl::InlinedVector, 4> resource_arg_unique_ids; GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes, - &control_ret_nodes); + &control_ret_nodes, &resource_arg_unique_ids); TF_RETURN_IF_ERROR(child_importer.Convert( mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, - llvm::makeArrayRef(attributes.begin(), attributes.end()))); + llvm::makeArrayRef(attributes.begin(), attributes.end()), + resource_arg_unique_ids)); return Status::OK(); } @@ -1121,7 +1134,9 @@ Status ImporterBase::Convert( const absl::InlinedVector& arg_nodes, const absl::InlinedVector& ret_nodes, const absl::InlinedVector& control_ret_nodes, - llvm::ArrayRef attrs) { + llvm::ArrayRef attrs, + const absl::InlinedVector, 4>& + resource_arg_unique_ids) { // TODO(b/122040776): Uses debug info for FunctionDef. auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_), func_name, func_type, attrs); @@ -1144,8 +1159,14 @@ Status ImporterBase::Convert( // pairs. TF_RETURN_IF_ERROR(AddBackedges()); - return ConvertFunctionArgAndRets(function, graph, func_type.getInputs(), - arg_nodes, ret_nodes, control_ret_nodes); + TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph, + func_type.getInputs(), arg_nodes, + ret_nodes, control_ret_nodes)); + for (const auto& entry : resource_arg_unique_ids) { + function.setArgAttr(entry.first, "tf.resource_arg_unique_id", + builder_.getI64IntegerAttr(entry.second)); + } + return Status::OK(); } Status ImporterBase::ConvertFunctionArgAndRets( @@ -1710,10 +1731,14 @@ class GraphDefImporter : public ImporterBase { // output nodes, for function graphs. Arguments and return values are // determined by node op type. Type and shape information of the function are // inferred by the shape refiner in ImporterBase. + // `resource_arg_unique_ids` will be filled with the unique IDs of resource + // variables, as a list of {index, ID} pairs. StatusOr GetArgsRetsAndTypesFromFunctionGraph( mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes); + absl::InlinedVector* ret_nodes, + absl::InlinedVector, 4>* + resource_arg_unique_ids); }; StatusOr GraphDefImporter::Convert( @@ -1734,6 +1759,7 @@ StatusOr GraphDefImporter::Convert( absl::InlinedVector arg_nodes; absl::InlinedVector ret_nodes; absl::InlinedVector control_ret_nodes; + absl::InlinedVector, 4> resource_arg_unique_ids; llvm::SmallVector attrs; if (specs.graph_as_function) { if (specs.prune_unused_nodes || !specs.inputs.empty() || @@ -1742,9 +1768,10 @@ StatusOr GraphDefImporter::Convert( "Pruning of graph is currently unsupported when the main graph is " "converted to a function."); - TF_ASSIGN_OR_RETURN(func_type, - importer.GetArgsRetsAndTypesFromFunctionGraph( - context, &arg_nodes, &ret_nodes)); + TF_ASSIGN_OR_RETURN( + func_type, + importer.GetArgsRetsAndTypesFromFunctionGraph( + context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids)); if (!arg_nodes.empty() || !ret_nodes.empty()) { mlir::Builder b(context); @@ -1805,7 +1832,8 @@ StatusOr GraphDefImporter::Convert( {producer, min_consumer, bad_consumers}))); TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( - "main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs)); + "main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, + resource_arg_unique_ids)); return module; } @@ -1918,7 +1946,9 @@ StatusOr GraphDefImporter::InferMainFunctionType( StatusOr GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes) { + absl::InlinedVector* ret_nodes, + absl::InlinedVector, 4>* + resource_arg_unique_ids) { auto add_node = [](Node* node, absl::InlinedVector* nodes) { auto* attr = node->attrs().Find("index"); if (!attr) @@ -1959,6 +1989,12 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*arg_node.node, /*idx=*/0, builder)); arg_types.push_back(type); + tensorflow::int64 resource_arg_unique_id; + if (TryGetNodeAttr(arg_node.node->attrs(), "_resource_arg_unique_id", + &resource_arg_unique_id)) { + resource_arg_unique_ids->emplace_back(arg_node_and_idx.index(), + resource_arg_unique_id); + } } llvm::SmallVector ret_types; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 413fdbcc3ae..c06f2d148ea 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -167,9 +167,11 @@ class FunctionInstantiationHelper { } // Builds index for nodes that can be used as node's input arguments. + // `resource_arg_unique_id`: if non-negative, will be populated to the + // "_resource_arg_unique_id" attribute of the arg node. Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values, const FunctionDef::ArgAttrs* arg_attrs, - bool ints_on_device) { + bool ints_on_device, int64 resource_arg_unique_id) { bool is_type_list; DataTypeVector dtypes; TF_RETURN_IF_ERROR( @@ -196,6 +198,9 @@ class FunctionInstantiationHelper { DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i]; AddAttr("T", dtype, gnode); AddAttr("index", arg_index, gnode); + if (resource_arg_unique_id >= 0) { + AddAttr("_resource_arg_unique_id", resource_arg_unique_id, gnode); + } if (arg_attrs) { for (const auto& arg_attr : arg_attrs->attr()) { AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr()); @@ -729,8 +734,14 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, auto it = fdef.arg_attr().find(i); const FunctionDef::ArgAttrs* arg_attrs = it != fdef.arg_attr().end() ? &it->second : nullptr; + auto resource_id_it = fdef.resource_arg_unique_id().find(i); + int64 resource_arg_unique_id = + resource_id_it != fdef.resource_arg_unique_id().end() + ? resource_id_it->second + : -1LL; s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs, - ints_on_device); + ints_on_device, resource_arg_unique_id); + if (!s.ok()) { errors::AppendToMessage(&s, "In ", Print(arg_def)); return s; diff --git a/tensorflow/core/framework/function.proto b/tensorflow/core/framework/function.proto index 7b5756ed8c9..71423f9168f 100644 --- a/tensorflow/core/framework/function.proto +++ b/tensorflow/core/framework/function.proto @@ -37,6 +37,17 @@ message FunctionDef { } map arg_attr = 7; + // Unique IDs for each resource argument, used to track aliasing resources. If + // Argument A and Argument B alias each other, then + // resource_arg_unique_ids[A.index] == resource_arg_unique_ids[B.index]. + // + // If this field is empty, none of the arguments could alias; otherwise, every + // resource argument should have an entry in this field. + // + // When instantiated, the unique IDs will be attached to the _Arg nodes' + // "_resource_arg_unique_id" attribute. + map resource_arg_unique_id = 8; + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. reserved 2; diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc index 5baf560801b..50aa4a81926 100644 --- a/tensorflow/core/framework/graph_to_functiondef.cc +++ b/tensorflow/core/framework/graph_to_functiondef.cc @@ -432,6 +432,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, const string& input_name = node_names.GetInputName(node->name()); argdef->set_name(input_name); FunctionDef::ArgAttrs arg_attrs; + int64 resource_arg_unique_id = -1; for (const auto& attr : node->attrs()) { // Only copy internal attributes. These attributes will be applied to // _Arg/Placeholder nodes when this FunctionDef is converted to graph, @@ -440,10 +441,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, if (absl::StartsWith(attr.first, "_")) { arg_attrs.mutable_attr()->insert(attr); } + if (attr.first == "_resource_arg_unique_id") { + resource_arg_unique_id = attr.second.i(); + } } if (arg_attrs.attr_size() > 0) { (*fdef->mutable_arg_attr())[i] = std::move(arg_attrs); } + if (resource_arg_unique_id >= 0) { + (*fdef->mutable_resource_arg_unique_id())[idx] = resource_arg_unique_id; + } tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 9fc9be0e5af..e0d4663f5fb 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -101,6 +101,7 @@ FunctionDefLibrary GetFunctionDefLibraryStub( *(fn_stub->mutable_signature()) = fn.signature(); *(fn_stub->mutable_attr()) = fn.attr(); *(fn_stub->mutable_arg_attr()) = fn.arg_attr(); + *(fn_stub->mutable_resource_arg_unique_id()) = fn.resource_arg_unique_id(); } *stub.mutable_gradient() = fdef_lib.gradient(); return stub;