From 6c0f93a2ddc8221b2d6ba87c43804b476a8e1fd2 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 15 Jul 2020 10:06:27 -0700 Subject: [PATCH] Export tf.func attributes as AttrValue.NameAttrList. Add support for exporting #tf.func<@name, {}> attributes as AttrValue.NameAttrList when converting TF MLIR to GraphDef. On import #tf.func may be introduced, modeling AttrValue.NameAttrList (func) attributes in TensorFlow. This updates the export path so round-tripping from Graph -> TF MLIR -> Graph will preserve such attributes properly. PiperOrigin-RevId: 321382977 Change-Id: Ica0aa2eede960e76b69074e0a8fc7f9306dc6a0c --- .../tests/mlir2graphdef/func_attr.mlir | 40 +++++++++++++++++++ .../mlir/tensorflow/utils/export_utils.cc | 29 ++++++++++---- 2 files changed, 62 insertions(+), 7 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_attr.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_attr.mlir new file mode 100644 index 00000000000..fadb62c44b8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_attr.mlir @@ -0,0 +1,40 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s | tf-mlir-translate -graphdef-to-mlir | tf-mlir-translate -mlir-to-graphdef | FileCheck %s + +// Tests #tf.func attributes are exported as AttrValue.NameAttrList attributes +// with its attr field populated with nested attributes. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} { + func @main() { + tf_executor.graph { + %control = tf_executor.island wraps "tf.NoOp"() {_f = #tf.func<@callee, {attr2 = true, attr3 = 8.0 : f32}>} : () -> () + tf_executor.fetch + } + return + } + func @callee() { + tf_executor.graph { + tf_executor.fetch + } + return + } +} + +// CHECK: op: "NoOp" +// CHECK-NEXT: attr +// CHECK-NEXT: key: "_f" +// CHECK-NEXT: value +// CHECK-NEXT: func +// CHECK-NEXT: name: [[FUNC_NAME:".*"]] +// CHECK-NEXT: attr +// CHECK-NEXT: key: "attr2" +// CHECK-NEXT: value +// CHECK-NEXT: b: true +// CHECK: attr +// CHECK-NEXT: key: "attr3" +// CHECK-NEXT: value +// CHECK-NEXT: f: 8 + +// CHECK: library +// CHECK-NEXT: function +// CHECK-NEXT: signature +// CHECK-NEXT: name: [[FUNC_NAME]] diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 852bc72d7de..7e018966396 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -121,6 +121,20 @@ Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) { return Status::OK(); } +Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { + value->mutable_func()->set_name(attr.getValue().str()); + return Status::OK(); +} + +Status ConvertAttribute(const mlir::TF::FuncAttr& attr, AttrValue* value) { + TF_RETURN_IF_ERROR( + ConvertAttribute(attr.GetName().cast(), value)); + TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(), + /*attrs_to_ignore=*/{}, + value->mutable_func()->mutable_attr())); + return Status::OK(); +} + Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { absl::string_view attr_value(attr.getValue().data(), attr.getValue().size()); switch (mangling_util::GetMangledKind(attr_value)) { @@ -160,11 +174,6 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { return Status::OK(); } -Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { - value->mutable_func()->set_name(std::string(attr.getValue())); - return Status::OK(); -} - Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { auto* list = value->mutable_list(); for (mlir::Attribute a : attr.getValue()) { @@ -372,8 +381,8 @@ Status ConvertAttributes( AttrValue value; switch (attr.getKind()) { case mlir::StandardAttributes::SymbolRef: { - auto func_attr = attr.cast(); - value.mutable_func()->set_name(std::string(func_attr.getValue())); + TF_RETURN_IF_ERROR( + ConvertAttribute(attr.cast(), &value)); func_call_attrs[string(name)] = value; continue; } @@ -415,6 +424,12 @@ Status ConvertAttributes( TF_RETURN_IF_ERROR( ConvertAttribute(attr.cast(), &value)); break; + case static_cast(mlir::TF::AttrKind::FUNC): { + TF_RETURN_IF_ERROR( + ConvertAttribute(attr.cast(), &value)); + func_call_attrs[string(name)] = value; + continue; + } // AffineMap kind is not implemented. case mlir::StandardAttributes::AffineMap: return errors::Unimplemented("AffineMap attribute (needed for '",