Export tf.func attributes as AttrValue.NameAttrList.

Add support for exporting #tf.func<@name, {}> attributes as AttrValue.NameAttrList when converting TF MLIR to GraphDef. On import #tf.func may be introduced, modeling AttrValue.NameAttrList (func) attributes in TensorFlow. This updates the export path so round-tripping from Graph -> TF MLIR -> Graph will preserve such attributes properly.

PiperOrigin-RevId: 321382977
Change-Id: Ica0aa2eede960e76b69074e0a8fc7f9306dc6a0c
This commit is contained in:
Andy Ly 2020-07-15 10:06:27 -07:00 committed by TensorFlower Gardener
parent b4cb31ff3d
commit 6c0f93a2dd
2 changed files with 62 additions and 7 deletions

View File

@ -0,0 +1,40 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s | tf-mlir-translate -graphdef-to-mlir | tf-mlir-translate -mlir-to-graphdef | FileCheck %s
// Tests #tf.func attributes are exported as AttrValue.NameAttrList attributes
// with its attr field populated with nested attributes.
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} {
func @main() {
tf_executor.graph {
%control = tf_executor.island wraps "tf.NoOp"() {_f = #tf.func<@callee, {attr2 = true, attr3 = 8.0 : f32}>} : () -> ()
tf_executor.fetch
}
return
}
func @callee() {
tf_executor.graph {
tf_executor.fetch
}
return
}
}
// CHECK: op: "NoOp"
// CHECK-NEXT: attr
// CHECK-NEXT: key: "_f"
// CHECK-NEXT: value
// CHECK-NEXT: func
// CHECK-NEXT: name: [[FUNC_NAME:".*"]]
// CHECK-NEXT: attr
// CHECK-NEXT: key: "attr2"
// CHECK-NEXT: value
// CHECK-NEXT: b: true
// CHECK: attr
// CHECK-NEXT: key: "attr3"
// CHECK-NEXT: value
// CHECK-NEXT: f: 8
// CHECK: library
// CHECK-NEXT: function
// CHECK-NEXT: signature
// CHECK-NEXT: name: [[FUNC_NAME]]

View File

@ -121,6 +121,20 @@ Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
return Status::OK(); return Status::OK();
} }
Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
value->mutable_func()->set_name(attr.getValue().str());
return Status::OK();
}
Status ConvertAttribute(const mlir::TF::FuncAttr& attr, AttrValue* value) {
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.GetName().cast<mlir::FlatSymbolRefAttr>(), value));
TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(),
/*attrs_to_ignore=*/{},
value->mutable_func()->mutable_attr()));
return Status::OK();
}
Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
absl::string_view attr_value(attr.getValue().data(), attr.getValue().size()); absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
switch (mangling_util::GetMangledKind(attr_value)) { switch (mangling_util::GetMangledKind(attr_value)) {
@ -160,11 +174,6 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
return Status::OK(); return Status::OK();
} }
Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
value->mutable_func()->set_name(std::string(attr.getValue()));
return Status::OK();
}
Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
auto* list = value->mutable_list(); auto* list = value->mutable_list();
for (mlir::Attribute a : attr.getValue()) { for (mlir::Attribute a : attr.getValue()) {
@ -372,8 +381,8 @@ Status ConvertAttributes(
AttrValue value; AttrValue value;
switch (attr.getKind()) { switch (attr.getKind()) {
case mlir::StandardAttributes::SymbolRef: { case mlir::StandardAttributes::SymbolRef: {
auto func_attr = attr.cast<mlir::FlatSymbolRefAttr>(); TF_RETURN_IF_ERROR(
value.mutable_func()->set_name(std::string(func_attr.getValue())); ConvertAttribute(attr.cast<mlir::FlatSymbolRefAttr>(), &value));
func_call_attrs[string(name)] = value; func_call_attrs[string(name)] = value;
continue; continue;
} }
@ -415,6 +424,12 @@ Status ConvertAttributes(
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::TF::ShapeAttr>(), &value)); ConvertAttribute(attr.cast<mlir::TF::ShapeAttr>(), &value));
break; break;
case static_cast<unsigned>(mlir::TF::AttrKind::FUNC): {
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::TF::FuncAttr>(), &value));
func_call_attrs[string(name)] = value;
continue;
}
// AffineMap kind is not implemented. // AffineMap kind is not implemented.
case mlir::StandardAttributes::AffineMap: case mlir::StandardAttributes::AffineMap:
return errors::Unimplemented("AffineMap attribute (needed for '", return errors::Unimplemented("AffineMap attribute (needed for '",