Export result attributes from MLIR function to exported Graph via _Retval node attributes.
Special handling is added for devices similarly to attributes on _Arg nodes for export. PiperOrigin-RevId: 343955240 Change-Id: Ie4cd9f83ac05b012a69da083b6b35bb516ceb8de
This commit is contained in:
parent
d4de90bf96
commit
5a086ba4f1
tensorflow/compiler/mlir/tensorflow
@ -1,9 +1,10 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -tf-graph-as-function -o - | FileCheck %s
|
||||
|
||||
// Verify arg attributes are exported as device assignment for arg nodes.
|
||||
// Verify arg/ret attributes are exported as device assignment for arg/retval
|
||||
// nodes.
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 121 : i32}} {
|
||||
func @main(%arg0: tensor<*xf32> {tf.device = "/CPU:0"}, %arg1: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>)
|
||||
func @main(%arg0: tensor<*xf32> {tf.device = "/CPU:0"}, %arg1: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32> {tf.device = "/CPU:1"})
|
||||
attributes {tf.entry_function = {inputs = "args_0,args_1", outputs = "rets_0,rets_1"}} {
|
||||
%0:2 = tf_executor.graph {
|
||||
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) {T = ["tfdtype$DT_FLOAT", "tfdtype$DT_INT32"], device = "", name = "identity"} : (tensor<*xf32>, tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>)
|
||||
@ -15,18 +16,39 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "args_0"
|
||||
// CHECK-NEXT: op: "_Arg"
|
||||
// CHECK: device: "/CPU:0"
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 0
|
||||
// CHECK_NEXT: }
|
||||
//
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "args_1"
|
||||
// CHECK-NOT: device: "/CPU:0"
|
||||
// CHECK-NEXT: op: "_Arg"
|
||||
// CHECK-NOT: device
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 1
|
||||
//
|
||||
// CHECK: node {
|
||||
// CHECK: op: "IdentityN"
|
||||
//
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "rets_0"
|
||||
// CHECK-NEXT: op: "_Retval"
|
||||
// CHECK-NOT: device
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 0
|
||||
//
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "rets_1"
|
||||
// CHECK-NEXT: op: "_Retval"
|
||||
// CHECK: device: "/CPU:1"
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 1
|
||||
// CHECK_NEXT: }
|
@ -250,8 +250,6 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
|
||||
return node_def;
|
||||
}
|
||||
|
||||
// TODO(b/160014479): Support exporting function result attributes as optional
|
||||
// attributes.
|
||||
StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
|
||||
mlir::FuncOp function, Value operand, unsigned index,
|
||||
llvm::StringRef name) {
|
||||
@ -272,6 +270,18 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
|
||||
AttrValue index_attr;
|
||||
index_attr.set_i(index);
|
||||
(*node_def->mutable_attr())["index"] = index_attr;
|
||||
|
||||
if (auto device_attr =
|
||||
function.getResultAttrOfType<mlir::StringAttr>(index, kDeviceAttr))
|
||||
*node_def->mutable_device() = device_attr.getValue().str();
|
||||
|
||||
llvm::ArrayRef<mlir::NamedAttribute> func_res_i_attrs =
|
||||
function.getResultAttrs(index);
|
||||
absl::flat_hash_set<absl::string_view> attrs_to_ignore = {kDeviceAttr};
|
||||
TF_RETURN_IF_ERROR(ConvertAttributes(func_res_i_attrs, attrs_to_ignore,
|
||||
/*remove_ref_type=*/false,
|
||||
node_def->mutable_attr()));
|
||||
|
||||
return node_def;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user