Merge pull request #34945 from wdirons:fix_saved_model_cli_for_tensorrt

PiperOrigin-RevId: 289592388
Change-Id: I249c1f6f871194c5ffa059e8d44a9d32408b3b92
This commit is contained in:
TensorFlower Gardener 2020-01-13 23:34:17 -08:00
commit 4d0eaa651e

View File

@ -762,18 +762,18 @@ def convert_with_tensorrt(args):
""" """
# Import here instead of at top, because this will crash if TensorRT is # Import here instead of at top, because this will crash if TensorRT is
# not installed # not installed
from tensorflow.contrib import tensorrt # pylint: disable=g-import-not-at-top from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top
tensorrt.create_inference_graph(
None, params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
None,
max_batch_size=args.max_batch_size,
max_workspace_size_bytes=args.max_workspace_size_bytes, max_workspace_size_bytes=args.max_workspace_size_bytes,
precision_mode=args.precision_mode, precision_mode=args.precision_mode,
minimum_segment_size=args.minimum_segment_size, minimum_segment_size=args.minimum_segment_size)
is_dynamic_op=args.is_dynamic_op, converter = trt.TrtGraphConverterV2(
input_saved_model_dir=args.dir, input_saved_model_dir=args.dir,
input_saved_model_tags=args.tag_set.split(','), input_saved_model_tags=args.tag_set.split(','),
output_saved_model_dir=args.output_dir) conversion_params=params)
converter.convert()
converter.save(output_saved_model_dir=args.output_dir)
def create_parser(): def create_parser():
@ -964,11 +964,6 @@ def create_parser():
'tensorrt', 'tensorrt',
description='Convert the SavedModel with Tensorflow-TensorRT integration', description='Convert the SavedModel with Tensorflow-TensorRT integration',
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter)
parser_convert_with_tensorrt.add_argument(
'--max_batch_size',
type=int,
default=1,
help='max size for the input batch')
parser_convert_with_tensorrt.add_argument( parser_convert_with_tensorrt.add_argument(
'--max_workspace_size_bytes', '--max_workspace_size_bytes',
type=int, type=int,
@ -986,12 +981,6 @@ def create_parser():
default=3, default=3,
help=('the minimum number of nodes required for a subgraph to be replaced' help=('the minimum number of nodes required for a subgraph to be replaced'
'in a TensorRT node')) 'in a TensorRT node'))
parser_convert_with_tensorrt.add_argument(
'--is_dynamic_op',
type=bool,
default=False,
help=('whether to generate dynamic TRT ops which will build the TRT '
'network and engine at run time'))
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt) parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
return parser return parser