From dc8554d9f5356d4f4f3ae9d05c69c07b9cdf56ac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Mar 2020 14:07:31 -0700 Subject: [PATCH] Support tf1 models for 'convert_with_tensorrt'. PiperOrigin-RevId: 301891524 Change-Id: If6c3f692a4763cf171c6e585c4986f52c732ee1a --- tensorflow/python/tools/saved_model_cli.py | 42 ++++++++++++++++------ 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 8b951be49db..6e60e58b345 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -772,16 +772,33 @@ def convert_with_tensorrt(args): # not installed from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top - params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( - max_workspace_size_bytes=args.max_workspace_size_bytes, - precision_mode=args.precision_mode, - minimum_segment_size=args.minimum_segment_size) - converter = trt.TrtGraphConverterV2( - input_saved_model_dir=args.dir, - input_saved_model_tags=args.tag_set.split(','), - conversion_params=params) - converter.convert() - converter.save(output_saved_model_dir=args.output_dir) + if not args.convert_tf1_model: + params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( + max_workspace_size_bytes=args.max_workspace_size_bytes, + precision_mode=args.precision_mode, + minimum_segment_size=args.minimum_segment_size) + converter = trt.TrtGraphConverterV2( + input_saved_model_dir=args.dir, + input_saved_model_tags=args.tag_set.split(','), + conversion_params=params) + try: + converter.convert() + except Exception as e: + raise RuntimeError( + '{}. Try passing "--convert_tf1_model=True".'.format(e)) + converter.save(output_saved_model_dir=args.output_dir) + else: + trt.create_inference_graph( + None, + None, + max_batch_size=1, + max_workspace_size_bytes=args.max_workspace_size_bytes, + precision_mode=args.precision_mode, + minimum_segment_size=args.minimum_segment_size, + is_dynamic_op=True, + input_saved_model_dir=args.dir, + input_saved_model_tags=args.tag_set.split(','), + output_saved_model_dir=args.output_dir) def aot_compile_cpu(args): @@ -1010,6 +1027,11 @@ def add_convert_subparser(subparsers): default=3, help=('the minimum number of nodes required for a subgraph to be replaced' 'in a TensorRT node')) + parser_convert_with_tensorrt.add_argument( + '--convert_tf1_model', + type=bool, + default=False, + help='support TRT conversion for TF1 models') parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)