diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-control-dep.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-control-dep.pbtxt new file mode 100644 index 00000000000..258d2059fc4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-control-dep.pbtxt @@ -0,0 +1,68 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_node -o - | FileCheck %s --dump-input=fail + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + } + } + } +} +node { + name: "variable_node" + op: "Const" + input: "^input" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "output_node" + op: "Identity" + input: "variable_node" + input: "^input" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +library { +} +versions { +} + +# CHECK: func @main(%[[ARG_0:[a-z0-9]+]]: tensor<f32>) -> tensor<f32> +# CHECK-NEXT: tf.entry_function = {inputs = "input", outputs = "output_node"} +# CHECK: %[[GRAPH:[0-9]+]] = tf_executor.graph +# CHECK: %[[CONST:[0-9]+]]:2 = tf_executor.island wraps "tf.Const"() +# CHECK: %[[OUTPUT:[0-9]+]]:2 = tf_executor.island wraps "tf.Identity"(%[[CONST]]#0) +# CHECK: tf_executor.fetch %[[OUTPUT]]#0 +# CHECK: return %[[GRAPH]] diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 4acbfe98bde..36c365577e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -1064,6 +1064,11 @@ Status ImporterBase::ConvertFunctionArgAndRets( // Collect mapping of OutputTensor to associated block arg. arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def); island->getResult(0)->replaceAllUsesWith(arg_def); + // Erase control outputs from feed. + auto control_uses = island->getResult(1)->getUses(); + for (auto& control_use : llvm::make_early_inc_range(control_uses)) + control_use.getOwner()->eraseOperand(control_use.getOperandNumber()); + island->dropAllReferences(); island->erase(); continue;