Export tf.func attributes as AttrValue.NameAttrList.
Add support for exporting #tf.func<@name, {}> attributes as AttrValue.NameAttrList when converting TF MLIR to GraphDef. On import #tf.func may be introduced, modeling AttrValue.NameAttrList (func) attributes in TensorFlow. This updates the export path so round-tripping from Graph -> TF MLIR -> Graph will preserve such attributes properly. PiperOrigin-RevId: 321382977 Change-Id: Ica0aa2eede960e76b69074e0a8fc7f9306dc6a0c
This commit is contained in:
parent
b4cb31ff3d
commit
6c0f93a2dd
@ -0,0 +1,40 @@
|
|||||||
|
// RUN: tf-mlir-translate -mlir-to-graphdef %s | tf-mlir-translate -graphdef-to-mlir | tf-mlir-translate -mlir-to-graphdef | FileCheck %s
|
||||||
|
|
||||||
|
// Tests #tf.func attributes are exported as AttrValue.NameAttrList attributes
|
||||||
|
// with its attr field populated with nested attributes.
|
||||||
|
|
||||||
|
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} {
|
||||||
|
func @main() {
|
||||||
|
tf_executor.graph {
|
||||||
|
%control = tf_executor.island wraps "tf.NoOp"() {_f = #tf.func<@callee, {attr2 = true, attr3 = 8.0 : f32}>} : () -> ()
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
func @callee() {
|
||||||
|
tf_executor.graph {
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: op: "NoOp"
|
||||||
|
// CHECK-NEXT: attr
|
||||||
|
// CHECK-NEXT: key: "_f"
|
||||||
|
// CHECK-NEXT: value
|
||||||
|
// CHECK-NEXT: func
|
||||||
|
// CHECK-NEXT: name: [[FUNC_NAME:".*"]]
|
||||||
|
// CHECK-NEXT: attr
|
||||||
|
// CHECK-NEXT: key: "attr2"
|
||||||
|
// CHECK-NEXT: value
|
||||||
|
// CHECK-NEXT: b: true
|
||||||
|
// CHECK: attr
|
||||||
|
// CHECK-NEXT: key: "attr3"
|
||||||
|
// CHECK-NEXT: value
|
||||||
|
// CHECK-NEXT: f: 8
|
||||||
|
|
||||||
|
// CHECK: library
|
||||||
|
// CHECK-NEXT: function
|
||||||
|
// CHECK-NEXT: signature
|
||||||
|
// CHECK-NEXT: name: [[FUNC_NAME]]
|
@ -121,6 +121,20 @@ Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
|
||||||
|
value->mutable_func()->set_name(attr.getValue().str());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConvertAttribute(const mlir::TF::FuncAttr& attr, AttrValue* value) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ConvertAttribute(attr.GetName().cast<mlir::FlatSymbolRefAttr>(), value));
|
||||||
|
TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(),
|
||||||
|
/*attrs_to_ignore=*/{},
|
||||||
|
value->mutable_func()->mutable_attr()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
|
Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
|
||||||
absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
|
absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
|
||||||
switch (mangling_util::GetMangledKind(attr_value)) {
|
switch (mangling_util::GetMangledKind(attr_value)) {
|
||||||
@ -160,11 +174,6 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
|
|
||||||
value->mutable_func()->set_name(std::string(attr.getValue()));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
|
Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
|
||||||
auto* list = value->mutable_list();
|
auto* list = value->mutable_list();
|
||||||
for (mlir::Attribute a : attr.getValue()) {
|
for (mlir::Attribute a : attr.getValue()) {
|
||||||
@ -372,8 +381,8 @@ Status ConvertAttributes(
|
|||||||
AttrValue value;
|
AttrValue value;
|
||||||
switch (attr.getKind()) {
|
switch (attr.getKind()) {
|
||||||
case mlir::StandardAttributes::SymbolRef: {
|
case mlir::StandardAttributes::SymbolRef: {
|
||||||
auto func_attr = attr.cast<mlir::FlatSymbolRefAttr>();
|
TF_RETURN_IF_ERROR(
|
||||||
value.mutable_func()->set_name(std::string(func_attr.getValue()));
|
ConvertAttribute(attr.cast<mlir::FlatSymbolRefAttr>(), &value));
|
||||||
func_call_attrs[string(name)] = value;
|
func_call_attrs[string(name)] = value;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -415,6 +424,12 @@ Status ConvertAttributes(
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ConvertAttribute(attr.cast<mlir::TF::ShapeAttr>(), &value));
|
ConvertAttribute(attr.cast<mlir::TF::ShapeAttr>(), &value));
|
||||||
break;
|
break;
|
||||||
|
case static_cast<unsigned>(mlir::TF::AttrKind::FUNC): {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ConvertAttribute(attr.cast<mlir::TF::FuncAttr>(), &value));
|
||||||
|
func_call_attrs[string(name)] = value;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
// AffineMap kind is not implemented.
|
// AffineMap kind is not implemented.
|
||||||
case mlir::StandardAttributes::AffineMap:
|
case mlir::StandardAttributes::AffineMap:
|
||||||
return errors::Unimplemented("AffineMap attribute (needed for '",
|
return errors::Unimplemented("AffineMap attribute (needed for '",
|
||||||
|
Loading…
Reference in New Issue
Block a user