Merge pull request #24381 from trevor-m:tmorris_tftrt_undefined_shapes_static

PiperOrigin-RevId: 225848626
This commit is contained in:
TensorFlower Gardener 2018-12-17 10:32:44 -08:00
commit 9f94b13cfa

View File

@ -582,6 +582,18 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
}
input_shape_protos.at(conn.port_number) = in_shape;
input_shapes.at(conn.port_number) = conn.outside_shape;
// Shape must be fully defined (excluding batch dimension) for static
// mode.
if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
for (int i = 1; i < conn.outside_shape.dims(); i++) {
if (conn.outside_shape.dim_size(i) <= 0) {
return tensorflow::errors::Internal(
"Input shapes must be fully defined when in static mode. "
"Please try is_dynamic_op=True (shape was ",
conn.outside_shape.DebugString(), ")");
}
}
}
// Rewrire data input if it's not found in original graph.
tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);