Ignore extra attributes when the tf dialect op is exported to NodeDef for eager

execution

The imported GraphDef might have extra attributes which are not part of the op
registration but added by other tools. The existence of these attributes might
cause error when they are added to the eager op for constant folding. So we
want to ignore them when the op is converted to an NodeDef during the constant folding.

PiperOrigin-RevId: 264691406
This commit is contained in:
Feng Liu 2019-08-21 14:21:57 -07:00 committed by TensorFlower Gardener
parent 9dc6644a31
commit ce6364764f
10 changed files with 99 additions and 22 deletions

View File

@ -314,8 +314,8 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
// We pass empty string for the original node_def name since Flex runtime
// does not care about this being set correctly on node_def. There is no
// "easy" (see b/120948529) way yet to get this from MLIR inst.
auto status_or_node_def =
tensorflow::ConvertTFDialectOpToNodeDef(inst, /*name=*/"");
auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef(
inst, /*name=*/"", /*ignore_unregistered_attrs=*/true);
if (!status_or_node_def.ok()) {
inst->emitOpError(
Twine("failed to obtain TensorFlow nodedef with status: " +

View File

@ -290,6 +290,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
],

View File

@ -63,3 +63,17 @@ func @testSideEffectOp() -> tensor<3xf32> {
// CHECK: return %[[random]]
return %1: tensor<3xf32>
}
// Ops with unimplemnted attributes which couldn't be added to the TFE_Op.
// CHECK-LABEL: func @testUnimplementedOp() -> (tensor<i32>, tensor<i32>)
func @testUnimplementedOp() -> (tensor<i32>, tensor<i32>) {
%0 = constant dense<1> : tensor<i32>
%1 = constant dense<2> : tensor<i32>
%2 = "tf.Maximum"(%0, %1) {_output_shapes = ["tfshape$"]} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%3 = "tf.Minimum"(%0, %1) {random_attr = "hello"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %2, %3: tensor<i32>, tensor<i32>
// CHECK-NEXT: %[[CST:.*]] = constant
// CHECK-NEXT: %[[CST1:.*]] = constant
// CHECK-NEXT: return %[[CST]], %[[CST1]]
}

View File

@ -300,13 +300,17 @@ Status Exporter::AddInstructionNode(mlir::Operation* inst) {
// check is too conservative given we could use a OpDef.
if (auto abstract_op = inst->getAbstractOperation()) {
if (&abstract_op->dialect == tf_dialect_) {
TF_ASSIGN_OR_RETURN(node_def, ConvertTFDialectOpToNodeDef(inst, name));
TF_ASSIGN_OR_RETURN(
node_def, ConvertTFDialectOpToNodeDef(
inst, name, /*ignore_unregistered_attrs=*/false));
}
}
// Convert TF control flow dialect ops.
if (!node_def) {
TF_ASSIGN_OR_RETURN(node_def,
GetOperationNodeDef(inst, name.c_str(), getTFOpName));
absl::flat_hash_set<absl::string_view> attrs_to_ignore;
TF_ASSIGN_OR_RETURN(
node_def, GetOperationNodeDef(attrs_to_ignore, inst, name.c_str(),
getTFOpName));
}
Node* node = graph_->AddNode(*node_def, &status);
TF_RETURN_IF_ERROR(status);
@ -562,7 +566,8 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
// Ignore the gradient and is_stateful attribute on the function as they have
// been handled above.
absl::flat_hash_set<string> attrs_to_ignore = {grad_string, stateful_string};
absl::flat_hash_set<absl::string_view> attrs_to_ignore = {
grad_string.data(), stateful_string.data()};
llvm::SmallVector<mlir::NamedAttribute, 8> funcAttrs(
function.getDialectAttrs());
TF_RETURN_IF_ERROR(

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringSet.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
@ -65,7 +68,7 @@ Status SetAttribute(absl::string_view name, ContainerT types,
// definitions and isn't a header file.
#include "tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator.inc"
static StatusOr<string> getTensorFlowOpName(llvm::StringRef op_name) {
StatusOr<string> getTensorFlowOpName(llvm::StringRef op_name) {
if (!op_name.consume_front("tf.")) {
return errors::FailedPrecondition("op name not prefixed with 'tf.': " +
op_name.str());
@ -73,12 +76,54 @@ static StatusOr<string> getTensorFlowOpName(llvm::StringRef op_name) {
return op_name.str();
}
// Collect all the unregistered attributes for an TF dialect operation.
// Attributes "name" and "device" are not included because they are not part
// of an TF op attributes.
Status GetUnregisteredAttrs(
mlir::Operation* inst,
absl::flat_hash_set<absl::string_view>* attrs_to_ignore) {
TF_ASSIGN_OR_RETURN(auto op_name,
getTensorFlowOpName(inst->getName().getStringRef()));
const tensorflow::OpRegistrationData* op_reg_data;
auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
if (!status.ok()) {
// This is likely a function call node, so we should continue.
VLOG(1) << status.ToString();
return Status::OK();
}
// Collect all the registered attributes.
llvm::DenseSet<llvm::StringRef> registered_attrs;
registered_attrs.insert("name");
registered_attrs.insert("device");
for (const auto& attr_def : op_reg_data->op_def.attr()) {
registered_attrs.insert(attr_def.name());
}
// Attributes are not in the registered attributes set will be ignored.
for (auto& attr : inst->getAttrs()) {
auto attr_name = attr.first.c_str();
if (registered_attrs.find(attr_name) == registered_attrs.end()) {
attrs_to_ignore->insert(attr_name);
}
}
return Status::OK();
}
} // namespace
StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
mlir::Operation* inst, llvm::StringRef name) {
TF_ASSIGN_OR_RETURN(auto node_def,
GetOperationNodeDef(inst, name, getTensorFlowOpName));
mlir::Operation* inst, llvm::StringRef name,
bool ignore_unregistered_attrs) {
// The elements are owned by the MLIRContext.
absl::flat_hash_set<absl::string_view> attrs_to_ignore;
if (ignore_unregistered_attrs) {
TF_RETURN_IF_ERROR(GetUnregisteredAttrs(inst, &attrs_to_ignore));
}
TF_ASSIGN_OR_RETURN(
auto node_def,
GetOperationNodeDef(attrs_to_ignore, inst, name, getTensorFlowOpName));
// Use auto generated function to populate derived attribute.
//

View File

@ -24,9 +24,13 @@ limitations under the License.
namespace tensorflow {
// Converts an MLIR operation to TensorFlow NodeDef with given node name. This
// name should be unique to the graph it is being inserted to.
// name should be unique to the graph it is being inserted to. If the
// `ignore_unregistered_attrs` argument is set to true, the attributes which are
// not in the op registry will be ignored. Set it to true if the returned
// NodeDef will be excuted by the linked TF Eager runtime.
stream_executor::port::StatusOr<std::unique_ptr<NodeDef>>
ConvertTFDialectOpToNodeDef(mlir::Operation* inst, llvm::StringRef name);
ConvertTFDialectOpToNodeDef(mlir::Operation* inst, llvm::StringRef name,
bool ignore_unregistered_attrs);
} // namespace tensorflow

View File

@ -59,7 +59,8 @@ static LogicalResult MlirToTfNodeDef(ModuleOp module,
return failure();
}
auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef(op, "node_name");
auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef(
op, "node_name", /*ignore_unregistered_attrs=*/false);
if (!node_def_or.ok()) {
op->emitError("failed to convert to TF NodeDef:")
<< node_def_or.status().ToString();

View File

@ -78,7 +78,8 @@ mlir::LogicalResult EvaluateOperation(
if (auto attr = inst->getAttrOfType<mlir::StringAttr>("name")) {
node_name = attr.getValue();
}
auto node_def_or = ConvertTFDialectOpToNodeDef(inst, node_name.c_str());
auto node_def_or = ConvertTFDialectOpToNodeDef(
inst, node_name.c_str(), /*ignore_unregistered_attrs=*/true);
RETURN_FAILURE_IF_ERROR(node_def_or.status());
const auto& node_def = node_def_or.ValueOrDie();
TFE_Op* op = TFE_NewOp(context, node_def->op().c_str(), status);

View File

@ -187,6 +187,7 @@ void UpdateCompositeWhileOp(NodeDef* node_def) {
} // anonymous namespace
StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
mlir::Operation* inst, llvm::StringRef name,
OpNameMappingFunc op_name_func) {
auto node_def = absl::make_unique<NodeDef>();
@ -208,7 +209,6 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
}
// Add the node attributes.
absl::flat_hash_set<string> attrs_to_ignore;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertAttributes(inst->getAttrs(), attrs_to_ignore,
node_def->mutable_attr()),
@ -224,9 +224,10 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
return node_def;
}
Status ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,
const absl::flat_hash_set<string>& attrs_to_ignore,
AttrValueMap* values) {
Status ConvertAttributes(
const llvm::ArrayRef<mlir::NamedAttribute> attrs,
const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
AttrValueMap* values) {
AttrValueMap func_call_attrs;
for (const mlir::NamedAttribute& named_attr : attrs) {
auto name_strref = named_attr.first.str();

View File

@ -43,17 +43,21 @@ using OpNameMappingFunc = std::function<StatusOr<std::string>(llvm::StringRef)>;
// Converts an MLIR operation to TensorFlow NodeDef with given node name. This
// name should be unique to the graph it is being inserted into. `op_name_func`
// is to map the op name of `inst` to its op name in TensorFlow.
// is to map the op name of `inst` to its op name in TensorFlow. "name" and
// "device" attributes are ignored by default. Use attrs_to_ignore to specify
// any other attributes that should be ignored.
StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
mlir::Operation* inst, llvm::StringRef name,
OpNameMappingFunc op_name_func);
// Converts MLIR attributes with values to their tensorflow equivalent.
// "name" and "device" attributes are ignored by default. Use attrs_to_ignore to
// specify any other attributes that should be ignored.
Status ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,
const absl::flat_hash_set<string>& attrs_to_ignore,
AttrValueMap* values);
Status ConvertAttributes(
const llvm::ArrayRef<mlir::NamedAttribute> attrs,
const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
AttrValueMap* values);
// Sets type attribute with the given name. If the attribute already exists with
// a different value, returns an error.