Link TF-TRT into python:no_contrib (so it's included in TF 2.0 build) and lazily load TRT shared library.
PiperOrigin-RevId: 233494604
This commit is contained in:
parent
5a297f5efe
commit
8bb8258646
@ -160,6 +160,7 @@ tf_custom_op_py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:resources",
|
"//tensorflow/python:resources",
|
||||||
|
@ -18,17 +18,45 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
|
||||||
if platform.system() != "Windows":
|
_trt_ops_so = None
|
||||||
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
_module_lock = threading.Lock()
|
||||||
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import *
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_trt_ops():
|
||||||
|
"""Load TF-TRT op libraries so if it hasn't been loaded already."""
|
||||||
|
global _trt_ops_so
|
||||||
|
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
raise RuntimeError("Windows platforms are not supported")
|
||||||
|
|
||||||
|
with _module_lock:
|
||||||
|
if _trt_ops_so:
|
||||||
|
return
|
||||||
|
|
||||||
|
# TODO(laigd): we should load TF-TRT kernels here as well after removing the
|
||||||
|
# swig binding.
|
||||||
|
try:
|
||||||
|
# TODO(lagid): consider getting rid of these unused imports.
|
||||||
|
# pylint: disable=unused-import,g-import-not-at-top,unused-variable
|
||||||
|
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import get_serialized_resource_op
|
||||||
|
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
|
||||||
from tensorflow.python.framework import load_library
|
from tensorflow.python.framework import load_library
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
|
# pylint: enable=unused-import,g-import-not-at-top,unused-variable
|
||||||
|
|
||||||
_trt_ops = load_library.load_op_library(
|
_trt_ops_so = load_library.load_op_library(
|
||||||
resource_loader.get_path_to_datafile("_trt_ops.so"))
|
resource_loader.get_path_to_datafile("_trt_ops.so"))
|
||||||
else:
|
except errors.NotFoundError as e:
|
||||||
raise RuntimeError("Windows platforms are not supported")
|
no_trt_message = (
|
||||||
|
"**** Failed to initialize TensorRT. This is either because the "
|
||||||
|
"TensorRT installation path is not in LD_LIBRARY_PATH, or because "
|
||||||
|
"you do not have it installed. If not installed, please go to "
|
||||||
|
"https://developer.nvidia.com/tensorrt to download and install "
|
||||||
|
"TensorRT ****")
|
||||||
|
print(no_trt_message)
|
||||||
|
raise e
|
||||||
|
@ -173,6 +173,7 @@ py_library(
|
|||||||
"//tensorflow/lite/python:lite",
|
"//tensorflow/lite/python:lite",
|
||||||
"//tensorflow/python/compat",
|
"//tensorflow/python/compat",
|
||||||
"//tensorflow/python/compat:v2_compat",
|
"//tensorflow/python/compat:v2_compat",
|
||||||
|
"//tensorflow/python/compiler",
|
||||||
"//tensorflow/python/data",
|
"//tensorflow/python/data",
|
||||||
"//tensorflow/python/distribute",
|
"//tensorflow/python/distribute",
|
||||||
"//tensorflow/python/distribute:estimator_training",
|
"//tensorflow/python/distribute:estimator_training",
|
||||||
|
17
tensorflow/python/compiler/BUILD
Normal file
17
tensorflow/python/compiler/BUILD
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Description:
|
||||||
|
# Python APIs for various Tensorflow backends.
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "compiler",
|
||||||
|
srcs = ["__init__.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python/compiler/tensorrt:init_py",
|
||||||
|
],
|
||||||
|
)
|
0
tensorflow/python/compiler/__init__.py
Normal file
0
tensorflow/python/compiler/__init__.py
Normal file
@ -32,8 +32,6 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tf_trt_integration_test_base",
|
":tf_trt_integration_test_base",
|
||||||
":trt_convert_py",
|
":trt_convert_py",
|
||||||
":trt_ops_py",
|
|
||||||
"//tensorflow/python:errors",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -51,6 +49,7 @@ py_library(
|
|||||||
srcs = ["trt_convert.py"],
|
srcs = ["trt_convert.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":trt_ops_py",
|
||||||
":wrap_conversion",
|
":wrap_conversion",
|
||||||
"//tensorflow/python:graph_util",
|
"//tensorflow/python:graph_util",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
@ -83,7 +82,6 @@ py_library(
|
|||||||
srcs = ["test/tf_trt_integration_test_base.py"],
|
srcs = ["test/tf_trt_integration_test_base.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":trt_convert_py",
|
":trt_convert_py",
|
||||||
":trt_ops_py",
|
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
],
|
],
|
||||||
|
@ -18,25 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import errors
|
# pylint: disable=unused-import,line-too-long
|
||||||
|
|
||||||
# pylint: disable=unused-import,g-import-not-at-top,line-too-long
|
|
||||||
try:
|
|
||||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
|
||||||
from tensorflow.python.compiler.tensorrt.trt_convert import add_test_value
|
|
||||||
from tensorflow.python.compiler.tensorrt.trt_convert import calib_graph_to_infer_graph
|
|
||||||
from tensorflow.python.compiler.tensorrt.trt_convert import clear_test_values
|
|
||||||
from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph
|
from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph
|
||||||
from tensorflow.python.compiler.tensorrt.trt_convert import enable_test_value
|
# pylint: enable=unused-import,line-too-long
|
||||||
from tensorflow.python.compiler.tensorrt.trt_convert import get_test_value
|
|
||||||
from tensorflow.python.compiler.tensorrt.trt_convert import is_tensorrt_enabled
|
|
||||||
except errors.NotFoundError as e:
|
|
||||||
no_trt_message = (
|
|
||||||
'**** Failed to initialize TensorRT. This is either because the TensorRT'
|
|
||||||
' installation path is not in LD_LIBRARY_PATH, or because you do not have'
|
|
||||||
' it installed. If not installed, please go to'
|
|
||||||
' https://developer.nvidia.com/tensorrt to download and install'
|
|
||||||
' TensorRT ****')
|
|
||||||
print(no_trt_message)
|
|
||||||
raise e
|
|
||||||
# pylint: enable=unused-import,g-import-not-at-top,line-too-long
|
|
||||||
|
@ -20,8 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
|
||||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -154,7 +155,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
|
|||||||
"""Setup method."""
|
"""Setup method."""
|
||||||
super(PartiallyConvertedTestA, self).setUp()
|
super(PartiallyConvertedTestA, self).setUp()
|
||||||
# Let it fail to build the second engine.
|
# Let it fail to build the second engine.
|
||||||
trt_convert.add_test_value("TRTEngineOp_1:CreateTRTNode", "fail")
|
add_test_value("TRTEngineOp_1:CreateTRTNode", "fail")
|
||||||
|
|
||||||
def GetParams(self):
|
def GetParams(self):
|
||||||
"""Create a graph containing two segment."""
|
"""Create a graph containing two segment."""
|
||||||
@ -209,8 +210,8 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA):
|
|||||||
"""Setup method."""
|
"""Setup method."""
|
||||||
super(PartiallyConvertedTestB, self).setUp()
|
super(PartiallyConvertedTestB, self).setUp()
|
||||||
# Let it fail to build the first engine.
|
# Let it fail to build the first engine.
|
||||||
trt_convert.clear_test_values("")
|
clear_test_values("")
|
||||||
trt_convert.add_test_value("TRTEngineOp_0:CreateTRTNode", "fail")
|
add_test_value("TRTEngineOp_0:CreateTRTNode", "fail")
|
||||||
|
|
||||||
def ExpectedEnginesToBuild(self, run_params):
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
"""Return the expected engines to build."""
|
"""Return the expected engines to build."""
|
||||||
|
@ -18,13 +18,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
|
||||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
|
||||||
# pylint: enable=unused-import
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import data
|
from tensorflow.python import data
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
|
||||||
from tensorflow.python.estimator.estimator import Estimator
|
from tensorflow.python.estimator.estimator import Estimator
|
||||||
from tensorflow.python.estimator.model_fn import EstimatorSpec
|
from tensorflow.python.estimator.model_fn import EstimatorSpec
|
||||||
from tensorflow.python.estimator.model_fn import ModeKeys
|
from tensorflow.python.estimator.model_fn import ModeKeys
|
||||||
@ -263,7 +262,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
|
|||||||
# num_epochs=100,
|
# num_epochs=100,
|
||||||
# model_dir=model_dir)
|
# model_dir=model_dir)
|
||||||
def testEval(self):
|
def testEval(self):
|
||||||
if not trt_convert.is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
model_dir = test.test_src_dir_path('python/compiler/tensorrt/test/testdata')
|
model_dir = test.test_src_dir_path('python/compiler/tensorrt/test/testdata')
|
||||||
|
|
||||||
@ -276,7 +275,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
|
|||||||
logging.info('accuracy_tf_native: %f', accuracy_tf_native)
|
logging.info('accuracy_tf_native: %f', accuracy_tf_native)
|
||||||
self.assertAllClose(0.9662, accuracy_tf_native, rtol=1e-3, atol=1e-3)
|
self.assertAllClose(0.9662, accuracy_tf_native, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
if trt_convert.get_linked_tensorrt_version()[0] < 5:
|
if get_linked_tensorrt_version()[0] < 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
accuracy_tf_trt = self._Run(
|
accuracy_tf_trt = self._Run(
|
||||||
|
@ -20,8 +20,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
|
||||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -72,7 +72,7 @@ class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
return _GetParams(add_quantization_nodes=False)
|
return _GetParams(add_quantization_nodes=False)
|
||||||
|
|
||||||
def ShouldRunTest(self, run_params):
|
def ShouldRunTest(self, run_params):
|
||||||
if trt_convert.get_linked_tensorrt_version()[0] < 5:
|
if get_linked_tensorrt_version()[0] < 5:
|
||||||
return False
|
return False
|
||||||
# Only test static engine mode, with or without calibration.
|
# Only test static engine mode, with or without calibration.
|
||||||
return (trt_test.IsQuantizationMode(run_params.precision_mode) and
|
return (trt_test.IsQuantizationMode(run_params.precision_mode) and
|
||||||
@ -96,7 +96,7 @@ class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
return _GetParams(add_quantization_nodes=True)
|
return _GetParams(add_quantization_nodes=True)
|
||||||
|
|
||||||
def ShouldRunTest(self, run_params):
|
def ShouldRunTest(self, run_params):
|
||||||
if trt_convert.get_linked_tensorrt_version()[0] < 5:
|
if get_linked_tensorrt_version()[0] < 5:
|
||||||
return False
|
return False
|
||||||
# Test static/dynamic engine with/without calibration.
|
# Test static/dynamic engine with/without calibration.
|
||||||
return (trt_test.IsQuantizationMode(run_params.precision_mode) and
|
return (trt_test.IsQuantizationMode(run_params.precision_mode) and
|
||||||
|
@ -25,12 +25,13 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
|
||||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
|
||||||
# pylint: enable=unused-import
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import graph_io
|
from tensorflow.python.framework import graph_io
|
||||||
from tensorflow.python.framework import importer
|
from tensorflow.python.framework import importer
|
||||||
@ -151,7 +152,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
"""Setup method for the module."""
|
"""Setup method for the module."""
|
||||||
super(TfTrtIntegrationTestBase, cls).setUpClass()
|
super(TfTrtIntegrationTestBase, cls).setUpClass()
|
||||||
trt_convert.enable_test_value()
|
enable_test_value()
|
||||||
|
|
||||||
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
||||||
super(TfTrtIntegrationTestBase, self).__init__(methodName)
|
super(TfTrtIntegrationTestBase, self).__init__(methodName)
|
||||||
@ -161,7 +162,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
"""Setup method."""
|
"""Setup method."""
|
||||||
super(TfTrtIntegrationTestBase, self).setUp()
|
super(TfTrtIntegrationTestBase, self).setUp()
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
trt_convert.clear_test_values("")
|
clear_test_values("")
|
||||||
|
|
||||||
def GetParams(self):
|
def GetParams(self):
|
||||||
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
|
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
|
||||||
@ -246,9 +247,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
def _PrepareRun(self, graph_state):
|
def _PrepareRun(self, graph_state):
|
||||||
"""Set up necessary testing environment before calling sess.run()."""
|
"""Set up necessary testing environment before calling sess.run()."""
|
||||||
# Clear test values added by TRTEngineOp.
|
# Clear test values added by TRTEngineOp.
|
||||||
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine")
|
clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine")
|
||||||
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteCalibration")
|
clear_test_values("TRTEngineOp_.*:ExecuteCalibration")
|
||||||
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment")
|
clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment")
|
||||||
|
|
||||||
def _GetGPUOptions(self):
|
def _GetGPUOptions(self):
|
||||||
gpu_options = config_pb2.GPUOptions()
|
gpu_options = config_pb2.GPUOptions()
|
||||||
@ -282,7 +283,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def _ExpectTestValue(self, engine_name, method, expected_value):
|
def _ExpectTestValue(self, engine_name, method, expected_value):
|
||||||
label = "%s:%s" % (engine_name, method)
|
label = "%s:%s" % (engine_name, method)
|
||||||
actual_value = trt_convert.get_test_value(label)
|
actual_value = get_test_value(label)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expected_value,
|
expected_value,
|
||||||
actual_value,
|
actual_value,
|
||||||
@ -639,5 +640,5 @@ def _AddTests(test_class):
|
|||||||
setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
|
setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
|
||||||
|
|
||||||
|
|
||||||
if trt_convert.is_tensorrt_enabled():
|
if is_tensorrt_enabled():
|
||||||
_AddTests(TfTrtIntegrationTestBase)
|
_AddTests(TfTrtIntegrationTestBase)
|
||||||
|
@ -19,17 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import six as _six
|
import six as _six
|
||||||
# pylint: disable=unused-import,line-too-long
|
|
||||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
|
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
|
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
|
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
|
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
|
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
|
|
||||||
# pylint: enable=unused-import,line-too-long
|
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
@ -357,6 +347,14 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
TypeError: if any of the parameters are of unexpected type.
|
TypeError: if any of the parameters are of unexpected type.
|
||||||
ValueError: if any of the parameters are of unexpected value.
|
ValueError: if any of the parameters are of unexpected value.
|
||||||
"""
|
"""
|
||||||
|
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
|
||||||
|
# even if it cannot find TensorRT library.
|
||||||
|
trt_ops.load_trt_ops()
|
||||||
|
# pylint: disable=g-import-not-at-top,unused-import,line-too-long,unused-variable
|
||||||
|
# Import a random symbol to trigger loading of TRT library.
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
|
||||||
|
# pylint: enable=g-import-not-at-top,unused-import,line-too-long,unused-variable
|
||||||
|
|
||||||
if rewriter_config_template is not None and not isinstance(
|
if rewriter_config_template is not None and not isinstance(
|
||||||
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
|
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -457,6 +455,14 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
nodes_blacklist=nodes_blacklist,
|
nodes_blacklist=nodes_blacklist,
|
||||||
session_config=session_config)
|
session_config=session_config)
|
||||||
|
|
||||||
|
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
|
||||||
|
# even if it cannot find TensorRT library.
|
||||||
|
trt_ops.load_trt_ops()
|
||||||
|
# pylint: disable=g-import-not-at-top,line-too-long
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
|
||||||
|
# pylint: enable=g-import-not-at-top,line-too-long
|
||||||
|
|
||||||
# Check compatibility of TensorRT version.
|
# Check compatibility of TensorRT version.
|
||||||
compiled_version = get_linked_tensorrt_version()
|
compiled_version = get_linked_tensorrt_version()
|
||||||
loaded_version = get_loaded_tensorrt_version()
|
loaded_version = get_loaded_tensorrt_version()
|
||||||
@ -642,6 +648,12 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: if the returned status message is malformed.
|
RuntimeError: if the returned status message is malformed.
|
||||||
"""
|
"""
|
||||||
|
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
|
||||||
|
# even if it cannot find TensorRT library.
|
||||||
|
trt_ops.load_trt_ops()
|
||||||
|
# pylint: disable=g-import-not-at-top,line-too-long
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
|
||||||
|
# pylint: enable=g-import-not-at-top,line-too-long
|
||||||
|
|
||||||
is_calib_graph = False
|
is_calib_graph = False
|
||||||
for n in calibration_graph_def.node:
|
for n in calibration_graph_def.node:
|
||||||
|
@ -20,9 +20,10 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
||||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
|
||||||
# pylint: enable=unused-import
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
|
||||||
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
@ -53,6 +54,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testGetTensorrtRewriterConfig(self):
|
def testGetTensorrtRewriterConfig(self):
|
||||||
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
|
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
|
||||||
|
if not is_tensorrt_enabled():
|
||||||
|
return
|
||||||
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
||||||
rewriter_config_template=None,
|
rewriter_config_template=None,
|
||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
@ -172,7 +175,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testCreateInferenceGraph_BasicConversion(self):
|
def testCreateInferenceGraph_BasicConversion(self):
|
||||||
"""Test case for trt_convert.create_inference_graph()."""
|
"""Test case for trt_convert.create_inference_graph()."""
|
||||||
if not trt_convert.is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
# Use GraphDef as input.
|
# Use GraphDef as input.
|
||||||
@ -187,25 +190,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
output_saved_model_dir)
|
output_saved_model_dir)
|
||||||
|
|
||||||
def _TestRun(self, sess, batch_size, expect_engine_is_run):
|
def _TestRun(self, sess, batch_size, expect_engine_is_run):
|
||||||
trt_convert.clear_test_values("")
|
clear_test_values("")
|
||||||
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
||||||
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
||||||
execute_engine_test_value = ("done" if expect_engine_is_run else "")
|
execute_engine_test_value = ("done" if expect_engine_is_run else "")
|
||||||
execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
|
execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
|
||||||
self.assertEqual(
|
self.assertEqual(execute_engine_test_value,
|
||||||
execute_engine_test_value,
|
get_test_value("TRTEngineOp_0:ExecuteTrtEngine"))
|
||||||
trt_convert.get_test_value("TRTEngineOp_0:ExecuteTrtEngine"))
|
self.assertEqual(execute_native_segment_test_value,
|
||||||
self.assertEqual(
|
get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
|
||||||
execute_native_segment_test_value,
|
|
||||||
trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
|
|
||||||
|
|
||||||
def testCreateInferenceGraph_MinimumSegmentSize(self):
|
def testCreateInferenceGraph_MinimumSegmentSize(self):
|
||||||
if not trt_convert.is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
output_graph_def = trt_convert.create_inference_graph(
|
output_graph_def = trt_convert.create_inference_graph(
|
||||||
self._GetGraphDef(), ["output"],
|
self._GetGraphDef(), ["output"],
|
||||||
minimum_segment_size=5,
|
|
||||||
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
|
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||||
|
minimum_segment_size=5,
|
||||||
is_dynamic_op=False)
|
is_dynamic_op=False)
|
||||||
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
@ -218,9 +219,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
}, node_name_to_op)
|
}, node_name_to_op)
|
||||||
|
|
||||||
def testCreateInferenceGraph_DynamicOp(self):
|
def testCreateInferenceGraph_DynamicOp(self):
|
||||||
if not trt_convert.is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
trt_convert.enable_test_value()
|
enable_test_value()
|
||||||
|
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
|
input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
|
||||||
@ -261,9 +262,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self._TestRun(sess, 3, True)
|
self._TestRun(sess, 3, True)
|
||||||
|
|
||||||
def testCreateInferenceGraph_StaticOp(self):
|
def testCreateInferenceGraph_StaticOp(self):
|
||||||
if not trt_convert.is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
trt_convert.enable_test_value()
|
enable_test_value()
|
||||||
|
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
|
input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
|
||||||
|
Loading…
Reference in New Issue
Block a user