[TF:TRT] Enforce no native segment execution for TfTrtIntegrationTestBase

tests.

Report an error if the native segment for an TRTEngineOp is executed. This
ensures that the TRTEngineOp constructed by the bridge is acceptable to
TensorRT to catch bugs.

Modify TrtModeTestBase to not build static engines for a graph that generates
dynamic shaped values. This avoids native segment execution that is not caused
by the inconsistency between the bridge and TensorRT.

Temporarily allow native segment execution for VGGBlockTest and
VGGBlockNCHWTest to workaround b/159459919.

PiperOrigin-RevId: 317597344
Change-Id: I6c268c5c912a1fddcffa5d9763399976ddc0299e
This commit is contained in:
Bixia Zheng 2020-06-21 23:18:37 -07:00 committed by TensorFlower Gardener
parent 8fc628cf78
commit 9f43ebd68c
4 changed files with 20 additions and 0 deletions

View File

@ -971,4 +971,5 @@ def _AddTests(test_class):
if is_tensorrt_enabled():
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
_AddTests(TfTrtIntegrationTestBase)

View File

@ -40,6 +40,11 @@ class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase):
q = q + 5.0
return array_ops.identity(q, name="output_0")
def ShouldRunTest(self, run_params):
# Squeeze op produces dynamic shaped values. Therefore, we don't run the
# test with static engine to avoid native segment execution.
return (run_params.dynamic_engine, "test dynamic engine only")
def GetParams(self):
"""The input has 1 as a first dimension, which is removed by the squeeze.

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
@ -69,6 +71,11 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
"""Return the expected engines to build."""
return ["TRTEngineOp_0"]
# TODO(b/159459919): remove this routine to disallow native segment execution.
def setUp(self):
super(trt_test.TfTrtIntegrationTestBase, self).setUp()
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
if __name__ == "__main__":
test.main()

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
@ -60,6 +62,11 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
"""Return the expected engines to build."""
return ["TRTEngineOp_0"]
# TODO(b/159459919): remove this routine to disallow native segment execution.
def setUp(self):
super(trt_test.TfTrtIntegrationTestBase, self).setUp()
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
if __name__ == "__main__":
test.main()