[TF:TRT] Enforce no native segment execution for TfTrtIntegrationTestBase
tests. Report an error if the native segment for an TRTEngineOp is executed. This ensures that the TRTEngineOp constructed by the bridge is acceptable to TensorRT to catch bugs. Modify TrtModeTestBase to not build static engines for a graph that generates dynamic shaped values. This avoids native segment execution that is not caused by the inconsistency between the bridge and TensorRT. Temporarily allow native segment execution for VGGBlockTest and VGGBlockNCHWTest to workaround b/159459919. PiperOrigin-RevId: 317597344 Change-Id: I6c268c5c912a1fddcffa5d9763399976ddc0299e
This commit is contained in:
parent
8fc628cf78
commit
9f43ebd68c
@ -971,4 +971,5 @@ def _AddTests(test_class):
|
||||
|
||||
|
||||
if is_tensorrt_enabled():
|
||||
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "False"
|
||||
_AddTests(TfTrtIntegrationTestBase)
|
||||
|
@ -40,6 +40,11 @@ class TrtModeTestBase(trt_test.TfTrtIntegrationTestBase):
|
||||
q = q + 5.0
|
||||
return array_ops.identity(q, name="output_0")
|
||||
|
||||
def ShouldRunTest(self, run_params):
|
||||
# Squeeze op produces dynamic shaped values. Therefore, we don't run the
|
||||
# test with static engine to avoid native segment execution.
|
||||
return (run_params.dynamic_engine, "test dynamic engine only")
|
||||
|
||||
def GetParams(self):
|
||||
"""The input has 1 as a first dimension, which is removed by the squeeze.
|
||||
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||
@ -69,6 +71,11 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
|
||||
"""Return the expected engines to build."""
|
||||
return ["TRTEngineOp_0"]
|
||||
|
||||
# TODO(b/159459919): remove this routine to disallow native segment execution.
|
||||
def setUp(self):
|
||||
super(trt_test.TfTrtIntegrationTestBase, self).setUp()
|
||||
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||
@ -60,6 +62,11 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
|
||||
"""Return the expected engines to build."""
|
||||
return ["TRTEngineOp_0"]
|
||||
|
||||
# TODO(b/159459919): remove this routine to disallow native segment execution.
|
||||
def setUp(self):
|
||||
super(trt_test.TfTrtIntegrationTestBase, self).setUp()
|
||||
os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user