[TF:TRT] Modify a test to workaround a bug.

Allow native segment execution for BinaryTensorWeightBroadcastTest when
TensorRT 7+ is used. This is to workaround b/176540862.

PiperOrigin-RevId: 353581801
Change-Id: Ic494e212c012a3fb2c260da64a191899e68ec866
This commit is contained in:
Bixia Zheng 2021-01-24 22:47:49 -08:00 committed by TensorFlower Gardener
parent 77a9b64c88
commit ee58e600bf

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
@ -61,6 +62,12 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
"""Return the expected engines to build."""
return ["TRTEngineOp_%d" % i for i in range(16)]
# TODO(b/176540862): remove this routine to disallow native segment execution
# for TensorRT 7+.
def setUp(self):
super(trt_test.TfTrtIntegrationTestBase, self).setUp()
if trt_test.IsTensorRTVersionGreaterEqual(7):
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
if __name__ == "__main__":
test.main()