Cleanup trt engine op definition.

PiperOrigin-RevId: 238484967
This commit is contained in:
Guangda Lai 2019-03-14 11:35:35 -07:00 committed by TensorFlower Gardener
parent 852d3364e6
commit 7239d65880

View File

@ -24,12 +24,9 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace shape_inference { // NOTE: when making changes please follow
extern Status TRTEngineOpShapeInference(InferenceContext* c); // https://www.tensorflow.org/guide/extend/op#backwards_compatibility to not
} // break backward compatibility.
// NOTE: please try NOT to add/modify/remove attributes or inputs/outputs to the
// list below, this will break backward compatibility!
// //
// TODO(laigd): consider making this op stateful. The only problem is it uses TF // TODO(laigd): consider making this op stateful. The only problem is it uses TF
// function which has to be stateless, but we can use function library as the // function which has to be stateless, but we can use function library as the
@ -41,8 +38,6 @@ REGISTER_OP("TRTEngineOp")
.Attr("segment_funcdef_name: string") .Attr("segment_funcdef_name: string")
.Attr("InT: list({int8,float16,float32,int32})") .Attr("InT: list({int8,float16,float32,int32})")
.Attr("OutT: list({int8,float16,float32,int32})") .Attr("OutT: list({int8,float16,float32,int32})")
.Attr("static_engine: bool = true")
.Attr("fixed_input_size: bool = true")
.Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("cached_engine_batches: list(int) >= 0 = []")
.Attr("max_cached_engines_count: int = 1") .Attr("max_cached_engines_count: int = 1")
.Attr("workspace_size_bytes: int") .Attr("workspace_size_bytes: int")
@ -57,8 +52,10 @@ REGISTER_OP("TRTEngineOp")
// implementation, we do require all input tensor to carry the same batch // implementation, we do require all input tensor to carry the same batch
// size, but this could change in the future). Hence we disable shape // size, but this could change in the future). Hence we disable shape
// inference function as a workaround. // inference function as a workaround.
// .SetShapeFn(shape_inference::TRTEngineOpShapeInference); .SetShapeFn(shape_inference::UnknownShape)
.SetShapeFn(shape_inference::UnknownShape); // Deprecated attributes.
.Attr("fixed_input_size: bool = true")
.Attr("static_engine: bool = true");
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_TENSORRT #endif // GOOGLE_TENSORRT