Return an error if the output index is invalid while importing GraphDef to MLIR

Currently, it crashes in such cases.

PiperOrigin-RevId: 261603894
This commit is contained in:
Smit Hinsu 2019-08-04 18:43:16 -07:00 committed by TensorFlower Gardener
parent cc6e729c9e
commit df79f08595
2 changed files with 17 additions and 4 deletions

View File

@ -0,0 +1,14 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
# CHECK: Graph import failed: Invalid argument: Invalid output index 1 specified for node: input
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}

View File

@ -1368,10 +1368,9 @@ StatusOr<mlir::FunctionType> Importer::InferMainFunctionType(
// Output nodes as function returns.
for (const auto& ret : *ret_nodes) {
if (ret.node->num_outputs() < 1) {
return errors::FailedPrecondition(
"Invalid output node; should have at least 1 output: " +
ret.node->name());
if (ret.node->num_outputs() <= ret.index) {
return errors::InvalidArgument("Invalid output index ", ret.index,
" specified for node: ", ret.node->name());
}
auto* shape_context = shape_refiner_->GetExtendedContext(ret.node);
TF_ASSIGN_OR_RETURN(auto type,