Export tf.func attributes as AttrValue.NameAttrList.
Add support for exporting #tf.func<@name, {}> attributes as AttrValue.NameAttrList when converting TF MLIR to GraphDef. On import #tf.func may be introduced, modeling AttrValue.NameAttrList (func) attributes in TensorFlow. This updates the export path so round-tripping from Graph -> TF MLIR -> Graph will preserve such attributes properly. PiperOrigin-RevId: 321382977 Change-Id: Ica0aa2eede960e76b69074e0a8fc7f9306dc6a0c
This commit is contained in:
parent
b4cb31ff3d
commit
6c0f93a2dd
@ -0,0 +1,40 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s | tf-mlir-translate -graphdef-to-mlir | tf-mlir-translate -mlir-to-graphdef | FileCheck %s
|
||||
|
||||
// Tests #tf.func attributes are exported as AttrValue.NameAttrList attributes
|
||||
// with its attr field populated with nested attributes.
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} {
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
%control = tf_executor.island wraps "tf.NoOp"() {_f = #tf.func<@callee, {attr2 = true, attr3 = 8.0 : f32}>} : () -> ()
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
func @callee() {
|
||||
tf_executor.graph {
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: op: "NoOp"
|
||||
// CHECK-NEXT: attr
|
||||
// CHECK-NEXT: key: "_f"
|
||||
// CHECK-NEXT: value
|
||||
// CHECK-NEXT: func
|
||||
// CHECK-NEXT: name: [[FUNC_NAME:".*"]]
|
||||
// CHECK-NEXT: attr
|
||||
// CHECK-NEXT: key: "attr2"
|
||||
// CHECK-NEXT: value
|
||||
// CHECK-NEXT: b: true
|
||||
// CHECK: attr
|
||||
// CHECK-NEXT: key: "attr3"
|
||||
// CHECK-NEXT: value
|
||||
// CHECK-NEXT: f: 8
|
||||
|
||||
// CHECK: library
|
||||
// CHECK-NEXT: function
|
||||
// CHECK-NEXT: signature
|
||||
// CHECK-NEXT: name: [[FUNC_NAME]]
|
@ -121,6 +121,20 @@ Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
|
||||
value->mutable_func()->set_name(attr.getValue().str());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertAttribute(const mlir::TF::FuncAttr& attr, AttrValue* value) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAttribute(attr.GetName().cast<mlir::FlatSymbolRefAttr>(), value));
|
||||
TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(),
|
||||
/*attrs_to_ignore=*/{},
|
||||
value->mutable_func()->mutable_attr()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
|
||||
absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
|
||||
switch (mangling_util::GetMangledKind(attr_value)) {
|
||||
@ -160,11 +174,6 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
|
||||
value->mutable_func()->set_name(std::string(attr.getValue()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
|
||||
auto* list = value->mutable_list();
|
||||
for (mlir::Attribute a : attr.getValue()) {
|
||||
@ -372,8 +381,8 @@ Status ConvertAttributes(
|
||||
AttrValue value;
|
||||
switch (attr.getKind()) {
|
||||
case mlir::StandardAttributes::SymbolRef: {
|
||||
auto func_attr = attr.cast<mlir::FlatSymbolRefAttr>();
|
||||
value.mutable_func()->set_name(std::string(func_attr.getValue()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAttribute(attr.cast<mlir::FlatSymbolRefAttr>(), &value));
|
||||
func_call_attrs[string(name)] = value;
|
||||
continue;
|
||||
}
|
||||
@ -415,6 +424,12 @@ Status ConvertAttributes(
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAttribute(attr.cast<mlir::TF::ShapeAttr>(), &value));
|
||||
break;
|
||||
case static_cast<unsigned>(mlir::TF::AttrKind::FUNC): {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertAttribute(attr.cast<mlir::TF::FuncAttr>(), &value));
|
||||
func_call_attrs[string(name)] = value;
|
||||
continue;
|
||||
}
|
||||
// AffineMap kind is not implemented.
|
||||
case mlir::StandardAttributes::AffineMap:
|
||||
return errors::Unimplemented("AffineMap attribute (needed for '",
|
||||
|
Loading…
Reference in New Issue
Block a user