From ee58e600bfcd73635912b13e131232408a9fc75d Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Sun, 24 Jan 2021 22:47:49 -0800 Subject: [PATCH] [TF:TRT] Modify a test to workaround a bug. Allow native segment execution for BinaryTensorWeightBroadcastTest when TensorRT 7+ is used. This is to workaround b/176540862. PiperOrigin-RevId: 353581801 Change-Id: Ic494e212c012a3fb2c260da64a191899e68ec866 --- .../tensorrt/test/binary_tensor_weight_broadcast_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py index 9e31327f580..89cfa6fb651 100644 --- a/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test @@ -61,6 +62,12 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return ["TRTEngineOp_%d" % i for i in range(16)] + # TODO(b/176540862): remove this routine to disallow native segment execution + # for TensorRT 7+. + def setUp(self): + super(trt_test.TfTrtIntegrationTestBase, self).setUp() + if trt_test.IsTensorRTVersionGreaterEqual(7): + os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" if __name__ == "__main__": test.main()