Add TRT conversion support to saved_model_cli

PiperOrigin-RevId: 225318469
This commit is contained in:
Guangda Lai 2018-12-12 23:27:38 -08:00 committed by TensorFlower Gardener
parent ad26fe7015
commit b99d914cfc

View File

@ -659,6 +659,28 @@ def scan(args):
scan_meta_graph_def(meta_graph_def)
def convert_with_tensorrt(args):
"""Function triggered by 'convert tensorrt' command.
Args:
args: A namespace parsed from command line.
"""
# Import here instead of at top, because this will crash if TensorRT is
# not installed
from tensorflow.contrib import tensorrt # pylint: disable=g-import-not-at-top
tensorrt.create_inference_graph(
None,
None,
max_batch_size=args.max_batch_size,
max_workspace_size_bytes=args.max_workspace_size_bytes,
precision_mode=args.precision_mode,
minimum_segment_size=args.minimum_segment_size,
is_dynamic_op=args.is_dynamic_op,
input_saved_model_dir=args.dir,
input_saved_model_tags=args.tag_set.split(','),
output_saved_model_dir=args.output_dir)
def create_parser():
"""Creates a parser that parse the command line arguments.
@ -812,6 +834,71 @@ def create_parser():
help='tag-set of graph in SavedModel to scan, separated by \',\'')
parser_scan.set_defaults(func=scan)
# convert command
convert_msg = ('Usage example:\n'
'To convert the SavedModel to one that have TensorRT ops:\n'
'$saved_model_cli convert \\\n'
' --dir /tmp/saved_model \\\n'
' --tag_set serve \\\n'
' --output_dir /tmp/saved_model_trt \\\n'
' tensorrt \n')
parser_convert = subparsers.add_parser(
'convert',
description=convert_msg,
formatter_class=argparse.RawTextHelpFormatter)
parser_convert.add_argument(
'--dir',
type=str,
required=True,
help='directory containing the SavedModel to convert')
parser_convert.add_argument(
'--output_dir',
type=str,
required=True,
help='output directory for the converted SavedModel')
parser_convert.add_argument(
'--tag_set',
type=str,
required=True,
help='tag-set of graph in SavedModel to convert, separated by \',\'')
convert_subparsers = parser_convert.add_subparsers(
title='conversion methods',
description='valid conversion methods',
help='the conversion to run with the SavedModel')
parser_convert_with_tensorrt = convert_subparsers.add_parser(
'tensorrt',
description='Convert the SavedModel with Tensorflow-TensorRT integration',
formatter_class=argparse.RawTextHelpFormatter)
parser_convert_with_tensorrt.add_argument(
'--max_batch_size',
type=int,
default=1,
help='max size for the input batch')
parser_convert_with_tensorrt.add_argument(
'--max_workspace_size_bytes',
type=int,
default=2 << 20,
help=('the maximum GPU temporary memory which the TRT engine can use at '
'execution time'))
parser_convert_with_tensorrt.add_argument(
'--precision_mode',
type=str,
default='FP32',
help='one of FP32, FP16 and INT8')
parser_convert_with_tensorrt.add_argument(
'--minimum_segment_size',
type=int,
default=3,
help=('the minimum number of nodes required for a subgraph to be replaced'
'in a TensorRT node'))
parser_convert_with_tensorrt.add_argument(
'--is_dynamic_op',
type=bool,
default=False,
help=('whether to generate dynamic TRT ops which will build the TRT '
'network and engine at run time'))
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
return parser