diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 38958d4ce69..12a51f7d32d 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -76,14 +76,11 @@ cc_library( tf_cc_shared_object( name = "python/ops/libtftrt.so", - srcs = [ - "ops/get_serialized_resource_op.cc", - "ops/trt_engine_op.cc", - ], copts = tf_copts(is_external = True), linkopts = ["-lm"], deps = [ ":trt_op_kernels", + ":trt_op_libs", "//tensorflow/core:lib_proto_parsing", ] + if_tensorrt([ "@local_config_tensorrt//:tensorrt", @@ -163,6 +160,7 @@ tf_custom_op_py_library( ], srcs_version = "PY2AND3", deps = [ + ":trt_ops", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", diff --git a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py index 62ac5a581dc..92aae7bb6b4 100644 --- a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -40,7 +40,9 @@ def load_trt_ops(): try: # pylint: disable=g-import-not-at-top,unused-variable - # This registers the TRT ops, it doesn't require loading TRT library. + # This will call register_op_list() in + # tensorflow/python/framework/op_def_registry.py, but it doesn't register + # the op or the op kernel in C++ runtime. from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op # pylint: enable=g-import-not-at-top,unused-variable except ImportError as e: @@ -48,14 +50,14 @@ def load_trt_ops(): "not built with CUDA or TensorRT enabled. ****") raise e - # TODO(laigd): we should load TF-TRT kernels here as well after removing the - # swig binding. try: # pylint: disable=g-import-not-at-top from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader # pylint: enable=g-import-not-at-top + # Loading the shared object will cause registration of the op and the op + # kernel if we link TF-TRT dynamically. _tf_trt_so = load_library.load_op_library( resource_loader.get_path_to_datafile("libtftrt.so")) except errors.NotFoundError as e: diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index 5cb95446801..eb269b4db41 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -35,22 +35,13 @@ py_library( ], ) -py_library( - name = "trt_ops_py", - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/compiler/tf2tensorrt:trt_ops", - "//tensorflow/compiler/tf2tensorrt:trt_ops_loader", - ], -) - py_library( name = "trt_convert_py", srcs = ["trt_convert.py"], srcs_version = "PY2AND3", deps = [ - ":trt_ops_py", ":wrap_conversion", + "//tensorflow/compiler/tf2tensorrt:trt_ops_loader", "//tensorflow/python:convert_to_constants", "//tensorflow/python:func_graph", "//tensorflow/python:graph_util", @@ -92,7 +83,6 @@ cuda_py_test( srcs = ["trt_convert_test.py"], additional_deps = [ ":trt_convert_py", - ":trt_ops_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:graph_util",