From 9f43ebd68c35a1f8f0f995352f7108b72fb41f50 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Sun, 21 Jun 2020 23:18:37 -0700 Subject: [PATCH] [TF:TRT] Enforce no native segment execution for TfTrtIntegrationTestBase tests. Report an error if the native segment for an TRTEngineOp is executed. This ensures that the TRTEngineOp constructed by the bridge is acceptable to TensorRT to catch bugs. Modify TrtModeTestBase to not build static engines for a graph that generates dynamic shaped values. This avoids native segment execution that is not caused by the inconsistency between the bridge and TensorRT. Temporarily allow native segment execution for VGGBlockTest and VGGBlockNCHWTest to workaround b/159459919. PiperOrigin-RevId: 317597344 Change-Id: I6c268c5c912a1fddcffa5d9763399976ddc0299e --- .../compiler/tensorrt/test/tf_trt_integration_test_base.py | 1 + tensorflow/python/compiler/tensorrt/test/trt_mode_test.py | 5 +++++ .../python/compiler/tensorrt/test/vgg_block_nchw_test.py | 7 +++++++ tensorflow/python/compiler/tensorrt/test/vgg_block_test.py | 7 +++++++ 4 files changed, 20 insertions(+) diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 8b93750fde4..87fa55a32bd 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -971,4 +971,5 @@ def _AddTests(test_class): if is_tensorrt_enabled(): + os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False" _AddTests(TfTrtIntegrationTestBase) diff --git a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py index 878ab4cbd8e..c67de7432cd 100644 --- a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py +++ b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py @@ -40,6 +40,11 @@ class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase): q = q + 5.0 return array_ops.identity(q, name="output_0") + def ShouldRunTest(self, run_params): + # Squeeze op produces dynamic shaped values. Therefore, we don't run the + # test with static engine to avoid native segment execution. + return (run_params.dynamic_engine, "test dynamic engine only") + def GetParams(self): """The input has 1 as a first dimension, which is removed by the squeeze. diff --git a/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py index 368ffad30a4..8fd9606812d 100644 --- a/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import numpy as np from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test @@ -69,6 +71,11 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return ["TRTEngineOp_0"] + # TODO(b/159459919): remove this routine to disallow native segment execution. + def setUp(self): + super(trt_test.TfTrtIntegrationTestBase, self).setUp() + os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py index f1b41327a58..9d81cd6dcc3 100644 --- a/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py +++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import numpy as np from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test @@ -60,6 +62,11 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return ["TRTEngineOp_0"] + # TODO(b/159459919): remove this routine to disallow native segment execution. + def setUp(self): + super(trt_test.TfTrtIntegrationTestBase, self).setUp() + os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True" + if __name__ == "__main__": test.main()