Fix TRT tests in OSS build by reducing the GPU memory consumption.
PiperOrigin-RevId: 263579969
This commit is contained in:
parent
676ff6bf31
commit
4850ef3125
@ -306,13 +306,16 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
output_saved_model_dir=self.mkdtemp(),
|
||||
need_calibration=need_calibration)
|
||||
|
||||
def _CreateConverterV2(self,
|
||||
input_saved_model_dir,
|
||||
precision_mode=trt_convert.TrtPrecisionMode.FP32):
|
||||
def _CreateConverterV2(
|
||||
self,
|
||||
input_saved_model_dir,
|
||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||
precision_mode=trt_convert.TrtPrecisionMode.FP32):
|
||||
return trt_convert.TrtGraphConverterV2(
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||
input_saved_model_signature_key=input_saved_model_signature_key,
|
||||
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
|
||||
precision_mode=precision_mode,
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=2))
|
||||
@ -493,8 +496,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
root = model_class()
|
||||
save.save(root, input_saved_model_dir)
|
||||
|
||||
converter = trt_convert.TrtGraphConverterV2(
|
||||
input_saved_model_dir=input_saved_model_dir)
|
||||
converter = self._CreateConverterV2(
|
||||
input_saved_model_dir, input_saved_model_signature_key=signature_key)
|
||||
converter.convert()
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
converter.save(output_saved_model_dir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user