Link TF-TRT into python:no_contrib (so it's included in TF 2.0 build) and lazily load TRT shared library.

PiperOrigin-RevId: 233494604
This commit is contained in:
Guangda Lai 2019-02-11 15:47:34 -08:00 committed by TensorFlower Gardener
parent 5a297f5efe
commit 8bb8258646
13 changed files with 124 additions and 84 deletions

View File

@ -160,6 +160,7 @@ tf_custom_op_py_library(
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:resources",

View File

@ -18,17 +18,45 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import platform
from tensorflow.python.framework import errors
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import *
_trt_ops_so = None
_module_lock = threading.Lock()
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
_trt_ops = load_library.load_op_library(
resource_loader.get_path_to_datafile("_trt_ops.so"))
else:
raise RuntimeError("Windows platforms are not supported")
def load_trt_ops():
"""Load TF-TRT op libraries so if it hasn't been loaded already."""
global _trt_ops_so
if platform.system() == "Windows":
raise RuntimeError("Windows platforms are not supported")
with _module_lock:
if _trt_ops_so:
return
# TODO(laigd): we should load TF-TRT kernels here as well after removing the
# swig binding.
try:
# TODO(lagid): consider getting rid of these unused imports.
# pylint: disable=unused-import,g-import-not-at-top,unused-variable
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import get_serialized_resource_op
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
# pylint: enable=unused-import,g-import-not-at-top,unused-variable
_trt_ops_so = load_library.load_op_library(
resource_loader.get_path_to_datafile("_trt_ops.so"))
except errors.NotFoundError as e:
no_trt_message = (
"**** Failed to initialize TensorRT. This is either because the "
"TensorRT installation path is not in LD_LIBRARY_PATH, or because "
"you do not have it installed. If not installed, please go to "
"https://developer.nvidia.com/tensorrt to download and install "
"TensorRT ****")
print(no_trt_message)
raise e

View File

@ -173,6 +173,7 @@ py_library(
"//tensorflow/lite/python:lite",
"//tensorflow/python/compat",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/compiler",
"//tensorflow/python/data",
"//tensorflow/python/distribute",
"//tensorflow/python/distribute:estimator_training",

View File

@ -0,0 +1,17 @@
# Description:
# Python APIs for various Tensorflow backends.
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "compiler",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/compiler/tensorrt:init_py",
],
)

View File

View File

@ -32,8 +32,6 @@ py_library(
deps = [
":tf_trt_integration_test_base",
":trt_convert_py",
":trt_ops_py",
"//tensorflow/python:errors",
],
)
@ -51,6 +49,7 @@ py_library(
srcs = ["trt_convert.py"],
srcs_version = "PY2AND3",
deps = [
":trt_ops_py",
":wrap_conversion",
"//tensorflow/python:graph_util",
"//tensorflow/python:session",
@ -83,7 +82,6 @@ py_library(
srcs = ["test/tf_trt_integration_test_base.py"],
deps = [
":trt_convert_py",
":trt_ops_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
],

View File

@ -18,25 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import errors
# pylint: disable=unused-import,g-import-not-at-top,line-too-long
try:
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
from tensorflow.python.compiler.tensorrt.trt_convert import add_test_value
from tensorflow.python.compiler.tensorrt.trt_convert import calib_graph_to_infer_graph
from tensorflow.python.compiler.tensorrt.trt_convert import clear_test_values
from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph
from tensorflow.python.compiler.tensorrt.trt_convert import enable_test_value
from tensorflow.python.compiler.tensorrt.trt_convert import get_test_value
from tensorflow.python.compiler.tensorrt.trt_convert import is_tensorrt_enabled
except errors.NotFoundError as e:
no_trt_message = (
'**** Failed to initialize TensorRT. This is either because the TensorRT'
' installation path is not in LD_LIBRARY_PATH, or because you do not have'
' it installed. If not installed, please go to'
' https://developer.nvidia.com/tensorrt to download and install'
' TensorRT ****')
print(no_trt_message)
raise e
# pylint: enable=unused-import,g-import-not-at-top,line-too-long
# pylint: disable=unused-import,line-too-long
from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph
# pylint: enable=unused-import,line-too-long

View File

@ -20,8 +20,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -154,7 +155,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
"""Setup method."""
super(PartiallyConvertedTestA, self).setUp()
# Let it fail to build the second engine.
trt_convert.add_test_value("TRTEngineOp_1:CreateTRTNode", "fail")
add_test_value("TRTEngineOp_1:CreateTRTNode", "fail")
def GetParams(self):
"""Create a graph containing two segment."""
@ -209,8 +210,8 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA):
"""Setup method."""
super(PartiallyConvertedTestB, self).setUp()
# Let it fail to build the first engine.
trt_convert.clear_test_values("")
trt_convert.add_test_value("TRTEngineOp_0:CreateTRTNode", "fail")
clear_test_values("")
add_test_value("TRTEngineOp_0:CreateTRTNode", "fail")
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""

View File

@ -18,13 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import data
from tensorflow.python import keras
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.python.estimator.model_fn import ModeKeys
@ -263,7 +262,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
# num_epochs=100,
# model_dir=model_dir)
def testEval(self):
if not trt_convert.is_tensorrt_enabled():
if not is_tensorrt_enabled():
return
model_dir = test.test_src_dir_path('python/compiler/tensorrt/test/testdata')
@ -276,7 +275,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
logging.info('accuracy_tf_native: %f', accuracy_tf_native)
self.assertAllClose(0.9662, accuracy_tf_native, rtol=1e-3, atol=1e-3)
if trt_convert.get_linked_tensorrt_version()[0] < 5:
if get_linked_tensorrt_version()[0] < 5:
return
accuracy_tf_trt = self._Run(

View File

@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -72,7 +72,7 @@ class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase):
return _GetParams(add_quantization_nodes=False)
def ShouldRunTest(self, run_params):
if trt_convert.get_linked_tensorrt_version()[0] < 5:
if get_linked_tensorrt_version()[0] < 5:
return False
# Only test static engine mode, with or without calibration.
return (trt_test.IsQuantizationMode(run_params.precision_mode) and
@ -96,7 +96,7 @@ class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase):
return _GetParams(add_quantization_nodes=True)
def ShouldRunTest(self, run_params):
if trt_convert.get_linked_tensorrt_version()[0] < 5:
if get_linked_tensorrt_version()[0] < 5:
return False
# Test static/dynamic engine with/without calibration.
return (trt_test.IsQuantizationMode(run_params.precision_mode) and

View File

@ -25,12 +25,13 @@ import warnings
import numpy as np
import six
# pylint: disable=unused-import
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
@ -151,7 +152,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def setUpClass(cls):
"""Setup method for the module."""
super(TfTrtIntegrationTestBase, cls).setUpClass()
trt_convert.enable_test_value()
enable_test_value()
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(TfTrtIntegrationTestBase, self).__init__(methodName)
@ -161,7 +162,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Setup method."""
super(TfTrtIntegrationTestBase, self).setUp()
warnings.simplefilter("always")
trt_convert.clear_test_values("")
clear_test_values("")
def GetParams(self):
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
@ -246,9 +247,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _PrepareRun(self, graph_state):
"""Set up necessary testing environment before calling sess.run()."""
# Clear test values added by TRTEngineOp.
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine")
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteCalibration")
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment")
clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine")
clear_test_values("TRTEngineOp_.*:ExecuteCalibration")
clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment")
def _GetGPUOptions(self):
gpu_options = config_pb2.GPUOptions()
@ -282,7 +283,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _ExpectTestValue(self, engine_name, method, expected_value):
label = "%s:%s" % (engine_name, method)
actual_value = trt_convert.get_test_value(label)
actual_value = get_test_value(label)
self.assertEqual(
expected_value,
actual_value,
@ -639,5 +640,5 @@ def _AddTests(test_class):
setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
if trt_convert.is_tensorrt_enabled():
if is_tensorrt_enabled():
_AddTests(TfTrtIntegrationTestBase)

View File

@ -19,17 +19,7 @@ from __future__ import division
from __future__ import print_function
import six as _six
# pylint: disable=unused-import,line-too-long
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
# pylint: enable=unused-import,line-too-long
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
@ -357,6 +347,14 @@ class TrtGraphConverter(GraphConverter):
TypeError: if any of the parameters are of unexpected type.
ValueError: if any of the parameters are of unexpected value.
"""
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
# even if it cannot find TensorRT library.
trt_ops.load_trt_ops()
# pylint: disable=g-import-not-at-top,unused-import,line-too-long,unused-variable
# Import a random symbol to trigger loading of TRT library.
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
# pylint: enable=g-import-not-at-top,unused-import,line-too-long,unused-variable
if rewriter_config_template is not None and not isinstance(
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
raise TypeError(
@ -457,6 +455,14 @@ class TrtGraphConverter(GraphConverter):
nodes_blacklist=nodes_blacklist,
session_config=session_config)
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
# even if it cannot find TensorRT library.
trt_ops.load_trt_ops()
# pylint: disable=g-import-not-at-top,line-too-long
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
# pylint: enable=g-import-not-at-top,line-too-long
# Check compatibility of TensorRT version.
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
@ -642,6 +648,12 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
Raises:
RuntimeError: if the returned status message is malformed.
"""
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
# even if it cannot find TensorRT library.
trt_ops.load_trt_ops()
# pylint: disable=g-import-not-at-top,line-too-long
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
# pylint: enable=g-import-not-at-top,line-too-long
is_calib_graph = False
for n in calibration_graph_def.node:

View File

@ -20,9 +20,10 @@ from __future__ import print_function
import os
# pylint: disable=unused-import
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
# pylint: enable=unused-import
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
@ -53,6 +54,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
def testGetTensorrtRewriterConfig(self):
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
if not is_tensorrt_enabled():
return
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
rewriter_config_template=None,
max_batch_size=128,
@ -172,7 +175,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
def testCreateInferenceGraph_BasicConversion(self):
"""Test case for trt_convert.create_inference_graph()."""
if not trt_convert.is_tensorrt_enabled():
if not is_tensorrt_enabled():
return
# Use GraphDef as input.
@ -187,25 +190,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
output_saved_model_dir)
def _TestRun(self, sess, batch_size, expect_engine_is_run):
trt_convert.clear_test_values("")
clear_test_values("")
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
self.assertAllEqual([[[4.0]]] * batch_size, result)
execute_engine_test_value = ("done" if expect_engine_is_run else "")
execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
self.assertEqual(
execute_engine_test_value,
trt_convert.get_test_value("TRTEngineOp_0:ExecuteTrtEngine"))
self.assertEqual(
execute_native_segment_test_value,
trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
self.assertEqual(execute_engine_test_value,
get_test_value("TRTEngineOp_0:ExecuteTrtEngine"))
self.assertEqual(execute_native_segment_test_value,
get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
def testCreateInferenceGraph_MinimumSegmentSize(self):
if not trt_convert.is_tensorrt_enabled():
if not is_tensorrt_enabled():
return
output_graph_def = trt_convert.create_inference_graph(
self._GetGraphDef(), ["output"],
minimum_segment_size=5,
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
minimum_segment_size=5,
is_dynamic_op=False)
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
self.assertEqual({
@ -218,9 +219,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
}, node_name_to_op)
def testCreateInferenceGraph_DynamicOp(self):
if not trt_convert.is_tensorrt_enabled():
if not is_tensorrt_enabled():
return
trt_convert.enable_test_value()
enable_test_value()
tmp_dir = self.get_temp_dir()
input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
@ -261,9 +262,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
self._TestRun(sess, 3, True)
def testCreateInferenceGraph_StaticOp(self):
if not trt_convert.is_tensorrt_enabled():
if not is_tensorrt_enabled():
return
trt_convert.enable_test_value()
enable_test_value()
tmp_dir = self.get_temp_dir()
input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")