Fix TRT tests in OSS build by reducing the GPU memory consumption.

PiperOrigin-RevId: 263579969
This commit is contained in:
Guangda Lai 2019-08-15 09:34:50 -07:00 committed by Goldie Gadde
parent 676ff6bf31
commit 4850ef3125

View File

@ -306,13 +306,16 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
output_saved_model_dir=self.mkdtemp(),
need_calibration=need_calibration)
def _CreateConverterV2(self,
input_saved_model_dir,
precision_mode=trt_convert.TrtPrecisionMode.FP32):
def _CreateConverterV2(
self,
input_saved_model_dir,
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
precision_mode=trt_convert.TrtPrecisionMode.FP32):
return trt_convert.TrtGraphConverterV2(
input_saved_model_dir=input_saved_model_dir,
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
input_saved_model_signature_key=input_saved_model_signature_key,
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
precision_mode=precision_mode,
is_dynamic_op=True,
maximum_cached_engines=2))
@ -493,8 +496,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
root = model_class()
save.save(root, input_saved_model_dir)
converter = trt_convert.TrtGraphConverterV2(
input_saved_model_dir=input_saved_model_dir)
converter = self._CreateConverterV2(
input_saved_model_dir, input_saved_model_signature_key=signature_key)
converter.convert()
output_saved_model_dir = self.mkdtemp()
converter.save(output_saved_model_dir)