diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 00d3c8cc5f6..76a009e1d5c 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -160,6 +160,7 @@ tf_custom_op_py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", "//tensorflow/python:resources", diff --git a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py index 86bfabf99e0..a11150f1551 100644 --- a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -18,17 +18,45 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading + import platform +from tensorflow.python.framework import errors -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import * +_trt_ops_so = None +_module_lock = threading.Lock() - from tensorflow.python.framework import load_library - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - _trt_ops = load_library.load_op_library( - resource_loader.get_path_to_datafile("_trt_ops.so")) -else: - raise RuntimeError("Windows platforms are not supported") +def load_trt_ops(): + """Load TF-TRT op libraries so if it hasn't been loaded already.""" + global _trt_ops_so + + if platform.system() == "Windows": + raise RuntimeError("Windows platforms are not supported") + + with _module_lock: + if _trt_ops_so: + return + + # TODO(laigd): we should load TF-TRT kernels here as well after removing the + # swig binding. + try: + # TODO(lagid): consider getting rid of these unused imports. + # pylint: disable=unused-import,g-import-not-at-top,unused-variable + from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import get_serialized_resource_op + from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op + from tensorflow.python.framework import load_library + from tensorflow.python.platform import resource_loader + # pylint: enable=unused-import,g-import-not-at-top,unused-variable + + _trt_ops_so = load_library.load_op_library( + resource_loader.get_path_to_datafile("_trt_ops.so")) + except errors.NotFoundError as e: + no_trt_message = ( + "**** Failed to initialize TensorRT. This is either because the " + "TensorRT installation path is not in LD_LIBRARY_PATH, or because " + "you do not have it installed. If not installed, please go to " + "https://developer.nvidia.com/tensorrt to download and install " + "TensorRT ****") + print(no_trt_message) + raise e diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 36ec9e01e2e..a6d68684989 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -173,6 +173,7 @@ py_library( "//tensorflow/lite/python:lite", "//tensorflow/python/compat", "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/compiler", "//tensorflow/python/data", "//tensorflow/python/distribute", "//tensorflow/python/distribute:estimator_training", diff --git a/tensorflow/python/compiler/BUILD b/tensorflow/python/compiler/BUILD new file mode 100644 index 00000000000..26eb06a10e5 --- /dev/null +++ b/tensorflow/python/compiler/BUILD @@ -0,0 +1,17 @@ +# Description: +# Python APIs for various Tensorflow backends. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "compiler", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/compiler/tensorrt:init_py", + ], +) diff --git a/tensorflow/python/compiler/__init__.py b/tensorflow/python/compiler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index a382579f5a1..afa25113b24 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -32,8 +32,6 @@ py_library( deps = [ ":tf_trt_integration_test_base", ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:errors", ], ) @@ -51,6 +49,7 @@ py_library( srcs = ["trt_convert.py"], srcs_version = "PY2AND3", deps = [ + ":trt_ops_py", ":wrap_conversion", "//tensorflow/python:graph_util", "//tensorflow/python:session", @@ -83,7 +82,6 @@ py_library( srcs = ["test/tf_trt_integration_test_base.py"], deps = [ ":trt_convert_py", - ":trt_ops_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], diff --git a/tensorflow/python/compiler/tensorrt/__init__.py b/tensorflow/python/compiler/tensorrt/__init__.py index 88fb69101d0..db3540ba45d 100644 --- a/tensorflow/python/compiler/tensorrt/__init__.py +++ b/tensorflow/python/compiler/tensorrt/__init__.py @@ -18,25 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import errors - -# pylint: disable=unused-import,g-import-not-at-top,line-too-long -try: - from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops - from tensorflow.python.compiler.tensorrt.trt_convert import add_test_value - from tensorflow.python.compiler.tensorrt.trt_convert import calib_graph_to_infer_graph - from tensorflow.python.compiler.tensorrt.trt_convert import clear_test_values - from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph - from tensorflow.python.compiler.tensorrt.trt_convert import enable_test_value - from tensorflow.python.compiler.tensorrt.trt_convert import get_test_value - from tensorflow.python.compiler.tensorrt.trt_convert import is_tensorrt_enabled -except errors.NotFoundError as e: - no_trt_message = ( - '**** Failed to initialize TensorRT. This is either because the TensorRT' - ' installation path is not in LD_LIBRARY_PATH, or because you do not have' - ' it installed. If not installed, please go to' - ' https://developer.nvidia.com/tensorrt to download and install' - ' TensorRT ****') - print(no_trt_message) - raise e -# pylint: enable=unused-import,g-import-not-at-top,line-too-long +# pylint: disable=unused-import,line-too-long +from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph +# pylint: enable=unused-import,line-too-long diff --git a/tensorflow/python/compiler/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py index cc3109968e2..a1199c5040a 100644 --- a/tensorflow/python/compiler/tensorrt/test/base_test.py +++ b/tensorflow/python/compiler/tensorrt/test/base_test.py @@ -20,8 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.python.compiler.tensorrt import trt_convert from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value +from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -154,7 +155,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): """Setup method.""" super(PartiallyConvertedTestA, self).setUp() # Let it fail to build the second engine. - trt_convert.add_test_value("TRTEngineOp_1:CreateTRTNode", "fail") + add_test_value("TRTEngineOp_1:CreateTRTNode", "fail") def GetParams(self): """Create a graph containing two segment.""" @@ -209,8 +210,8 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA): """Setup method.""" super(PartiallyConvertedTestB, self).setUp() # Let it fail to build the first engine. - trt_convert.clear_test_values("") - trt_convert.add_test_value("TRTEngineOp_0:CreateTRTNode", "fail") + clear_test_values("") + add_test_value("TRTEngineOp_0:CreateTRTNode", "fail") def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py index 1d7792c01a2..e2b21d7b92f 100644 --- a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py +++ b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py @@ -18,13 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import -from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops -# pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 from tensorflow.python import data from tensorflow.python import keras from tensorflow.python.compiler.tensorrt import trt_convert +from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version +from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled from tensorflow.python.estimator.estimator import Estimator from tensorflow.python.estimator.model_fn import EstimatorSpec from tensorflow.python.estimator.model_fn import ModeKeys @@ -263,7 +262,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): # num_epochs=100, # model_dir=model_dir) def testEval(self): - if not trt_convert.is_tensorrt_enabled(): + if not is_tensorrt_enabled(): return model_dir = test.test_src_dir_path('python/compiler/tensorrt/test/testdata') @@ -276,7 +275,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): logging.info('accuracy_tf_native: %f', accuracy_tf_native) self.assertAllClose(0.9662, accuracy_tf_native, rtol=1e-3, atol=1e-3) - if trt_convert.get_linked_tensorrt_version()[0] < 5: + if get_linked_tensorrt_version()[0] < 5: return accuracy_tf_trt = self._Run( diff --git a/tensorflow/python/compiler/tensorrt/test/quantization_test.py b/tensorflow/python/compiler/tensorrt/test/quantization_test.py index 086e070f1b1..3e1c9ff8ddc 100644 --- a/tensorflow/python/compiler/tensorrt/test/quantization_test.py +++ b/tensorflow/python/compiler/tensorrt/test/quantization_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.python.compiler.tensorrt import trt_convert from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -72,7 +72,7 @@ class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase): return _GetParams(add_quantization_nodes=False) def ShouldRunTest(self, run_params): - if trt_convert.get_linked_tensorrt_version()[0] < 5: + if get_linked_tensorrt_version()[0] < 5: return False # Only test static engine mode, with or without calibration. return (trt_test.IsQuantizationMode(run_params.precision_mode) and @@ -96,7 +96,7 @@ class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase): return _GetParams(add_quantization_nodes=True) def ShouldRunTest(self, run_params): - if trt_convert.get_linked_tensorrt_version()[0] < 5: + if get_linked_tensorrt_version()[0] < 5: return False # Test static/dynamic engine with/without calibration. return (trt_test.IsQuantizationMode(run_params.precision_mode) and diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 28563f09a15..5af1a970169 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -25,12 +25,13 @@ import warnings import numpy as np import six -# pylint: disable=unused-import -from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops -# pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.compiler.tensorrt import trt_convert +from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values +from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value +from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value +from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer @@ -151,7 +152,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def setUpClass(cls): """Setup method for the module.""" super(TfTrtIntegrationTestBase, cls).setUpClass() - trt_convert.enable_test_value() + enable_test_value() def __init__(self, methodName="runTest"): # pylint: disable=invalid-name super(TfTrtIntegrationTestBase, self).__init__(methodName) @@ -161,7 +162,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): """Setup method.""" super(TfTrtIntegrationTestBase, self).setUp() warnings.simplefilter("always") - trt_convert.clear_test_values("") + clear_test_values("") def GetParams(self): """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" @@ -246,9 +247,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _PrepareRun(self, graph_state): """Set up necessary testing environment before calling sess.run().""" # Clear test values added by TRTEngineOp. - trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine") - trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteCalibration") - trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment") + clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine") + clear_test_values("TRTEngineOp_.*:ExecuteCalibration") + clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment") def _GetGPUOptions(self): gpu_options = config_pb2.GPUOptions() @@ -282,7 +283,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _ExpectTestValue(self, engine_name, method, expected_value): label = "%s:%s" % (engine_name, method) - actual_value = trt_convert.get_test_value(label) + actual_value = get_test_value(label) self.assertEqual( expected_value, actual_value, @@ -639,5 +640,5 @@ def _AddTests(test_class): setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params)) -if trt_convert.is_tensorrt_enabled(): +if is_tensorrt_enabled(): _AddTests(TfTrtIntegrationTestBase) diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 33b5e50418f..2dd34a83aab 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -19,17 +19,7 @@ from __future__ import division from __future__ import print_function import six as _six -# pylint: disable=unused-import,line-too-long from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops -from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value -from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert -from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values -from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value -from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version -from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version -from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value -from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled -# pylint: enable=unused-import,line-too-long from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 @@ -357,6 +347,14 @@ class TrtGraphConverter(GraphConverter): TypeError: if any of the parameters are of unexpected type. ValueError: if any of the parameters are of unexpected value. """ + # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain + # even if it cannot find TensorRT library. + trt_ops.load_trt_ops() + # pylint: disable=g-import-not-at-top,unused-import,line-too-long,unused-variable + # Import a random symbol to trigger loading of TRT library. + from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert + # pylint: enable=g-import-not-at-top,unused-import,line-too-long,unused-variable + if rewriter_config_template is not None and not isinstance( rewriter_config_template, rewriter_config_pb2.RewriterConfig): raise TypeError( @@ -457,6 +455,14 @@ class TrtGraphConverter(GraphConverter): nodes_blacklist=nodes_blacklist, session_config=session_config) + # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain + # even if it cannot find TensorRT library. + trt_ops.load_trt_ops() + # pylint: disable=g-import-not-at-top,line-too-long + from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version + from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version + # pylint: enable=g-import-not-at-top,line-too-long + # Check compatibility of TensorRT version. compiled_version = get_linked_tensorrt_version() loaded_version = get_loaded_tensorrt_version() @@ -642,6 +648,12 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): Raises: RuntimeError: if the returned status message is malformed. """ + # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain + # even if it cannot find TensorRT library. + trt_ops.load_trt_ops() + # pylint: disable=g-import-not-at-top,line-too-long + from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert + # pylint: enable=g-import-not-at-top,line-too-long is_calib_graph = False for n in calibration_graph_def.node: diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index 0dbc5c19708..023ac5c36e7 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -20,9 +20,10 @@ from __future__ import print_function import os -# pylint: disable=unused-import -from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops -# pylint: enable=unused-import +from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values +from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value +from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value +from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -53,6 +54,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase): def testGetTensorrtRewriterConfig(self): """Test case for TrtGraphConverter.get_tensorrt_rewriter_config().""" + if not is_tensorrt_enabled(): + return rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config( rewriter_config_template=None, max_batch_size=128, @@ -172,7 +175,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): def testCreateInferenceGraph_BasicConversion(self): """Test case for trt_convert.create_inference_graph().""" - if not trt_convert.is_tensorrt_enabled(): + if not is_tensorrt_enabled(): return # Use GraphDef as input. @@ -187,25 +190,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase): output_saved_model_dir) def _TestRun(self, sess, batch_size, expect_engine_is_run): - trt_convert.clear_test_values("") + clear_test_values("") result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size}) self.assertAllEqual([[[4.0]]] * batch_size, result) execute_engine_test_value = ("done" if expect_engine_is_run else "") execute_native_segment_test_value = ("" if expect_engine_is_run else "done") - self.assertEqual( - execute_engine_test_value, - trt_convert.get_test_value("TRTEngineOp_0:ExecuteTrtEngine")) - self.assertEqual( - execute_native_segment_test_value, - trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment")) + self.assertEqual(execute_engine_test_value, + get_test_value("TRTEngineOp_0:ExecuteTrtEngine")) + self.assertEqual(execute_native_segment_test_value, + get_test_value("TRTEngineOp_0:ExecuteNativeSegment")) def testCreateInferenceGraph_MinimumSegmentSize(self): - if not trt_convert.is_tensorrt_enabled(): + if not is_tensorrt_enabled(): return output_graph_def = trt_convert.create_inference_graph( self._GetGraphDef(), ["output"], - minimum_segment_size=5, max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES, + minimum_segment_size=5, is_dynamic_op=False) node_name_to_op = {node.name: node.op for node in output_graph_def.node} self.assertEqual({ @@ -218,9 +219,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase): }, node_name_to_op) def testCreateInferenceGraph_DynamicOp(self): - if not trt_convert.is_tensorrt_enabled(): + if not is_tensorrt_enabled(): return - trt_convert.enable_test_value() + enable_test_value() tmp_dir = self.get_temp_dir() input_saved_model_dir = os.path.join(tmp_dir, "in_dir2") @@ -261,9 +262,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self._TestRun(sess, 3, True) def testCreateInferenceGraph_StaticOp(self): - if not trt_convert.is_tensorrt_enabled(): + if not is_tensorrt_enabled(): return - trt_convert.enable_test_value() + enable_test_value() tmp_dir = self.get_temp_dir() input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")