diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt new file mode 100644 index 00000000000..6fec080be58 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt @@ -0,0 +1,14 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 + +# CHECK: Graph import failed: Invalid argument: Invalid output index 1 specified for node: input + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc index d713131d79e..ed1a2633eae 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc @@ -1368,10 +1368,9 @@ StatusOr Importer::InferMainFunctionType( // Output nodes as function returns. for (const auto& ret : *ret_nodes) { - if (ret.node->num_outputs() < 1) { - return errors::FailedPrecondition( - "Invalid output node; should have at least 1 output: " + - ret.node->name()); + if (ret.node->num_outputs() <= ret.index) { + return errors::InvalidArgument("Invalid output index ", ret.index, + " specified for node: ", ret.node->name()); } auto* shape_context = shape_refiner_->GetExtendedContext(ret.node); TF_ASSIGN_OR_RETURN(auto type,