Support tf1 models for 'convert_with_tensorrt'.
PiperOrigin-RevId: 301891524 Change-Id: If6c3f692a4763cf171c6e585c4986f52c732ee1a
This commit is contained in:
parent
b8cdd18825
commit
dc8554d9f5
@ -772,16 +772,33 @@ def convert_with_tensorrt(args):
|
|||||||
# not installed
|
# not installed
|
||||||
from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top
|
from tensorflow.python.compiler.tensorrt import trt_convert as trt # pylint: disable=g-import-not-at-top
|
||||||
|
|
||||||
params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
if not args.convert_tf1_model:
|
||||||
max_workspace_size_bytes=args.max_workspace_size_bytes,
|
params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||||
precision_mode=args.precision_mode,
|
max_workspace_size_bytes=args.max_workspace_size_bytes,
|
||||||
minimum_segment_size=args.minimum_segment_size)
|
precision_mode=args.precision_mode,
|
||||||
converter = trt.TrtGraphConverterV2(
|
minimum_segment_size=args.minimum_segment_size)
|
||||||
input_saved_model_dir=args.dir,
|
converter = trt.TrtGraphConverterV2(
|
||||||
input_saved_model_tags=args.tag_set.split(','),
|
input_saved_model_dir=args.dir,
|
||||||
conversion_params=params)
|
input_saved_model_tags=args.tag_set.split(','),
|
||||||
converter.convert()
|
conversion_params=params)
|
||||||
converter.save(output_saved_model_dir=args.output_dir)
|
try:
|
||||||
|
converter.convert()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
'{}. Try passing "--convert_tf1_model=True".'.format(e))
|
||||||
|
converter.save(output_saved_model_dir=args.output_dir)
|
||||||
|
else:
|
||||||
|
trt.create_inference_graph(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
max_batch_size=1,
|
||||||
|
max_workspace_size_bytes=args.max_workspace_size_bytes,
|
||||||
|
precision_mode=args.precision_mode,
|
||||||
|
minimum_segment_size=args.minimum_segment_size,
|
||||||
|
is_dynamic_op=True,
|
||||||
|
input_saved_model_dir=args.dir,
|
||||||
|
input_saved_model_tags=args.tag_set.split(','),
|
||||||
|
output_saved_model_dir=args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
def aot_compile_cpu(args):
|
def aot_compile_cpu(args):
|
||||||
@ -1010,6 +1027,11 @@ def add_convert_subparser(subparsers):
|
|||||||
default=3,
|
default=3,
|
||||||
help=('the minimum number of nodes required for a subgraph to be replaced'
|
help=('the minimum number of nodes required for a subgraph to be replaced'
|
||||||
'in a TensorRT node'))
|
'in a TensorRT node'))
|
||||||
|
parser_convert_with_tensorrt.add_argument(
|
||||||
|
'--convert_tf1_model',
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help='support TRT conversion for TF1 models')
|
||||||
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
|
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user