From bf00bd654adc0bbb6ccc73a8b729e9f1d0f6037c Mon Sep 17 00:00:00 2001 From: "William D. Irons" <wdirons@us.ibm.com> Date: Sun, 8 Dec 2019 20:16:34 +0000 Subject: [PATCH 1/2] Fix saved_model_cli tensorrt conversion The existing saved_model_cli convert tensorrt script fails in 2.X with module not found "tensorflow.contrib". Updated the script to use the V2 API for TensorRT to convert a saved_model. The max_batch_size and is_dynamic_op parameters are not valid for the V2 API so they have been removed. --- tensorflow/python/tools/saved_model_cli.py | 33 +++++++--------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 57ffc3f05c2..e2e5c37d83c 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -747,19 +747,17 @@ def convert_with_tensorrt(args): """ # Import here instead of at top, because this will crash if TensorRT is # not installed - from tensorflow.contrib import tensorrt # pylint: disable=g-import-not-at-top - tensorrt.create_inference_graph( - None, - None, - max_batch_size=args.max_batch_size, - max_workspace_size_bytes=args.max_workspace_size_bytes, - precision_mode=args.precision_mode, - minimum_segment_size=args.minimum_segment_size, - is_dynamic_op=args.is_dynamic_op, - input_saved_model_dir=args.dir, - input_saved_model_tags=args.tag_set.split(','), - output_saved_model_dir=args.output_dir) + from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top + params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( + max_workspace_size_bytes=args.max_workspace_size_bytes, + precision_mode=args.precision_mode, + minimum_segment_size=args.minimum_segment_size) + converter = trt.TrtGraphConverterV2(input_saved_model_dir=args.dir, + input_saved_model_tags=args.tag_set.split(','), + conversion_params=params) + converter.convert() + converter.save(output_saved_model_dir=args.output_dir) def create_parser(): """Creates a parser that parse the command line arguments. @@ -949,11 +947,6 @@ def create_parser(): 'tensorrt', description='Convert the SavedModel with Tensorflow-TensorRT integration', formatter_class=argparse.RawTextHelpFormatter) - parser_convert_with_tensorrt.add_argument( - '--max_batch_size', - type=int, - default=1, - help='max size for the input batch') parser_convert_with_tensorrt.add_argument( '--max_workspace_size_bytes', type=int, @@ -971,12 +964,6 @@ def create_parser(): default=3, help=('the minimum number of nodes required for a subgraph to be replaced' 'in a TensorRT node')) - parser_convert_with_tensorrt.add_argument( - '--is_dynamic_op', - type=bool, - default=False, - help=('whether to generate dynamic TRT ops which will build the TRT ' - 'network and engine at run time')) parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt) return parser From 89f1f386ce9899fa835e87fd3d7d7a671aab73d9 Mon Sep 17 00:00:00 2001 From: "William D. Irons" <wdirons@us.ibm.com> Date: Sat, 11 Jan 2020 23:59:21 +0000 Subject: [PATCH 2/2] Fix lint problems --- tensorflow/python/tools/saved_model_cli.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index e2e5c37d83c..16cfbb14b58 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -750,12 +750,13 @@ def convert_with_tensorrt(args): from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( - max_workspace_size_bytes=args.max_workspace_size_bytes, - precision_mode=args.precision_mode, - minimum_segment_size=args.minimum_segment_size) - converter = trt.TrtGraphConverterV2(input_saved_model_dir=args.dir, - input_saved_model_tags=args.tag_set.split(','), - conversion_params=params) + max_workspace_size_bytes=args.max_workspace_size_bytes, + precision_mode=args.precision_mode, + minimum_segment_size=args.minimum_segment_size) + converter = trt.TrtGraphConverterV2( + input_saved_model_dir=args.dir, + input_saved_model_tags=args.tag_set.split(','), + conversion_params=params) converter.convert() converter.save(output_saved_model_dir=args.output_dir)