[TF:TRT] Modify a test to workaround a bug.
Allow native segment execution for BinaryTensorWeightBroadcastTest when TensorRT 7+ is used. This is to workaround b/176540862. PiperOrigin-RevId: 353581801 Change-Id: Ic494e212c012a3fb2c260da64a191899e68ec866
This commit is contained in:
parent
77a9b64c88
commit
ee58e600bf
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||
@ -61,6 +62,12 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
|
||||
"""Return the expected engines to build."""
|
||||
return ["TRTEngineOp_%d" % i for i in range(16)]
|
||||
|
||||
# TODO(b/176540862): remove this routine to disallow native segment execution
|
||||
# for TensorRT 7+.
|
||||
def setUp(self):
|
||||
super(trt_test.TfTrtIntegrationTestBase, self).setUp()
|
||||
if trt_test.IsTensorRTVersionGreaterEqual(7):
|
||||
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user