Merge pull request #24381 from trevor-m:tmorris_tftrt_undefined_shapes_static
PiperOrigin-RevId: 225848626
This commit is contained in:
commit
9f94b13cfa
@ -582,6 +582,18 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
|
||||
}
|
||||
input_shape_protos.at(conn.port_number) = in_shape;
|
||||
input_shapes.at(conn.port_number) = conn.outside_shape;
|
||||
// Shape must be fully defined (excluding batch dimension) for static
|
||||
// mode.
|
||||
if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
|
||||
for (int i = 1; i < conn.outside_shape.dims(); i++) {
|
||||
if (conn.outside_shape.dim_size(i) <= 0) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Input shapes must be fully defined when in static mode. "
|
||||
"Please try is_dynamic_op=True (shape was ",
|
||||
conn.outside_shape.DebugString(), ")");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rewrire data input if it's not found in original graph.
|
||||
tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
|
||||
|
Loading…
Reference in New Issue
Block a user