Return an error if the output index is invalid while importing GraphDef to MLIR
Currently, it crashes in such cases. PiperOrigin-RevId: 261603894
This commit is contained in:
parent
cc6e729c9e
commit
df79f08595
@ -0,0 +1,14 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
|
||||
# CHECK: Graph import failed: Invalid argument: Invalid output index 1 specified for node: input
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
@ -1368,10 +1368,9 @@ StatusOr<mlir::FunctionType> Importer::InferMainFunctionType(
|
||||
|
||||
// Output nodes as function returns.
|
||||
for (const auto& ret : *ret_nodes) {
|
||||
if (ret.node->num_outputs() < 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Invalid output node; should have at least 1 output: " +
|
||||
ret.node->name());
|
||||
if (ret.node->num_outputs() <= ret.index) {
|
||||
return errors::InvalidArgument("Invalid output index ", ret.index,
|
||||
" specified for node: ", ret.node->name());
|
||||
}
|
||||
auto* shape_context = shape_refiner_->GetExtendedContext(ret.node);
|
||||
TF_ASSIGN_OR_RETURN(auto type,
|
||||
|
Loading…
Reference in New Issue
Block a user