Merge pull request #34945 from wdirons:fix_saved_model_cli_for_tensorrt
PiperOrigin-RevId: 289592388 Change-Id: I249c1f6f871194c5ffa059e8d44a9d32408b3b92
This commit is contained in:
commit
4d0eaa651e
@ -762,18 +762,18 @@ def convert_with_tensorrt(args):
|
|||||||
"""
|
"""
|
||||||
# Import here instead of at top, because this will crash if TensorRT is
|
# Import here instead of at top, because this will crash if TensorRT is
|
||||||
# not installed
|
# not installed
|
||||||
from tensorflow.contrib import tensorrt # pylint: disable=g-import-not-at-top
|
from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top
|
||||||
tensorrt.create_inference_graph(
|
|
||||||
None,
|
params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||||
None,
|
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
max_workspace_size_bytes=args.max_workspace_size_bytes,
|
max_workspace_size_bytes=args.max_workspace_size_bytes,
|
||||||
precision_mode=args.precision_mode,
|
precision_mode=args.precision_mode,
|
||||||
minimum_segment_size=args.minimum_segment_size,
|
minimum_segment_size=args.minimum_segment_size)
|
||||||
is_dynamic_op=args.is_dynamic_op,
|
converter = trt.TrtGraphConverterV2(
|
||||||
input_saved_model_dir=args.dir,
|
input_saved_model_dir=args.dir,
|
||||||
input_saved_model_tags=args.tag_set.split(','),
|
input_saved_model_tags=args.tag_set.split(','),
|
||||||
output_saved_model_dir=args.output_dir)
|
conversion_params=params)
|
||||||
|
converter.convert()
|
||||||
|
converter.save(output_saved_model_dir=args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
def create_parser():
|
def create_parser():
|
||||||
@ -964,11 +964,6 @@ def create_parser():
|
|||||||
'tensorrt',
|
'tensorrt',
|
||||||
description='Convert the SavedModel with Tensorflow-TensorRT integration',
|
description='Convert the SavedModel with Tensorflow-TensorRT integration',
|
||||||
formatter_class=argparse.RawTextHelpFormatter)
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
parser_convert_with_tensorrt.add_argument(
|
|
||||||
'--max_batch_size',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='max size for the input batch')
|
|
||||||
parser_convert_with_tensorrt.add_argument(
|
parser_convert_with_tensorrt.add_argument(
|
||||||
'--max_workspace_size_bytes',
|
'--max_workspace_size_bytes',
|
||||||
type=int,
|
type=int,
|
||||||
@ -986,12 +981,6 @@ def create_parser():
|
|||||||
default=3,
|
default=3,
|
||||||
help=('the minimum number of nodes required for a subgraph to be replaced'
|
help=('the minimum number of nodes required for a subgraph to be replaced'
|
||||||
'in a TensorRT node'))
|
'in a TensorRT node'))
|
||||||
parser_convert_with_tensorrt.add_argument(
|
|
||||||
'--is_dynamic_op',
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help=('whether to generate dynamic TRT ops which will build the TRT '
|
|
||||||
'network and engine at run time'))
|
|
||||||
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
|
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
Loading…
x
Reference in New Issue
Block a user