From bf00bd654adc0bbb6ccc73a8b729e9f1d0f6037c Mon Sep 17 00:00:00 2001
From: "William D. Irons" <wdirons@us.ibm.com>
Date: Sun, 8 Dec 2019 20:16:34 +0000
Subject: [PATCH 1/2] Fix saved_model_cli tensorrt conversion

The existing saved_model_cli convert tensorrt script fails in 2.X with module
not found "tensorflow.contrib". Updated the script to use the V2 API for
TensorRT to convert a saved_model.

The max_batch_size and is_dynamic_op parameters are not valid for the V2 API
so they have been removed.
---
 tensorflow/python/tools/saved_model_cli.py | 33 +++++++---------------
 1 file changed, 10 insertions(+), 23 deletions(-)

diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 57ffc3f05c2..e2e5c37d83c 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -747,19 +747,17 @@ def convert_with_tensorrt(args):
   """
   # Import here instead of at top, because this will crash if TensorRT is
   # not installed
-  from tensorflow.contrib import tensorrt  # pylint: disable=g-import-not-at-top
-  tensorrt.create_inference_graph(
-      None,
-      None,
-      max_batch_size=args.max_batch_size,
-      max_workspace_size_bytes=args.max_workspace_size_bytes,
-      precision_mode=args.precision_mode,
-      minimum_segment_size=args.minimum_segment_size,
-      is_dynamic_op=args.is_dynamic_op,
-      input_saved_model_dir=args.dir,
-      input_saved_model_tags=args.tag_set.split(','),
-      output_saved_model_dir=args.output_dir)
+  from tensorflow.python.compiler.tensorrt import trt_convert as trt  # pylint: disable=g-import-not-at-top
 
+  params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
+        max_workspace_size_bytes=args.max_workspace_size_bytes,
+        precision_mode=args.precision_mode,
+        minimum_segment_size=args.minimum_segment_size)
+  converter = trt.TrtGraphConverterV2(input_saved_model_dir=args.dir,
+                                      input_saved_model_tags=args.tag_set.split(','),
+                                      conversion_params=params)
+  converter.convert()
+  converter.save(output_saved_model_dir=args.output_dir)
 
 def create_parser():
   """Creates a parser that parse the command line arguments.
@@ -949,11 +947,6 @@ def create_parser():
       'tensorrt',
       description='Convert the SavedModel with Tensorflow-TensorRT integration',
       formatter_class=argparse.RawTextHelpFormatter)
-  parser_convert_with_tensorrt.add_argument(
-      '--max_batch_size',
-      type=int,
-      default=1,
-      help='max size for the input batch')
   parser_convert_with_tensorrt.add_argument(
       '--max_workspace_size_bytes',
       type=int,
@@ -971,12 +964,6 @@ def create_parser():
       default=3,
       help=('the minimum number of nodes required for a subgraph to be replaced'
             'in a TensorRT node'))
-  parser_convert_with_tensorrt.add_argument(
-      '--is_dynamic_op',
-      type=bool,
-      default=False,
-      help=('whether to generate dynamic TRT ops which will build the TRT '
-            'network and engine at run time'))
   parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
 
   return parser

From 89f1f386ce9899fa835e87fd3d7d7a671aab73d9 Mon Sep 17 00:00:00 2001
From: "William D. Irons" <wdirons@us.ibm.com>
Date: Sat, 11 Jan 2020 23:59:21 +0000
Subject: [PATCH 2/2] Fix lint problems

---
 tensorflow/python/tools/saved_model_cli.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index e2e5c37d83c..16cfbb14b58 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -750,12 +750,13 @@ def convert_with_tensorrt(args):
   from tensorflow.python.compiler.tensorrt import trt_convert as trt  # pylint: disable=g-import-not-at-top
 
   params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
-        max_workspace_size_bytes=args.max_workspace_size_bytes,
-        precision_mode=args.precision_mode,
-        minimum_segment_size=args.minimum_segment_size)
-  converter = trt.TrtGraphConverterV2(input_saved_model_dir=args.dir,
-                                      input_saved_model_tags=args.tag_set.split(','),
-                                      conversion_params=params)
+      max_workspace_size_bytes=args.max_workspace_size_bytes,
+      precision_mode=args.precision_mode,
+      minimum_segment_size=args.minimum_segment_size)
+  converter = trt.TrtGraphConverterV2(
+      input_saved_model_dir=args.dir,
+      input_saved_model_tags=args.tag_set.split(','),
+      conversion_params=params)
   converter.convert()
   converter.save(output_saved_model_dir=args.output_dir)