Remove some remaining references to the TF control dialect
This dialect was removed a few weeks ago, but we still have some special handling remaining in some place. This is cleaning up some of the export path. PiperOrigin-RevId: 318702263 Change-Id: I8ea70062bbff3d65e30a3aedb2a2bcc1efa7fc3c
This commit is contained in:
parent
9b13f1b47f
commit
bd006c354f
@ -131,39 +131,14 @@ StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
|
||||
// Use auto generated function to populate derived attribute.
|
||||
//
|
||||
// Note: This only populates derived attributes for TensorFlow ops that are
|
||||
// generated using the TableGen. Manually defined ops and TF ops with control
|
||||
// edges (i.e TF op names with leading '_' in names) should have all the
|
||||
// generated using the TableGen. Manually defined ops should have all the
|
||||
// attributes present as native MLIR op attributes.
|
||||
|
||||
// If the operation is in the TensorFlow control dialect, we create a
|
||||
// temporary copy in the TensorFlow dialect. This is needed because we
|
||||
// auto-generated the registration for TensorFlow dialect only.
|
||||
// TODO(aminim): this is only done while we're using the TF control dialect
|
||||
// as a temporary stage when exporting to GraphDef. Remove when we update the
|
||||
// export.
|
||||
auto erase_clone = [](mlir::Operation* op) { op->erase(); };
|
||||
std::unique_ptr<mlir::Operation, decltype(erase_clone)> cloned_inst(
|
||||
nullptr, erase_clone);
|
||||
if (inst->getDialect() && inst->getDialect()->getNamespace() == "_tf") {
|
||||
mlir::OperationState result(inst->getLoc(),
|
||||
inst->getName().getStringRef().drop_front());
|
||||
for (mlir::Value operand : inst->getOperands())
|
||||
result.operands.push_back(operand);
|
||||
|
||||
// Add a result type for each non-control result we find
|
||||
for (mlir::Type result_type : inst->getResultTypes())
|
||||
result.types.push_back(result_type);
|
||||
cloned_inst.reset(mlir::Operation::create(result));
|
||||
cloned_inst->setAttrs(inst->getAttrs());
|
||||
inst = cloned_inst.get();
|
||||
}
|
||||
|
||||
// The elements are owned by the MLIRContext.
|
||||
absl::flat_hash_set<absl::string_view> attrs_to_ignore;
|
||||
if (inst->isRegistered()) {
|
||||
// We ignore attributes attached to the operation when there is already a
|
||||
// derived attribute defined in ODS.
|
||||
// TODO(aminim) replace absl::flat_hash_set with a SmallDenseSet.
|
||||
llvm::SmallDenseSet<llvm::StringRef> derived_attrs;
|
||||
CollectDerivedAttrsName(inst, &derived_attrs);
|
||||
for (auto name : derived_attrs) attrs_to_ignore.insert(name.data());
|
||||
@ -198,10 +173,8 @@ StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
|
||||
inst->getName().getStringRef().str());
|
||||
}
|
||||
|
||||
// If the instruction is in the TF dialect, the code above already filtered
|
||||
// results with control types. Here we only add the shapes for the leading
|
||||
// values with ShapedType, assuming values with non-ShapedType are put at the
|
||||
// end of the result.
|
||||
// Here we only add the shapes for the leading values with ShapedType,
|
||||
// assuming values with non-ShapedType are put at the end of the result.
|
||||
if (!ignore_unregistered_attrs && inst->getNumResults() > 0) {
|
||||
auto values = inst->getResults();
|
||||
auto begin = values.begin();
|
||||
|
@ -64,7 +64,6 @@ std::set<std::string>* GlobalOpPrefixes() {
|
||||
static std::set<std::string>* global_op_prefixes = [] {
|
||||
std::set<std::string>* result = new std::set<std::string>;
|
||||
result->insert("tf.");
|
||||
result->insert("_tf.");
|
||||
result->insert("tf_executor.");
|
||||
return result;
|
||||
}();
|
||||
@ -276,7 +275,7 @@ StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
|
||||
// When being converted to MLIR, some prefixes and suffixes are added to the
|
||||
// operation types, and we have to remove them when converting the
|
||||
// operations back to a graph:
|
||||
// - "_tf.", "tf." or "tf_executor." : every operation type has this prefix.
|
||||
// - "tf." or "tf_executor." : every operation type has this prefix.
|
||||
// - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
|
||||
// don't need to consider ".source"/".Source" because the nodes with this
|
||||
// suffix are skipped by the caller and will not be added to the graph.
|
||||
@ -313,9 +312,8 @@ StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
|
||||
// Some control flow ops in TensorFlow Graph have their respective "Ref" ops
|
||||
// as well. For example there is Enter and RefEnter op. RefEnter forwards
|
||||
// the input ref buffer to output. However both Enter and RefEnter are
|
||||
// mapped to tf_executor::EnterOp during import and then to _tf.Enter op in
|
||||
// control dialect. Check if it is a Ref op to correctly map to the
|
||||
// TensorFlow Graph op.
|
||||
// mapped to tf_executor::EnterOp during import. Check if it is a Ref op to
|
||||
// correctly map to the TensorFlow Graph op.
|
||||
if (IsRefTypeControlOp(inst)) op_name = "Ref";
|
||||
TF_ASSIGN_OR_RETURN(auto tf_name,
|
||||
GetTensorFlowOpName(inst->getName().getStringRef()));
|
||||
@ -516,8 +514,7 @@ Status SetSizeAttribute(absl::string_view name, size_t size,
|
||||
}
|
||||
|
||||
bool IsLegacyCallInstruction(mlir::Operation* inst) {
|
||||
return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst) ||
|
||||
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
|
||||
return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst);
|
||||
}
|
||||
|
||||
Status AddTensorFlowOpPrefix(std::string prefix) {
|
||||
|
Loading…
Reference in New Issue
Block a user