Export result attributes from MLIR function to exported Graph via _Retval node attributes.

Special handling is added for devices similarly to attributes on _Arg nodes for export.

PiperOrigin-RevId: 343955240
Change-Id: Ie4cd9f83ac05b012a69da083b6b35bb516ceb8de
This commit is contained in:
Andy Ly 2020-11-23 16:55:30 -08:00 committed by TensorFlower Gardener
parent d4de90bf96
commit 5a086ba4f1
2 changed files with 39 additions and 7 deletions
tensorflow/compiler/mlir/tensorflow

View File

@ -1,9 +1,10 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -tf-graph-as-function -o - | FileCheck %s
// Verify arg attributes are exported as device assignment for arg nodes.
// Verify arg/ret attributes are exported as device assignment for arg/retval
// nodes.
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 121 : i32}} {
func @main(%arg0: tensor<*xf32> {tf.device = "/CPU:0"}, %arg1: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>)
func @main(%arg0: tensor<*xf32> {tf.device = "/CPU:0"}, %arg1: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32> {tf.device = "/CPU:1"})
attributes {tf.entry_function = {inputs = "args_0,args_1", outputs = "rets_0,rets_1"}} {
%0:2 = tf_executor.graph {
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) {T = ["tfdtype$DT_FLOAT", "tfdtype$DT_INT32"], device = "", name = "identity"} : (tensor<*xf32>, tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>)
@ -15,18 +16,39 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK: node {
// CHECK-NEXT: name: "args_0"
// CHECK-NEXT: op: "_Arg"
// CHECK: device: "/CPU:0"
// CHECK: attr {
// CHECK: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 0
// CHECK_NEXT: }
//
// CHECK: node {
// CHECK-NEXT: name: "args_1"
// CHECK-NOT: device: "/CPU:0"
// CHECK-NEXT: op: "_Arg"
// CHECK-NOT: device
// CHECK: attr {
// CHECK: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 1
//
// CHECK: node {
// CHECK: op: "IdentityN"
//
// CHECK: node {
// CHECK-NEXT: name: "rets_0"
// CHECK-NEXT: op: "_Retval"
// CHECK-NOT: device
// CHECK: attr {
// CHECK: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 0
//
// CHECK: node {
// CHECK-NEXT: name: "rets_1"
// CHECK-NEXT: op: "_Retval"
// CHECK: device: "/CPU:1"
// CHECK: attr {
// CHECK: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 1
// CHECK_NEXT: }

View File

@ -250,8 +250,6 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
return node_def;
}
// TODO(b/160014479): Support exporting function result attributes as optional
// attributes.
StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
mlir::FuncOp function, Value operand, unsigned index,
llvm::StringRef name) {
@ -272,6 +270,18 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
AttrValue index_attr;
index_attr.set_i(index);
(*node_def->mutable_attr())["index"] = index_attr;
if (auto device_attr =
function.getResultAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
*node_def->mutable_device() = device_attr.getValue().str();
llvm::ArrayRef<mlir::NamedAttribute> func_res_i_attrs =
function.getResultAttrs(index);
absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
TF_RETURN_IF_ERROR(ConvertAttributes(func_res_i_attrs, attrs_to_ignore,
/*remove_ref_type=*/false,
node_def->mutable_attr()));
return node_def;
}