Introducing TensortRT Operator to TF which can run (sub)graphs in
highly optimized TensorRT engines. This commit is a merged version of many commits by benbarsdell <bbarsdell at nvidia.com> deadeyegoodwin <davidg at nvidia.com jjsjann123 <jiej at nvidia.com> samikama <skama at nvidia.com>
This commit is contained in:
parent
e810b107d8
commit
825e7a32e9
126
configure.py
126
configure.py
@ -37,12 +37,14 @@ _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
|||||||
_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||||
'WORKSPACE')
|
'WORKSPACE')
|
||||||
_DEFAULT_CUDA_VERSION = '9.0'
|
_DEFAULT_CUDA_VERSION = '9.0'
|
||||||
|
_DEFAULT_TENSORRT_VERSION = '4'
|
||||||
_DEFAULT_CUDNN_VERSION = '7'
|
_DEFAULT_CUDNN_VERSION = '7'
|
||||||
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
|
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
|
||||||
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
|
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
|
||||||
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
|
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
|
||||||
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
|
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
|
||||||
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
|
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
|
||||||
|
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu'
|
||||||
_TF_OPENCL_VERSION = '1.2'
|
_TF_OPENCL_VERSION = '1.2'
|
||||||
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
|
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
|
||||||
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
|
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
|
||||||
@ -382,13 +384,12 @@ def set_build_var(environ_cp, var_name, query_item, option_name,
|
|||||||
|
|
||||||
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
|
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
|
||||||
environ_cp[var_name] = var
|
environ_cp[var_name] = var
|
||||||
if var == '1':
|
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
|
||||||
write_to_bazelrc('build --define %s=true' % option_name)
|
# options and not to set build configs through environment variables.
|
||||||
elif bazel_config_name is not None:
|
if var=='1':
|
||||||
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
|
setting='true'
|
||||||
# options and not to set build configs through environment variables.
|
confname=":%s"%(bazel_config_name) if bazel_config_name is not None else ""
|
||||||
write_to_bazelrc('build:%s --define %s=true'
|
write_to_bazelrc('build%s --define %s=%s' % (confname,option_name,setting))
|
||||||
% (bazel_config_name, option_name))
|
|
||||||
|
|
||||||
|
|
||||||
def set_action_env_var(environ_cp,
|
def set_action_env_var(environ_cp,
|
||||||
@ -438,13 +439,12 @@ def convert_version_to_int(version):
|
|||||||
for seg in version_segments:
|
for seg in version_segments:
|
||||||
if not seg.isdigit():
|
if not seg.isdigit():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
version_str = ''.join(['%03d' % int(seg) for seg in version_segments])
|
version_str = ''.join(['%03d' % int(seg) for seg in version_segments])
|
||||||
return int(version_str)
|
return int(version_str)
|
||||||
|
|
||||||
|
|
||||||
def check_bazel_version(min_version):
|
def check_bazel_version(min_version):
|
||||||
"""Check installed bezel version is at least min_version.
|
"""Check installed bazel version is at least min_version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
min_version: string for minimum bazel version.
|
min_version: string for minimum bazel version.
|
||||||
@ -1056,6 +1056,108 @@ def set_other_cuda_vars(environ_cp):
|
|||||||
write_to_bazelrc('test --config=cuda')
|
write_to_bazelrc('test --config=cuda')
|
||||||
|
|
||||||
|
|
||||||
|
def set_tf_trt_version(environ_cp):
|
||||||
|
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION."""
|
||||||
|
ask_trt_version = (
|
||||||
|
'Please specify the TensorRT (libnvinfer) version you want to use. '
|
||||||
|
'[Leave empty to default to libnvinfer %s]: ') % _DEFAULT_TENSORRT_VERSION
|
||||||
|
|
||||||
|
while True:
|
||||||
|
tf_trt_version = get_from_env_or_user_or_default(
|
||||||
|
environ_cp, 'TF_TENSORRT_VERSION', ask_trt_version,
|
||||||
|
_DEFAULT_TENSORRT_VERSION)
|
||||||
|
# if library version is passed and known
|
||||||
|
default_trt_path = environ_cp.get('TENSORRT_INSTALL_PATH',_DEFAULT_TENSORRT_PATH_LINUX)
|
||||||
|
ask_trt_path = (r'Please specify the location where libnvinfer %s library is '
|
||||||
|
'installed. Refer to README.md for more details. [Default'
|
||||||
|
' is %s]:') % (tf_trt_version, default_trt_path)
|
||||||
|
trt_install_path = get_from_env_or_user_or_default(
|
||||||
|
environ_cp, 'TENSORRT_INSTALL_PATH', ask_trt_path, default_trt_path)
|
||||||
|
|
||||||
|
# Result returned from "read" will be used unexpanded. That make "~"
|
||||||
|
# unusable. Going through one more level of expansion to handle that.
|
||||||
|
trt_install_path = os.path.realpath(
|
||||||
|
os.path.expanduser(trt_install_path))
|
||||||
|
# Simple function to search for libnvinfer in install path
|
||||||
|
# it will find all libnvinfer.so* in user defined install path
|
||||||
|
# and lib64 subdirectory and return absolute paths
|
||||||
|
def find_libs(search_path):
|
||||||
|
fl=set()
|
||||||
|
if os.path.exists(search_path) and os.path.isdir(search_path):
|
||||||
|
fl.update([os.path.realpath(os.path.join(search_path,x)) \
|
||||||
|
for x in os.listdir(search_path) if 'libnvinfer.so' in x])
|
||||||
|
return fl
|
||||||
|
possible_files=find_libs(trt_install_path)
|
||||||
|
possible_files.update(find_libs(os.path.join(trt_install_path,'lib64')))
|
||||||
|
if is_linux():
|
||||||
|
cudnnpatt=re.compile(".*libcudnn.so\.?(.*) =>.*$")
|
||||||
|
cudapatt =re.compile(".*libcudart.so\.?(.*) =>.*$")
|
||||||
|
def is_compatible(lib,cudaver,cudnnver):
|
||||||
|
ldd_bin=which('ldd') or '/usr/bin/ldd'
|
||||||
|
ldd_out=run_shell([ldd_bin,lib]).split(os.linesep)
|
||||||
|
for l in ldd_out:
|
||||||
|
if 'libcudnn.so' in l:
|
||||||
|
cudnn=cudnnpatt.search(l)
|
||||||
|
elif 'libcudart.so' in l:
|
||||||
|
cudart=cudapatt.search(l)
|
||||||
|
if cudnn:
|
||||||
|
cudnn=convert_version_to_int(cudnn.group(1)) if len(cudnn.group(1)) else 0
|
||||||
|
if cudart:
|
||||||
|
cudart=convert_version_to_int(cudart.group(1)) if len(cudart.group(1)) else 0
|
||||||
|
return (cudnn==cudnnver) and (cudart==cudaver)
|
||||||
|
cudaver=convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
|
||||||
|
cudnnver=convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
|
||||||
|
valid_libs=[]
|
||||||
|
vfinder=re.compile('.*libnvinfer.so.?(.*)$')
|
||||||
|
highest_ver=[0,None,None]
|
||||||
|
|
||||||
|
for l in possible_files:
|
||||||
|
if is_compatible(l,cudaver,cudnnver):
|
||||||
|
valid_libs.append(l)
|
||||||
|
vstr=vfinder.search(l).group(1)
|
||||||
|
currver=convert_version_to_int(vstr) if len(vstr) else 0
|
||||||
|
if currver > highest_ver[0]:
|
||||||
|
highest_ver= [currver,vstr,l]
|
||||||
|
if highest_ver[1] is not None:
|
||||||
|
trt_install_path=os.path.dirname(highest_ver[2])
|
||||||
|
tf_trt_version=highest_ver[1]
|
||||||
|
break
|
||||||
|
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
|
||||||
|
libnvinfer_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
|
||||||
|
libnvinfer_path_from_ldconfig = re.search('.*libnvinfer.so.* => (.*)',
|
||||||
|
libnvinfer_path_from_ldconfig)
|
||||||
|
if libnvinfer_path_from_ldconfig:
|
||||||
|
libnvinfer_path_from_ldconfig = libnvinfer_path_from_ldconfig.group(1)
|
||||||
|
if os.path.exists('%s.%s' % (libnvinfer_path_from_ldconfig,
|
||||||
|
tf_trt_version)):
|
||||||
|
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset and Retry
|
||||||
|
if len(possible_files):
|
||||||
|
print(
|
||||||
|
'Invalid path to TensorRT %s. libnvinfer.so* files found are for incompatible cuda versions '
|
||||||
|
% tf_trt_version)
|
||||||
|
print(trt_install_path)
|
||||||
|
print(os.path.join(trt_install_path,'lib64'))
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
'Invalid path to TensorRT %s. No libnvinfer.so* files found in '
|
||||||
|
'found:' % tf_trt_version)
|
||||||
|
print(trt_install_path)
|
||||||
|
print(os.path.join(trt_install_path,'lib64'))
|
||||||
|
if is_linux():
|
||||||
|
print('%s.%s' % (libnvinfer_path_from_ldconfig, tf_trt_version))
|
||||||
|
|
||||||
|
environ_cp['TF_TENSORRT_VERSION'] = ''
|
||||||
|
|
||||||
|
# Set TENSORRT_INSTALL_PATH and TENSORRT_CUDNN_VERSION
|
||||||
|
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
|
||||||
|
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
|
||||||
|
environ_cp['TF_TENSORRT_VERSION'] = tf_trt_version
|
||||||
|
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_trt_version)
|
||||||
|
write_to_bazelrc('build:tensorrt --define using_tensorrt=true')
|
||||||
|
|
||||||
def set_host_cxx_compiler(environ_cp):
|
def set_host_cxx_compiler(environ_cp):
|
||||||
"""Set HOST_CXX_COMPILER."""
|
"""Set HOST_CXX_COMPILER."""
|
||||||
default_cxx_host_compiler = which('g++') or ''
|
default_cxx_host_compiler = which('g++') or ''
|
||||||
@ -1244,9 +1346,11 @@ def main():
|
|||||||
environ_cp['TF_NEED_COMPUTECPP'] = '0'
|
environ_cp['TF_NEED_COMPUTECPP'] = '0'
|
||||||
environ_cp['TF_NEED_OPENCL'] = '0'
|
environ_cp['TF_NEED_OPENCL'] = '0'
|
||||||
environ_cp['TF_CUDA_CLANG'] = '0'
|
environ_cp['TF_CUDA_CLANG'] = '0'
|
||||||
|
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||||
|
|
||||||
if is_macos():
|
if is_macos():
|
||||||
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
||||||
|
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||||
|
|
||||||
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
|
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
|
||||||
'with_jemalloc', True)
|
'with_jemalloc', True)
|
||||||
@ -1301,6 +1405,10 @@ def main():
|
|||||||
if not is_windows():
|
if not is_windows():
|
||||||
set_gcc_host_compiler_path(environ_cp)
|
set_gcc_host_compiler_path(environ_cp)
|
||||||
set_other_cuda_vars(environ_cp)
|
set_other_cuda_vars(environ_cp)
|
||||||
|
# enable tensorrt if desired. Disabled on non-linux
|
||||||
|
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
|
||||||
|
if environ_cp.get('TF_NEED_TENSORRT') == '1':
|
||||||
|
set_tf_trt_version(environ_cp)
|
||||||
|
|
||||||
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
|
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
|
||||||
if environ_cp.get('TF_NEED_MPI') == '1':
|
if environ_cp.get('TF_NEED_MPI') == '1':
|
||||||
|
@ -358,6 +358,14 @@ config_setting(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "using_tensorrt",
|
||||||
|
define_values = {
|
||||||
|
"using_tensorrt":"true",
|
||||||
|
},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "with_mpi_support",
|
name = "with_mpi_support",
|
||||||
values = {"define": "with_mpi_support=true"},
|
values = {"define": "with_mpi_support=true"},
|
||||||
|
@ -7,6 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
|
|||||||
|
|
||||||
load("//third_party/mpi:mpi.bzl", "if_mpi")
|
load("//third_party/mpi:mpi.bzl", "if_mpi")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
|
load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "contrib_py",
|
name = "contrib_py",
|
||||||
@ -104,7 +105,9 @@ py_library(
|
|||||||
"//tensorflow/contrib/training:training_py",
|
"//tensorflow/contrib/training:training_py",
|
||||||
"//tensorflow/contrib/util:util_py",
|
"//tensorflow/contrib/util:util_py",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
|
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"])
|
||||||
|
+ if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
266
tensorflow/contrib/tensorrt/BUILD
Normal file
266
tensorflow/contrib/tensorrt/BUILD
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
# -*- python -*-
|
||||||
|
# Description:
|
||||||
|
# provide tensorrt operators and converter package
|
||||||
|
|
||||||
|
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_custom_op_library",
|
||||||
|
"tf_gen_op_libs",
|
||||||
|
"tf_gen_op_wrapper_py",
|
||||||
|
"tf_py_wrap_cc",
|
||||||
|
"tf_cc_test",
|
||||||
|
"tf_kernel_library",
|
||||||
|
"tf_custom_op_py_library",
|
||||||
|
"tf_copts",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "python/ops/_trt_engine_op.so",
|
||||||
|
srcs = [
|
||||||
|
"kernels/trt_engine_op.cc",
|
||||||
|
"ops/trt_engine_op.cc",
|
||||||
|
"kernels/trt_engine_op.h",
|
||||||
|
],
|
||||||
|
gpu_srcs = [],
|
||||||
|
deps = [
|
||||||
|
"@local_config_tensorrt//:tensorrt",
|
||||||
|
":trt_shape_function",
|
||||||
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
"//tensorflow/core/kernels:bounds_check_lib",
|
||||||
|
"//tensorflow/core/kernels:ops_util_hdrs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "trt_shape_function",
|
||||||
|
srcs=[
|
||||||
|
"shape_fn/trt_shfn.cc",
|
||||||
|
],
|
||||||
|
hdrs=["shape_fn/trt_shfn.h"],
|
||||||
|
copts=tf_copts(),
|
||||||
|
deps=[
|
||||||
|
":trt_logging",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"@local_config_tensorrt//:tensorrt",
|
||||||
|
"@protobuf_archive//:protobuf",
|
||||||
|
"@nsync//:nsync_headers",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "trt_engine_op_kernel",
|
||||||
|
srcs = [
|
||||||
|
"kernels/trt_engine_op.cc",
|
||||||
|
],
|
||||||
|
hdrs=[
|
||||||
|
"kernels/trt_engine_op.h",
|
||||||
|
],
|
||||||
|
gpu_srcs = [
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":trt_logging",
|
||||||
|
":trt_shape_function",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"//tensorflow/core:gpu_headers_lib",
|
||||||
|
"@local_config_tensorrt//:tensorrt",
|
||||||
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
],
|
||||||
|
alwayslink=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_gen_op_libs(
|
||||||
|
op_lib_names = [
|
||||||
|
"trt_engine_op",
|
||||||
|
],
|
||||||
|
deps=[
|
||||||
|
"@local_config_tensorrt//:tensorrt",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name="trt_logging",
|
||||||
|
srcs = [
|
||||||
|
"log/trt_logger.cc",
|
||||||
|
],
|
||||||
|
hdrs=[
|
||||||
|
"log/trt_logger.h",
|
||||||
|
],
|
||||||
|
deps=[
|
||||||
|
"@local_config_tensorrt//:tensorrt",
|
||||||
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrapper_py(
|
||||||
|
name = "trt_engine_op",
|
||||||
|
deps = [
|
||||||
|
":trt_engine_op_op_lib",
|
||||||
|
":trt_shape_function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tf_custom_op_py_library(
|
||||||
|
name = "trt_engine_op_loader",
|
||||||
|
srcs = ["python/ops/trt_engine_op.py"],
|
||||||
|
dso = [":python/ops/_trt_engine_op.so",
|
||||||
|
"@local_config_tensorrt//:tensorrt",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:resources",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "init_py",
|
||||||
|
srcs = [
|
||||||
|
"__init__.py",
|
||||||
|
"python/__init__.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":trt_ops_py",
|
||||||
|
":trt_convert_py",
|
||||||
|
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="trt_ops_py",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps=[":trt_engine_op",
|
||||||
|
":trt_engine_op_loader",
|
||||||
|
],
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="trt_convert_py",
|
||||||
|
srcs=["python/trt_convert.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps=[
|
||||||
|
":wrap_conversion"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_wrap_cc(
|
||||||
|
name="wrap_conversion",
|
||||||
|
srcs=["trt_conversion.i"],
|
||||||
|
deps=[
|
||||||
|
":trt_conversion",
|
||||||
|
"//tensorflow/core:framework_lite",
|
||||||
|
"//util/python:python_headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name= "trt_conversion",
|
||||||
|
srcs=[
|
||||||
|
"convert/convert_nodes.cc",
|
||||||
|
"convert/convert_graph.cc",
|
||||||
|
"segment/segment.cc",
|
||||||
|
"convert/inferShapes.cc",
|
||||||
|
],
|
||||||
|
hdrs=[
|
||||||
|
"convert/convert_nodes.h",
|
||||||
|
"convert/convert_graph.h",
|
||||||
|
"convert/inferShapes.h",
|
||||||
|
"segment/segment.h",
|
||||||
|
"segment/union_find.h",
|
||||||
|
],
|
||||||
|
deps=[
|
||||||
|
"@local_config_tensorrt//:tensorrt",
|
||||||
|
"@protobuf_archive//:protobuf_headers",
|
||||||
|
"@nsync//:nsync_headers",
|
||||||
|
":trt_logging",
|
||||||
|
"//tensorflow/core:framework_lite",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
"//tensorflow/core:core_cpu_base",
|
||||||
|
#"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "tensorrt_ops.so",
|
||||||
|
srcs = [
|
||||||
|
"ops/tensorrt_ops.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"@local_config_tensorrt//:tensorrt",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Library for the segmenting portion of TensorRT operation creation
|
||||||
|
cc_library(
|
||||||
|
name = "segment",
|
||||||
|
srcs = [
|
||||||
|
"segment/segment.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"segment/union_find.h",
|
||||||
|
"segment/segment.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"@protobuf_archive//:protobuf_headers",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
linkstatic = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "segment_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["segment/segment_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":segment",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Library for the node-level conversion portion of TensorRT operation creation
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "cppfiles",
|
||||||
|
srcs = glob(["**/*.cc"]),
|
||||||
|
visibility=["//visibility:private"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "headers",
|
||||||
|
srcs = glob(["**/*.h"]),
|
||||||
|
visibility=["//visibility:private"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
42
tensorflow/contrib/tensorrt/README.md
Normal file
42
tensorflow/contrib/tensorrt/README.md
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
Using TensorRT in TensorFlow
|
||||||
|
============================
|
||||||
|
|
||||||
|
This module provides necessary bindings and introduces TRT_engine_op
|
||||||
|
operator that wraps a subgraph in TensorRT.
|
||||||
|
|
||||||
|
Compilation
|
||||||
|
-----------
|
||||||
|
|
||||||
|
In order to compile the module, you need to have a local TensorRT
|
||||||
|
installation (libnvinfer.so and respective include files). During the
|
||||||
|
configuration step, TensorRT should be enabled and installation path
|
||||||
|
should be set. If installed through package managers (deb,rpm),
|
||||||
|
configure script should find the necessary components from the system
|
||||||
|
automatically. If installed from tar packages, user has to set path to
|
||||||
|
location where the library is installed during configuration.
|
||||||
|
|
||||||
|
In order to enable TensorRT support, user has to add `--config=tensorrt` to
|
||||||
|
the build flags during the compilation such as
|
||||||
|
|
||||||
|
```
|
||||||
|
bazel build --config=cuda --config=opt --config=tensorrt //tensorflow/tools/pip_package:build_pip_package
|
||||||
|
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
|
||||||
|
```
|
||||||
|
|
||||||
|
After the installation of tensorflow package, TensorRT transformation
|
||||||
|
will be available. An example use is shown below.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.contrib.tensorrt as trt
|
||||||
|
#... create and train or load model
|
||||||
|
gdef=sess.graph.as_graph_def()
|
||||||
|
trt_gdef=trt.CreateInferenceGraph(gdef, #original graph_def
|
||||||
|
["output"], #name of output node(s)
|
||||||
|
max_batch_size, #maximum batch size to run the inference
|
||||||
|
max_workspace_size # max memory for TensorRT to use
|
||||||
|
)
|
||||||
|
tf.reset_default_graph()
|
||||||
|
tf.import_graph_def(graph_def=trt_gdef)
|
||||||
|
#...... run inference
|
||||||
|
```
|
19
tensorflow/contrib/tensorrt/__init__.py
Normal file
19
tensorflow/contrib/tensorrt/__init__.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.tensorrt.python import *
|
253
tensorflow/contrib/tensorrt/convert/convert_graph.cc
Normal file
253
tensorflow/contrib/tensorrt/convert/convert_graph.cc
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "NvInfer.h"
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
|
||||||
|
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
|
||||||
|
#include "tensorflow/contrib/tensorrt/segment/segment.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace convert {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static std::unordered_set<std::string> output_nodes;
|
||||||
|
bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
|
||||||
|
static const std::set<std::string> candidate_ops = {
|
||||||
|
"Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
|
||||||
|
"Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
|
||||||
|
// TODO(ben,jie): ...
|
||||||
|
};
|
||||||
|
if (output_nodes.count(node_def.name())) return false;
|
||||||
|
return candidate_ops.count(node_def.op());
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetSubGraphIncomingEdges(tensorflow::Graph const& graph,
|
||||||
|
std::set<int> const& subgraph_node_ids,
|
||||||
|
tensorflow::EdgeSet* incoming_edges) {
|
||||||
|
for (int node_id : subgraph_node_ids) {
|
||||||
|
tensorflow::Node const* node = graph.FindNodeId(node_id);
|
||||||
|
LOG(DEBUG) << node->name() << " has incoming edges: ";
|
||||||
|
for (tensorflow::Edge const* edge : node->in_edges()) {
|
||||||
|
if (!subgraph_node_ids.count(edge->src()->id()) &&
|
||||||
|
!edge->src()->IsSource()) {
|
||||||
|
LOG(DEBUG) << edge->src()->name() << ", ";
|
||||||
|
incoming_edges->insert(edge);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetSubGraphOutgoingEdges(tensorflow::Graph const& graph,
|
||||||
|
std::set<int> const& subgraph_node_ids,
|
||||||
|
tensorflow::EdgeSet* outgoing_edges) {
|
||||||
|
for (int node_id : subgraph_node_ids) {
|
||||||
|
tensorflow::Node const* node = graph.FindNodeId(node_id);
|
||||||
|
LOG(DEBUG) << node->name() << " has outgoing edges: ";
|
||||||
|
for (tensorflow::Edge const* edge : node->out_edges()) {
|
||||||
|
if (!subgraph_node_ids.count(edge->dst()->id()) &&
|
||||||
|
!edge->dst()->IsSink()) {
|
||||||
|
outgoing_edges->insert(edge);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, int> ParseTensorName(std::string name,
|
||||||
|
int default_idx = 0) {
|
||||||
|
int idx = default_idx;
|
||||||
|
size_t sep = name.find_last_of(':');
|
||||||
|
if (sep != std::string::npos) {
|
||||||
|
name = name.substr(0, sep);
|
||||||
|
idx = std::stoi(name.substr(sep + 1));
|
||||||
|
}
|
||||||
|
return std::make_pair(name, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap(
|
||||||
|
const std::vector<std::string>& tensor_names) {
|
||||||
|
std::unordered_map<std::string, std::vector<int>> result;
|
||||||
|
for (std::string const& tensor_name : tensor_names) {
|
||||||
|
std::string node_name;
|
||||||
|
int index;
|
||||||
|
std::tie(node_name, index) = ParseTensorName(tensor_name);
|
||||||
|
result[node_name].push_back(index);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status ConvertSubGraphToTensorRT(
|
||||||
|
tensorflow::Graph& graph, const std::vector<std::string>& output_names,
|
||||||
|
const std::set<int>& subgraph_node_ids, size_t max_batch_size,
|
||||||
|
size_t max_workspace_size, const ShapeMap& shape_map) {
|
||||||
|
tensorflow::EdgeSet subgraph_incoming_edges;
|
||||||
|
GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges);
|
||||||
|
|
||||||
|
std::vector<std::pair<int, int>> subgraph_inputs;
|
||||||
|
|
||||||
|
|
||||||
|
// Collect inputs by looking for incoming edges
|
||||||
|
for (tensorflow::Edge const* edge : subgraph_incoming_edges) {
|
||||||
|
subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
|
||||||
|
}
|
||||||
|
std::set<std::pair<int, int>> subgraph_outputs_set;
|
||||||
|
// Collect outputs referenced from output_names
|
||||||
|
auto output_name_to_index_map = BuildTensorNameMap(output_names);
|
||||||
|
// for (int node_id : subgraph_node_ids_no_placeholder) {
|
||||||
|
for (int node_id : subgraph_node_ids) {
|
||||||
|
tensorflow::Node* node = graph.FindNodeId(node_id);
|
||||||
|
if (output_name_to_index_map.count(node->name())) {
|
||||||
|
for (int index : output_name_to_index_map.at(node->name())) {
|
||||||
|
subgraph_outputs_set.insert({node_id, index});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Collect outputs referenced from outgoing edges
|
||||||
|
tensorflow::EdgeSet subgraph_outgoing_edges;
|
||||||
|
// GetSubGraphOutgoingEdges(graph, subgraph_node_ids_no_placeholder,
|
||||||
|
// &subgraph_outgoing_edges);
|
||||||
|
GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges);
|
||||||
|
for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
|
||||||
|
subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
|
||||||
|
}
|
||||||
|
// Impose an ordering on the outputs
|
||||||
|
std::vector<std::pair<int, int>> subgraph_outputs(
|
||||||
|
subgraph_outputs_set.begin(), subgraph_outputs_set.end());
|
||||||
|
// Build TensorRT node and add it to the graph
|
||||||
|
tensorflow::NodeDef trt_node_def;
|
||||||
|
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
|
||||||
|
graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
|
||||||
|
max_batch_size, max_workspace_size, shape_map, &trt_node_def));
|
||||||
|
tensorflow::Status status;
|
||||||
|
tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
|
// Re-map outgoing edges to use the new TRT node instead of the orig subgraph
|
||||||
|
std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
|
||||||
|
for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
|
||||||
|
subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
|
||||||
|
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
|
||||||
|
int new_src_output = subgraph_edge_to_output_map.at(old_src);
|
||||||
|
graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
|
||||||
|
}
|
||||||
|
// Remove the original subgraph
|
||||||
|
for (int node_id : subgraph_node_ids) {
|
||||||
|
tensorflow::Node* node = graph.FindNodeId(node_id);
|
||||||
|
// Don't remove the input placeholders
|
||||||
|
if (node->type_string() == "Placeholder") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
graph.RemoveNode(node);
|
||||||
|
}
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status BuildNodeMap(
|
||||||
|
const tensorflow::Graph& graph,
|
||||||
|
std::unordered_map<std::string, tensorflow::Node*>* node_map) {
|
||||||
|
for (auto* node : graph.op_nodes()) {
|
||||||
|
if (!node_map->insert({node->name(), node}).second) {
|
||||||
|
return tensorflow::errors::AlreadyExists(
|
||||||
|
"Node name is not unique in graph: " + node->name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||||
|
const tensorflow::GraphDef& graph_def,
|
||||||
|
const std::vector<std::string>& output_names, size_t max_batch_size,
|
||||||
|
size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) {
|
||||||
|
ShapeMap shape_map;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
tensorflow::trt::inferShapes(graph_def, output_names, shape_map));
|
||||||
|
std::stringstream oss;
|
||||||
|
for (auto& n : shape_map) { // nodes
|
||||||
|
oss << " Node= " << n.first << ", ";
|
||||||
|
for (auto o : n.second) { // outputs
|
||||||
|
oss << o.first.DebugString() << " T= " << o.second << ", ";
|
||||||
|
}
|
||||||
|
LOG(DEBUG) << oss.str();
|
||||||
|
oss.str("");
|
||||||
|
}
|
||||||
|
// Build full graph
|
||||||
|
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
|
||||||
|
graph_def.library());
|
||||||
|
tensorflow::Graph graph(flib);
|
||||||
|
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
|
||||||
|
tensorflow::GraphConstructorOptions(), graph_def, &graph));
|
||||||
|
|
||||||
|
// Segment the graph into subgraphs that can be converted to TensorRT
|
||||||
|
tensorrt::segment::SegmentOptions segment_options;
|
||||||
|
// TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
|
||||||
|
for (auto node : output_names) output_nodes.insert(node);
|
||||||
|
|
||||||
|
// TODO(sami): this should be passed as a knob!!!!
|
||||||
|
segment_options.minimum_segment_size = 2;
|
||||||
|
tensorrt::segment::SegmentNodesVector segments;
|
||||||
|
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
|
||||||
|
graph_def, IsTensorRTCandidate, segment_options, &segments));
|
||||||
|
if (segments.size() > 1) {
|
||||||
|
// LOG(WARNING) << "Multiple TensorRT candidate subgraphs were found, "
|
||||||
|
//<< "but only the first can be converted.";
|
||||||
|
// segments.erase(++segments.begin(), segments.end());
|
||||||
|
LOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
|
||||||
|
}
|
||||||
|
std::unordered_map<std::string, tensorflow::Node*> node_map;
|
||||||
|
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
|
||||||
|
for (std::set<std::string> const& subgraph_node_names : segments) {
|
||||||
|
std::set<int> subgraph_node_ids;
|
||||||
|
for (std::string const& node_name : subgraph_node_names) {
|
||||||
|
subgraph_node_ids.insert(node_map.at(node_name)->id());
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
|
||||||
|
graph, output_names, subgraph_node_ids, max_batch_size,
|
||||||
|
max_workspace_size, shape_map));
|
||||||
|
}
|
||||||
|
graph.ToGraphDef(new_graph_def);
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace convert
|
||||||
|
} // namespace tensorrt
|
34
tensorflow/contrib/tensorrt/convert/convert_graph.h
Normal file
34
tensorflow/contrib/tensorrt/convert/convert_graph.h
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace convert {
|
||||||
|
|
||||||
|
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||||
|
const tensorflow::GraphDef& graph_def,
|
||||||
|
const std::vector<std::string>& output_names, size_t max_batch_size,
|
||||||
|
size_t max_workspace_size, tensorflow::GraphDef* new_graph_def);
|
||||||
|
}
|
||||||
|
} // namespace tensorrt
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
|
1737
tensorflow/contrib/tensorrt/convert/convert_nodes.cc
Normal file
1737
tensorflow/contrib/tensorrt/convert/convert_nodes.cc
Normal file
File diff suppressed because it is too large
Load Diff
42
tensorflow/contrib/tensorrt/convert/convert_nodes.h
Normal file
42
tensorflow/contrib/tensorrt/convert/convert_nodes.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace convert {
|
||||||
|
|
||||||
|
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
|
||||||
|
const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
|
||||||
|
const std::vector<std::pair<int, int>>&
|
||||||
|
input_inds, // {node_id, output_idx}
|
||||||
|
const std::vector<std::pair<int, int>>&
|
||||||
|
output_inds, // {node_id, output_idx}
|
||||||
|
size_t max_batch_size, size_t max_workspace_size, const ShapeMap& shape_map,
|
||||||
|
tensorflow::NodeDef* trt_node);
|
||||||
|
} // namespace convert
|
||||||
|
} // namespace tensorrt
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
|
125
tensorflow/contrib/tensorrt/convert/inferShapes.cc
Normal file
125
tensorflow/contrib/tensorrt/convert/inferShapes.cc
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
|
||||||
|
#include <functional>
|
||||||
|
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb_text.h"
|
||||||
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace trt {
|
||||||
|
std::vector<tensorflow::DataType> getTypes(const tensorflow::OpDef& op,
|
||||||
|
const tensorflow::NodeDef& nd,
|
||||||
|
bool inp = true) {
|
||||||
|
const auto& attrMap = nd.attr();
|
||||||
|
auto getType = [&attrMap](decltype(
|
||||||
|
op.input_arg(0)) a) -> std::vector<tensorflow::DataType> {
|
||||||
|
std::vector<tensorflow::DataType> tvec;
|
||||||
|
if (!a.type_list_attr().empty()) { // get the list types
|
||||||
|
const auto& tl = attrMap.at(a.type_list_attr()).list();
|
||||||
|
int tsize = tl.type_size();
|
||||||
|
tvec.reserve(tsize);
|
||||||
|
for (int t = 0; t < tsize; t++) {
|
||||||
|
tvec.push_back(tl.type(t));
|
||||||
|
}
|
||||||
|
return tvec;
|
||||||
|
}
|
||||||
|
tensorflow::DataType cType = tensorflow::DT_INVALID;
|
||||||
|
if (a.type() != tensorflow::DT_INVALID) { // get defined types
|
||||||
|
cType = a.type();
|
||||||
|
} else if (!a.type_attr().empty()) {
|
||||||
|
cType = attrMap.at(a.type_attr()).type();
|
||||||
|
}
|
||||||
|
if (!a.number_attr().empty()) { // numbertypes
|
||||||
|
int64 nTensors = attrMap.at(a.number_attr()).i();
|
||||||
|
tvec = std::vector<tensorflow::DataType>(nTensors, cType);
|
||||||
|
return tvec;
|
||||||
|
}
|
||||||
|
tvec.push_back(cType);
|
||||||
|
return tvec;
|
||||||
|
};
|
||||||
|
std::vector<tensorflow::DataType> types;
|
||||||
|
if (inp) {
|
||||||
|
int n_inputs = op.input_arg_size();
|
||||||
|
for (int i = 0; i < n_inputs; i++) {
|
||||||
|
auto tout = getType(op.input_arg(i));
|
||||||
|
LOG(DEBUG) << "Node= " << nd.name() << " #inputs" << tout.size();
|
||||||
|
types.insert(types.end(), tout.begin(), tout.end());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int n_outputs = op.output_arg_size();
|
||||||
|
// types.resize(n_outputs);
|
||||||
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
|
auto tout = getType(op.output_arg(i));
|
||||||
|
LOG(DEBUG) << "Node= " << nd.name() << " #outputs" << tout.size();
|
||||||
|
types.insert(types.end(), tout.begin(), tout.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return types;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
|
||||||
|
const std::vector<std::string>& output_names,
|
||||||
|
ShapeMap& shapes) {
|
||||||
|
tensorflow::Graph g(OpRegistry::Global());
|
||||||
|
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
|
||||||
|
tensorflow::GraphConstructorOptions(), graph_def, &g));
|
||||||
|
std::vector<tensorflow::Node*> POnodes;
|
||||||
|
tensorflow::GetPostOrder(g, &POnodes);
|
||||||
|
tensorflow::ShapeRefiner refiner(graph_def.versions().producer(),
|
||||||
|
OpRegistry::Global());
|
||||||
|
for (auto n = POnodes.rbegin(); n != POnodes.rend(); ++n) {
|
||||||
|
TF_CHECK_OK(refiner.AddNode(*n));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape2PTS = [](tensorflow::shape_inference::InferenceContext* ic,
|
||||||
|
const tensorflow::shape_inference::ShapeHandle& sh)
|
||||||
|
-> tensorflow::PartialTensorShape {
|
||||||
|
std::vector<int64> dims;
|
||||||
|
int64 rank = ic->Rank(sh);
|
||||||
|
for (int64 i = 0; i < rank; i++) {
|
||||||
|
auto dh = ic->Dim(sh, i);
|
||||||
|
dims.push_back(ic->Value(dh));
|
||||||
|
}
|
||||||
|
return tensorflow::PartialTensorShape(dims);
|
||||||
|
};
|
||||||
|
for (const auto& n : POnodes) {
|
||||||
|
auto ic = refiner.GetContext(n);
|
||||||
|
if (ic) {
|
||||||
|
int nOuts = ic->num_outputs();
|
||||||
|
auto types = getTypes(n->op_def(), n->def(), false);
|
||||||
|
std::vector<
|
||||||
|
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>
|
||||||
|
SAT;
|
||||||
|
for (int i = 0; i < nOuts; i++) {
|
||||||
|
auto PTS = shape2PTS(ic, ic->output(i));
|
||||||
|
SAT.push_back({PTS, types.at(i)});
|
||||||
|
}
|
||||||
|
shapes[n->name()] = SAT;
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << "Node " << n->name() << " doesn't have InferenceContext!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace trt
|
||||||
|
} // namespace tensorflow
|
39
tensorflow/contrib/tensorrt/convert/inferShapes.h
Normal file
39
tensorflow/contrib/tensorrt/convert/inferShapes.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
typedef std::unordered_map<std::string,
|
||||||
|
std::vector<std::pair<tensorflow::PartialTensorShape,
|
||||||
|
tensorflow::DataType>>>
|
||||||
|
ShapeMap;
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace trt {
|
||||||
|
tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
|
||||||
|
const std::vector<std::string>& output_names,
|
||||||
|
ShapeMap& shapes);
|
||||||
|
}
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
|
183
tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
Normal file
183
tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include <sstream>
|
||||||
|
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
// Use TF logging f
|
||||||
|
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
static ::tensorflow::tensorrt::Logger gLogger;
|
||||||
|
|
||||||
|
using namespace nvinfer1;
|
||||||
|
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
|
// char *gieModelStream{nullptr};
|
||||||
|
// size_t size{0};
|
||||||
|
|
||||||
|
// read serialized_engine
|
||||||
|
std::string serialized_engine;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("serialized_engine", &serialized_engine));
|
||||||
|
|
||||||
|
// register input output node name in trt_sub_graph
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
|
||||||
|
|
||||||
|
// TODO(samikama) runtime should be taken from a resourcemanager as well.
|
||||||
|
// Only engine should be in the op and context and runtime should be taken
|
||||||
|
// from resourcemanager
|
||||||
|
IRuntime* infer = createInferRuntime(gLogger);
|
||||||
|
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
|
||||||
|
serialized_engine.c_str(), serialized_engine.size(), nullptr));
|
||||||
|
|
||||||
|
trt_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
|
||||||
|
// runtime is safe to delete after engine creation
|
||||||
|
infer->destroy();
|
||||||
|
std::stringstream oss;
|
||||||
|
// debug iterate through all binding instances
|
||||||
|
for (int i = 0; i < trt_engine_ptr_->getNbBindings(); i++) {
|
||||||
|
LOG(INFO) << "index: " << i
|
||||||
|
<< ", binding name: " << trt_engine_ptr_->getBindingName(i);
|
||||||
|
|
||||||
|
if (trt_engine_ptr_->bindingIsInput(i)) {
|
||||||
|
LOG(INFO) << "INPUT";
|
||||||
|
} else {
|
||||||
|
LOG(INFO) << "OUTPUT";
|
||||||
|
}
|
||||||
|
oss << "Dimension: ";
|
||||||
|
auto dims = trt_engine_ptr_->getBindingDimensions(i);
|
||||||
|
oss << " nbDims: " << dims.nbDims << " -> ";
|
||||||
|
for (int j = 0; j < Dims::MAX_DIMS; j++) {
|
||||||
|
oss << dims.d[j] << ", ";
|
||||||
|
}
|
||||||
|
LOG(INFO) << oss.str();
|
||||||
|
oss.str("");
|
||||||
|
switch (trt_engine_ptr_->getBindingDataType(i)) {
|
||||||
|
case nvinfer1::DataType::kFLOAT:
|
||||||
|
LOG(INFO) << "data type float" << std::endl;
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kHALF:
|
||||||
|
LOG(INFO) << "data type half" << std::endl;
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kINT8:
|
||||||
|
LOG(INFO) << "data type int8" << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK_NE(cudaStreamCreate(&stream_),0); // logic here is wrong
|
||||||
|
// cudaStreamCreate(&stream_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||||
|
int nbBindings = context->num_inputs() + context->num_outputs();
|
||||||
|
// TODO(jjsjann123) multiple input/output
|
||||||
|
std::vector<void*> buffers(nbBindings);
|
||||||
|
|
||||||
|
size_t bindingIndex;
|
||||||
|
int nbBatch = 0;
|
||||||
|
bool valid = true;
|
||||||
|
for (int i = 0; i < context->num_inputs(); i++) {
|
||||||
|
// Grab the input tensor
|
||||||
|
bindingIndex = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
|
||||||
|
|
||||||
|
const Tensor& input_tensor = context->input(i);
|
||||||
|
const TensorShape& input_shape = input_tensor.shape();
|
||||||
|
if (i == 0) {
|
||||||
|
nbBatch = input_shape.dim_size(0);
|
||||||
|
} else if (nbBatch != input_shape.dim_size(0)) {
|
||||||
|
valid = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// int64 input_shape.dim_size(int d)
|
||||||
|
// int input_shape.dims()
|
||||||
|
switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
|
||||||
|
case nvinfer1::DataType::kFLOAT:
|
||||||
|
LOG(INFO) << "float";
|
||||||
|
buffers[bindingIndex] = (void*)(input_tensor.flat<float>().data());
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kHALF:
|
||||||
|
LOG(INFO) << "half";
|
||||||
|
// buffers[bindingIndex] = (void*)input_tensor.flat<float16>().data();
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kINT8:
|
||||||
|
LOG(INFO) << "int8";
|
||||||
|
// buffers[bindingIndex] = (void*)input_tensor.flat<int8>().data();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!valid) LOG(WARNING) << "input data inconsistent batch size";
|
||||||
|
|
||||||
|
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
|
||||||
|
// This is bad that we have to reallocate output buffer every run.
|
||||||
|
// Create an output tensor
|
||||||
|
bindingIndex = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
|
||||||
|
Tensor* output_tensor = NULL;
|
||||||
|
|
||||||
|
TensorShape output_shape;
|
||||||
|
if (bindingIndex != -1) {
|
||||||
|
LOG(INFO) << "got binding " << bindingIndex;
|
||||||
|
auto dims = trt_engine_ptr_->getBindingDimensions(bindingIndex);
|
||||||
|
std::vector<int> trt_shape(dims.nbDims + 1);
|
||||||
|
trt_shape[0] = nbBatch;
|
||||||
|
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
|
||||||
|
TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
|
||||||
|
&output_shape);
|
||||||
|
} else {
|
||||||
|
LOG(INFO) << "no binding ";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(i, output_shape, &output_tensor));
|
||||||
|
// buffers[bindingIndex] = (void*)output_tensor->flat<float>();
|
||||||
|
// buffers[bindingIndex] = output_tensor->flat<float>().data();
|
||||||
|
switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
|
||||||
|
case nvinfer1::DataType::kFLOAT:
|
||||||
|
LOG(INFO) << "float";
|
||||||
|
buffers[bindingIndex] =
|
||||||
|
reinterpret_cast<void*>(output_tensor->flat<float>().data());
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kHALF:
|
||||||
|
LOG(INFO) << "half";
|
||||||
|
// buffers[bindingIndex] = (void*)output_tensor->flat<float16>().data();
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kINT8:
|
||||||
|
LOG(INFO) << "int8";
|
||||||
|
// buffers[bindingIndex] = (void*)output_tensor->flat<int8>().data();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
||||||
|
const cudaStream_t* stream = CHECK_NOTNULL(
|
||||||
|
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
|
||||||
|
->stream()
|
||||||
|
->implementation()
|
||||||
|
->CudaStreamMemberHack()));
|
||||||
|
|
||||||
|
trt_context_ptr_->enqueue(nbBatch, &buffers[0], *stream, nullptr);
|
||||||
|
cudaStreamSynchronize(*stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
|
||||||
|
} // namespace tensorrt
|
||||||
|
} // namespace tensorflow
|
55
tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
Normal file
55
tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
|
||||||
|
|
||||||
|
#include <NvInfer.h>
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace tensorrt {
|
||||||
|
class Logger;
|
||||||
|
class TRTEngineOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit TRTEngineOp(OpKernelConstruction* context);
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
struct Destroyer {
|
||||||
|
void operator()(T* d) { d->destroy(); }
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
|
||||||
|
destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
|
||||||
|
// TODO(samikama) context should go to a resource manager!
|
||||||
|
destroyed_ptr<nvinfer1::IExecutionContext> trt_context_ptr_;
|
||||||
|
std::vector<string> input_nodes_;
|
||||||
|
std::vector<string> output_nodes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
|
56
tensorflow/contrib/tensorrt/log/trt_logger.cc
Normal file
56
tensorflow/contrib/tensorrt/log/trt_logger.cc
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||||
|
// Use TF logging for TensorRT informations
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
void Logger::log(Severity severity, const char* msg) {
|
||||||
|
// suppress info-level messages
|
||||||
|
switch (severity) {
|
||||||
|
case Severity::kINFO: { // mark TRT info messages as debug!
|
||||||
|
LOG(DEBUG) << msg;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Severity::kWARNING: {
|
||||||
|
LOG(WARNING) << msg;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Severity::kERROR: {
|
||||||
|
LOG(ERROR) << msg;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Severity::kINTERNAL_ERROR: {
|
||||||
|
LOG(FATAL) << msg;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// This is useless for now. But would catch it in future if enum changes. It
|
||||||
|
// is always good to have default case!
|
||||||
|
default: {
|
||||||
|
LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
41
tensorflow/contrib/tensorrt/log/trt_logger.h
Normal file
41
tensorflow/contrib/tensorrt/log/trt_logger.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
// -*- c++ -*-
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
|
||||||
|
|
||||||
|
// Use TF logging f
|
||||||
|
#include <NvInfer.h>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
namespace tensorrt {
|
||||||
|
|
||||||
|
// Logger for GIE info/warning/errors
|
||||||
|
class Logger : public nvinfer1::ILogger {
|
||||||
|
void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorrt
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
|
37
tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
Normal file
37
tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace shape_inference {
|
||||||
|
extern Status TRTEngineOpShapeInference(InferenceContext* c);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_OP("TRTEngineOp")
|
||||||
|
.Attr("serialized_engine: string")
|
||||||
|
.Attr("input_nodes: list(string)")
|
||||||
|
.Attr("output_nodes: list(string)")
|
||||||
|
.Attr("InT: list({int8, float16, float32})")
|
||||||
|
.Attr("OutT: list({int8, float16, float32})")
|
||||||
|
.Input("in_tensor: InT")
|
||||||
|
.Output("out_tensor: OutT")
|
||||||
|
.SetShapeFn(shape_inference::TRTEngineOpShapeInference);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
8
tensorflow/contrib/tensorrt/python/__init__.py
Normal file
8
tensorflow/contrib/tensorrt/python/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=unused-import,wildcard-import
|
||||||
|
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
|
||||||
|
from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph
|
||||||
|
# pylint: enable=unused-import,wildcard-import
|
35
tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
Normal file
35
tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import platform
|
||||||
|
|
||||||
|
if platform.system() != "Windows":
|
||||||
|
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
||||||
|
from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
|
||||||
|
|
||||||
|
from tensorflow.contrib.util import loader
|
||||||
|
from tensorflow.python.platform import resource_loader
|
||||||
|
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
|
||||||
|
|
||||||
|
_trt_engine_op = loader.load_op_library(
|
||||||
|
resource_loader.get_path_to_datafile("_trt_engine_op.so"))
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Windows platforms are not supported")
|
||||||
|
|
||||||
|
|
91
tensorflow/contrib/tensorrt/python/trt_convert.py
Normal file
91
tensorflow/contrib/tensorrt/python/trt_convert.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
"""Exposes the Python wrapper conversion to trt_graph."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||||
|
from tensorflow.core.framework import graph_pb2
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import errors_impl as _impl
|
||||||
|
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
|
||||||
|
from tensorflow.python.util import compat
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
|
from tensorflow.python.framework import meta_graph
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
|
||||||
|
|
||||||
|
def CreateInferenceGraph(input_graph_def, outputs,max_batch_size=1,max_workspace_size=2<<20):
|
||||||
|
"""Python wrapper for the TRT transormation.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_graph_def: GraphDef object containing a model to be transformed.
|
||||||
|
outputs: List of node names for the model outputs.
|
||||||
|
max_batch_size: max size for the input batch
|
||||||
|
max_workspace_size: parameter to control memory allocation (in Bytes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
# output_graph_def_string = trt_convert(
|
||||||
|
# input_graph_def_string,outputs,
|
||||||
|
# max_batch_size,max_workspace_size, status)
|
||||||
|
g = tf.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
tf.import_graph_def(input_graph_def, name="")
|
||||||
|
rewriter_config = rewriter_config_pb2.RewriterConfig()
|
||||||
|
rewriter_config.optimizers.append('layout')
|
||||||
|
rewriter_config.optimizers.append('constfold')
|
||||||
|
|
||||||
|
# mark output nodes as fetch
|
||||||
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||||||
|
for node_name in outputs:
|
||||||
|
out_node = g.get_operation_by_name(node_name)
|
||||||
|
for i in range(0,len(out_node.outputs)):
|
||||||
|
train_op.append(out_node.outputs[0])
|
||||||
|
|
||||||
|
# constant folding
|
||||||
|
mg = meta_graph.create_meta_graph_def(graph=g)
|
||||||
|
meta_graph.add_collection_def(mg, ops.GraphKeys.TRAIN_OP)
|
||||||
|
optimized_graph_def_str = \
|
||||||
|
tf_optimizer.OptimizeGraph(rewriter_config, mg).SerializeToString()
|
||||||
|
|
||||||
|
# TODO(sami): Fix this when we can return status from C++ library
|
||||||
|
# There is a problem with the TF internal library setup that doesn't allow us to return a status object from C++.
|
||||||
|
# Thus we return a pair or strings where first one is encoded status and the second one is the
|
||||||
|
# transformed graphs protobuf string.
|
||||||
|
out = trt_convert(
|
||||||
|
optimized_graph_def_str ,outputs,
|
||||||
|
max_batch_size,max_workspace_size)
|
||||||
|
status = out[0]
|
||||||
|
output_graph_def_string = out[1]
|
||||||
|
del optimized_graph_def_str #save some memory
|
||||||
|
if len(status) < 2:
|
||||||
|
raise _impl.UnknownError(None,None,status)
|
||||||
|
if status[:2] != "OK":
|
||||||
|
msg=status.split(";")
|
||||||
|
if len(msg) == 1:
|
||||||
|
raise RuntimeError("Status message is malformed {}".format(status))
|
||||||
|
raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
|
||||||
|
output_graph_def = graph_pb2.GraphDef()
|
||||||
|
output_graph_def.ParseFromString(output_graph_def_string)
|
||||||
|
del output_graph_def_string #save some memory
|
||||||
|
return output_graph_def
|
259
tensorflow/contrib/tensorrt/segment/segment.cc
Normal file
259
tensorflow/contrib/tensorrt/segment/segment.cc
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/tensorrt/segment/segment.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/tensorrt/segment/union_find.h"
|
||||||
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace segment {
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
bool CanContractEdge(const tensorflow::Edge* edge,
|
||||||
|
const tensorflow::Graph& graph) {
|
||||||
|
const tensorflow::Node* src = edge->src();
|
||||||
|
const tensorflow::Node* dst = edge->dst();
|
||||||
|
|
||||||
|
// Can't contract edge if doing so would cause a cycle in the
|
||||||
|
// graph. So, if there is a directed path from 'src' to 'dst', other
|
||||||
|
// than 'edge' (or any other direct edge from 'src' to 'dst'), then
|
||||||
|
// combining 'src' and 'dst' will cause a cycle along that path.
|
||||||
|
//
|
||||||
|
// In practice, to avoid modifying the graph and to take advantage
|
||||||
|
// of existing graph functions, we perform an equivalent.
|
||||||
|
// 1. Get all nodes incoming to 'dst', excluding 'src'
|
||||||
|
// 2. Reverse DFS from those nodes
|
||||||
|
// 3. If reverse DFS reaches 'src' then we have a cycle
|
||||||
|
std::vector<tensorflow::Node*> dfs_start_nodes;
|
||||||
|
for (tensorflow::Node* node : dst->in_nodes()) {
|
||||||
|
if (node != src) {
|
||||||
|
dfs_start_nodes.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_cycle = false;
|
||||||
|
if (!dfs_start_nodes.empty()) {
|
||||||
|
tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
|
||||||
|
[&is_cycle, src](tensorflow::Node* node) {
|
||||||
|
if (node == src) {
|
||||||
|
is_cycle = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return !is_cycle;
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
|
||||||
|
std::vector<const tensorflow::Edge*>* remove_edges) {
|
||||||
|
// Transfer all inputs and outputs of 'dst' to 'src' except edges
|
||||||
|
// connecting the two.
|
||||||
|
tensorflow::Node* src = edge->src();
|
||||||
|
tensorflow::Node* dst = edge->dst();
|
||||||
|
|
||||||
|
// We can use '0' for input/output index because we don't need them
|
||||||
|
// to be accurate for the way we are using the graph.
|
||||||
|
std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
|
||||||
|
dst->in_edges().end());
|
||||||
|
for (const tensorflow::Edge* in_edge : in_edges) {
|
||||||
|
if (in_edge->src() != src) {
|
||||||
|
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
|
||||||
|
if (e->src() == graph->source_node()) {
|
||||||
|
graph->AddEdge(e->src(), e->src_output(), src,
|
||||||
|
tensorflow::Graph::kControlSlot);
|
||||||
|
} else {
|
||||||
|
graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
|
||||||
|
dst->out_edges().end());
|
||||||
|
for (const tensorflow::Edge* out_edge : out_edges) {
|
||||||
|
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
|
||||||
|
if (e->dst() == graph->sink_node()) {
|
||||||
|
graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
|
||||||
|
e->dst_input());
|
||||||
|
} else {
|
||||||
|
graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the edges that must be removed to disconnect 'dst' from
|
||||||
|
// the graph. We don't actually remove 'dst' since the caller holds
|
||||||
|
// references to all the nodes.
|
||||||
|
for (const auto& in_edge : dst->in_edges()) {
|
||||||
|
remove_edges->push_back(in_edge);
|
||||||
|
}
|
||||||
|
for (const auto& out_edge : dst->out_edges()) {
|
||||||
|
remove_edges->push_back(out_edge);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
tensorflow::Status SegmentGraph(
|
||||||
|
const tensorflow::GraphDef& gdef,
|
||||||
|
const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
|
||||||
|
const SegmentOptions& options, SegmentNodesVector* segments) {
|
||||||
|
// Create a Graph representation of the GraphDef.
|
||||||
|
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
|
||||||
|
gdef.library());
|
||||||
|
tensorflow::Graph graph(flib);
|
||||||
|
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
|
||||||
|
tensorflow::GraphConstructorOptions(), gdef, &graph));
|
||||||
|
|
||||||
|
// tensorflow::DumpGraph("Pre-Segment", &graph);
|
||||||
|
|
||||||
|
// Use a union-find to collect the nodes that belong to the same
|
||||||
|
// segment. A node value of nullptr indicates that the node is not a
|
||||||
|
// candidate for TRT.
|
||||||
|
std::vector<UnionFind<tensorflow::Node*>> node_segments;
|
||||||
|
for (int i = 0; i < graph.num_node_ids(); ++i) {
|
||||||
|
tensorflow::Node* node = graph.FindNodeId(i);
|
||||||
|
if (!candidate_fn(node->def())) {
|
||||||
|
node = nullptr;
|
||||||
|
}
|
||||||
|
node_segments.emplace_back(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visit nodes in reverse topological order and use edge
|
||||||
|
// contraction to merge candidate nodes.
|
||||||
|
std::vector<tensorflow::Node*> order;
|
||||||
|
tensorflow::GetPostOrder(graph, &order);
|
||||||
|
|
||||||
|
for (const tensorflow::Node* node : order) {
|
||||||
|
// All output nodes of 'node' have been visited...
|
||||||
|
VLOG(2) << "Trying node " << node->name();
|
||||||
|
|
||||||
|
// 'node' must be a TRT candidate...
|
||||||
|
if (node_segments[node->id()].Value() == nullptr) {
|
||||||
|
VLOG(2) << "... not a TRT candidate";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contract output edges to combine 'node' with output
|
||||||
|
// nodes. Iterate since combining two nodes may unblock other
|
||||||
|
// combining.
|
||||||
|
while (true) {
|
||||||
|
std::set<const tensorflow::Edge*> contract_edges;
|
||||||
|
for (const tensorflow::Edge* out_edge : node->out_edges()) {
|
||||||
|
VLOG(2) << "... out node " << out_edge->dst()->name();
|
||||||
|
|
||||||
|
// Out node must be TRT candidate...
|
||||||
|
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
|
||||||
|
VLOG(2) << "... ... not a TRT candidate";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (CanContractEdge(out_edge, graph)) {
|
||||||
|
VLOG(2) << "... ... can contract";
|
||||||
|
contract_edges.insert(out_edge);
|
||||||
|
} else {
|
||||||
|
VLOG(2) << "... ... cannot contract, would form cycle";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (contract_edges.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contract edges and collect the adjacent nodes into the same
|
||||||
|
// segment/subgraph.
|
||||||
|
while (!contract_edges.empty()) {
|
||||||
|
const tensorflow::Edge* contract_edge = *contract_edges.begin();
|
||||||
|
const tensorflow::Node* src = contract_edge->src();
|
||||||
|
const tensorflow::Node* dst = contract_edge->dst();
|
||||||
|
|
||||||
|
VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
|
||||||
|
node_segments[src->id()].Merge(&node_segments[dst->id()]);
|
||||||
|
|
||||||
|
// Contracting the edge leaves disconnected graph edges.
|
||||||
|
// Remove these from the graph and from 'contract_edges' so we
|
||||||
|
// don't visit them again.
|
||||||
|
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
|
||||||
|
std::vector<const tensorflow::Edge*> remove_edges;
|
||||||
|
ContractEdge(e, &graph, &remove_edges);
|
||||||
|
|
||||||
|
for (const tensorflow::Edge* r : remove_edges) {
|
||||||
|
contract_edges.erase(r);
|
||||||
|
graph.RemoveEdge(r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect the segments/subgraphs. Each subgraph is represented by a
|
||||||
|
// set of the names of the nodes in that subgraph.
|
||||||
|
std::unordered_map<std::string, std::set<std::string>> sg_map;
|
||||||
|
for (auto& u : node_segments) {
|
||||||
|
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
|
||||||
|
sg_map[u.ParentValue()->name()].insert(u.Value()->name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup the graph to remove disconnected nodes before outputting
|
||||||
|
if (VLOG_IS_ON(2)) {
|
||||||
|
for (tensorflow::Node* node : graph.nodes()) {
|
||||||
|
if ((node->in_edges().size() == 0) && (node->out_edges().size() == 0)) {
|
||||||
|
graph.RemoveNode(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// tensorflow::DumpGraph("Post-Segment", &graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the segments into the expected return format
|
||||||
|
for (const auto& itr : sg_map) {
|
||||||
|
const auto& segment_node_names = itr.second;
|
||||||
|
if (VLOG_IS_ON(1)) {
|
||||||
|
std::string s;
|
||||||
|
for (const auto& name : segment_node_names) {
|
||||||
|
s += " " + name;
|
||||||
|
}
|
||||||
|
VLOG(1) << "Segment " << segments->size() << ":" << s;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't use small segments.
|
||||||
|
if (static_cast<int>(segment_node_names.size()) <
|
||||||
|
options.minimum_segment_size) {
|
||||||
|
VLOG(1) << "Segment " << segments->size() << " has only "
|
||||||
|
<< segment_node_names.size() << " nodes, dropping";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
segments->emplace_back(segment_node_names);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace segment
|
||||||
|
} // namespace tensorrt
|
53
tensorflow/contrib/tensorrt/segment/segment.h
Normal file
53
tensorflow/contrib/tensorrt/segment/segment.h
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace segment {
|
||||||
|
|
||||||
|
using SegmentNodesVector = std::vector<std::set<std::string>>;
|
||||||
|
|
||||||
|
struct SegmentOptions {
|
||||||
|
// Segment must contain at least this many nodes.
|
||||||
|
int minimum_segment_size = 2;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get the subgraphs of a graph that can be handled by TensorRT.
|
||||||
|
//
|
||||||
|
// @param gdef The GraphDef describing the network
|
||||||
|
// @param candidate_fn A function that returns true for a NodeDef if
|
||||||
|
// that node can be handled by TensorRT.
|
||||||
|
// @param segments Returns the TensorRT segments/subgraphs. Each entry
|
||||||
|
// in the vector describes a subgraph by giving a set of the names of
|
||||||
|
// all the NodeDefs in that subgraph.
|
||||||
|
// @return the status.
|
||||||
|
tensorflow::Status SegmentGraph(
|
||||||
|
const tensorflow::GraphDef& gdef,
|
||||||
|
const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
|
||||||
|
const SegmentOptions& options, SegmentNodesVector* segments);
|
||||||
|
|
||||||
|
} // namespace segment
|
||||||
|
} // namespace tensorrt
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
|
363
tensorflow/contrib/tensorrt/segment/segment_test.cc
Normal file
363
tensorflow/contrib/tensorrt/segment/segment_test.cc
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/tensorrt/segment/segment.h"
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
using namespace tensorflow;
|
||||||
|
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace segment {
|
||||||
|
namespace test {
|
||||||
|
|
||||||
|
class SegmentTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
|
||||||
|
|
||||||
|
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
|
||||||
|
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||||
|
TF_Status* s, const char* name);
|
||||||
|
|
||||||
|
std::function<bool(const NodeDef&)> MakeCandidateFn(
|
||||||
|
const std::set<std::string>& node_names);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
|
||||||
|
TF_Operation** op);
|
||||||
|
void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||||
|
TF_Status* s, const char* name, TF_Operation** op, bool check);
|
||||||
|
|
||||||
|
SegmentOptions default_options_;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool SegmentTest::GetGraphDef(TF_Graph* graph,
|
||||||
|
tensorflow::GraphDef* graph_def) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Buffer* buffer = TF_NewBuffer();
|
||||||
|
TF_GraphToGraphDef(graph, buffer, s);
|
||||||
|
bool ret = TF_GetCode(s) == TF_OK;
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
|
||||||
|
TF_DeleteBuffer(buffer);
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn(
|
||||||
|
const std::set<std::string>& node_names) {
|
||||||
|
return [node_names](const NodeDef& node) -> bool {
|
||||||
|
return node_names.find(node.name()) != node_names.end();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
|
||||||
|
const char* name, TF_Operation** op) {
|
||||||
|
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
|
||||||
|
TF_SetAttrType(desc, "dtype", TF_INT32);
|
||||||
|
*op = TF_FinishOperation(desc, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
ASSERT_NE(*op, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
|
||||||
|
const char* name) {
|
||||||
|
TF_Operation* op;
|
||||||
|
PlaceholderHelper(graph, s, name, &op);
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||||
|
TF_Status* s, const char* name, TF_Operation** op,
|
||||||
|
bool check) {
|
||||||
|
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
|
||||||
|
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
|
||||||
|
TF_AddInputList(desc, add_inputs, 2);
|
||||||
|
*op = TF_FinishOperation(desc, s);
|
||||||
|
if (check) {
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
ASSERT_NE(*op, nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
|
||||||
|
TF_Graph* graph, TF_Status* s,
|
||||||
|
const char* name) {
|
||||||
|
TF_Operation* op;
|
||||||
|
AddHelper(l, r, graph, s, name, &op, true);
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
TEST_F(SegmentTest, Empty) {
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||||
|
|
||||||
|
SegmentNodesVector segments;
|
||||||
|
ASSERT_EQ(
|
||||||
|
SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
|
||||||
|
tensorflow::Status::OK());
|
||||||
|
|
||||||
|
// Expect no segments/subgraphs.
|
||||||
|
EXPECT_TRUE(segments.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
TEST_F(SegmentTest, Simple) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
// feed
|
||||||
|
// // ||
|
||||||
|
// add0 add1
|
||||||
|
// | | /
|
||||||
|
// | add2
|
||||||
|
// | / ||
|
||||||
|
// add3 add4
|
||||||
|
// | /
|
||||||
|
// <sink>
|
||||||
|
//
|
||||||
|
TF_Operation* feed = Placeholder(graph, s, "feed");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||||
|
|
||||||
|
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
|
||||||
|
TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||||
|
|
||||||
|
SegmentNodesVector segments;
|
||||||
|
ASSERT_EQ(
|
||||||
|
SegmentGraph(graph_def,
|
||||||
|
MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
|
||||||
|
default_options_, &segments),
|
||||||
|
tensorflow::Status::OK());
|
||||||
|
|
||||||
|
// Expect all Add operations to be collapsed into a single segment
|
||||||
|
ASSERT_EQ(segments.size(), 1);
|
||||||
|
std::vector<std::string> expected{"add0", "add1", "add2", "add3", "add4"};
|
||||||
|
for (const auto& ex : expected) {
|
||||||
|
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
|
||||||
|
<< "Missing expected node " << ex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
TEST_F(SegmentTest, AvoidCycle) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
// add2 is not a TRT candidate so add0/add3 cannot be formed as a
|
||||||
|
// subgraph
|
||||||
|
//
|
||||||
|
// feed
|
||||||
|
// // ||
|
||||||
|
// add0 add1
|
||||||
|
// | | /
|
||||||
|
// | add2
|
||||||
|
// | / ||
|
||||||
|
// add3 add4
|
||||||
|
// | /
|
||||||
|
// <sink>
|
||||||
|
//
|
||||||
|
TF_Operation* feed = Placeholder(graph, s, "feed");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||||
|
|
||||||
|
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
|
||||||
|
TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||||
|
|
||||||
|
SegmentNodesVector segments;
|
||||||
|
ASSERT_EQ(
|
||||||
|
SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
|
||||||
|
default_options_, &segments),
|
||||||
|
tensorflow::Status::OK());
|
||||||
|
|
||||||
|
// Expect no subgraphs
|
||||||
|
EXPECT_EQ(segments.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
TEST_F(SegmentTest, Multiple) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
// add5 is not a TRT candidate so two subgraphs should be formed
|
||||||
|
//
|
||||||
|
// feed
|
||||||
|
// // || ||
|
||||||
|
// add0 add1 add7
|
||||||
|
// | | / / ||
|
||||||
|
// | add2-----add5 add8
|
||||||
|
// | / | | | |
|
||||||
|
// add3 add4 add6
|
||||||
|
// | | /
|
||||||
|
// <sink>
|
||||||
|
//
|
||||||
|
TF_Operation* feed = Placeholder(graph, s, "feed");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||||
|
|
||||||
|
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
|
||||||
|
TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
|
||||||
|
TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||||
|
|
||||||
|
SegmentNodesVector segments;
|
||||||
|
ASSERT_EQ(SegmentGraph(graph_def,
|
||||||
|
MakeCandidateFn({"add0", "add1", "add2", "add3",
|
||||||
|
"add4", "add6", "add7", "add8"}),
|
||||||
|
default_options_, &segments),
|
||||||
|
tensorflow::Status::OK());
|
||||||
|
|
||||||
|
// Expect two subgraphs
|
||||||
|
EXPECT_EQ(segments.size(), 2);
|
||||||
|
|
||||||
|
std::vector<std::string> expected0{"add0", "add1", "add2", "add3"};
|
||||||
|
for (const auto& ex : expected0) {
|
||||||
|
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
|
||||||
|
<< "Missing expected node " << ex;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> expected1{"add6", "add8"};
|
||||||
|
for (const auto& ex : expected1) {
|
||||||
|
EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
|
||||||
|
<< "Missing expected node " << ex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
TEST_F(SegmentTest, BigIfElse) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
// add2 is not a TRT candidate
|
||||||
|
//
|
||||||
|
// feed
|
||||||
|
// ||
|
||||||
|
// add0
|
||||||
|
// // ||
|
||||||
|
// add1 add4
|
||||||
|
// || ||
|
||||||
|
// add2 add5
|
||||||
|
// || ||
|
||||||
|
// add3 add6
|
||||||
|
// || //
|
||||||
|
// add7
|
||||||
|
// ||
|
||||||
|
// <sink>
|
||||||
|
//
|
||||||
|
TF_Operation* feed = Placeholder(graph, s, "feed");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||||
|
|
||||||
|
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||||
|
|
||||||
|
SegmentNodesVector segments;
|
||||||
|
ASSERT_EQ(SegmentGraph(graph_def,
|
||||||
|
MakeCandidateFn({"add0", "add1", "add3", "add4",
|
||||||
|
"add5", "add6", "add7"}),
|
||||||
|
default_options_, &segments),
|
||||||
|
tensorflow::Status::OK());
|
||||||
|
|
||||||
|
// Expect 2 subgraphs
|
||||||
|
EXPECT_EQ(segments.size(), 2);
|
||||||
|
|
||||||
|
std::vector<std::string> expected0{"add3", "add4", "add5", "add6", "add7"};
|
||||||
|
for (const auto& ex : expected0) {
|
||||||
|
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
|
||||||
|
<< "Missing expected node " << ex;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> expected1{"add0", "add1"};
|
||||||
|
for (const auto& ex : expected1) {
|
||||||
|
EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
|
||||||
|
<< "Missing expected node " << ex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace test
|
||||||
|
} // namespace segment
|
||||||
|
} // namespace tensorrt
|
77
tensorflow/contrib/tensorrt/segment/union_find.h
Normal file
77
tensorflow/contrib/tensorrt/segment/union_find.h
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
|
||||||
|
|
||||||
|
namespace tensorrt {
|
||||||
|
namespace segment {
|
||||||
|
|
||||||
|
// Union-Find data structure.
|
||||||
|
// Each cluster has an associated value; when merging clusters we can control
|
||||||
|
// which value becomes the representative of the merged clusters. Values must be
|
||||||
|
// copyable.
|
||||||
|
template <typename T>
|
||||||
|
class UnionFind {
|
||||||
|
public:
|
||||||
|
UnionFind() : size_(1), parent_(nullptr) {}
|
||||||
|
explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {}
|
||||||
|
|
||||||
|
// Returns the number of elements in a cluster.
|
||||||
|
int Size() { return FindRoot()->size_; }
|
||||||
|
|
||||||
|
// Merges this cluster with 'other'. This cluster's value becomes
|
||||||
|
// the value of the merged cluster; the value of 'other' is ignored.
|
||||||
|
void Merge(UnionFind* other);
|
||||||
|
|
||||||
|
// Each cluster has an associated value. Retrieves the value associated
|
||||||
|
// with this cluster.
|
||||||
|
T& ParentValue() { return FindRoot()->value_; }
|
||||||
|
|
||||||
|
// Get the original value of this node.
|
||||||
|
T& Value() { return value_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Finds the root element of the cluster. Performs path compression.
|
||||||
|
UnionFind* FindRoot();
|
||||||
|
|
||||||
|
int size_;
|
||||||
|
UnionFind* parent_;
|
||||||
|
T value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void UnionFind<T>::Merge(UnionFind* other) {
|
||||||
|
UnionFind<T>* a = FindRoot();
|
||||||
|
UnionFind<T>* b = other->FindRoot();
|
||||||
|
if (a == b) return;
|
||||||
|
|
||||||
|
b->parent_ = a;
|
||||||
|
a->size_ += b->size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
UnionFind<T>* UnionFind<T>::FindRoot() {
|
||||||
|
if (!parent_) return this;
|
||||||
|
// Path compression: update intermediate nodes to point to the root of the
|
||||||
|
// equivalence class.
|
||||||
|
parent_ = parent_->FindRoot();
|
||||||
|
return parent_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace segment
|
||||||
|
} // namespace tensorrt
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
|
123
tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
Normal file
123
tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "NvInfer.h"
|
||||||
|
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace shape_inference {
|
||||||
|
tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) {
|
||||||
|
tensorflow::tensorrt::Logger gLogger;
|
||||||
|
string serialized_engine;
|
||||||
|
c->GetAttr("serialized_engine", &serialized_engine);
|
||||||
|
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger);
|
||||||
|
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
|
||||||
|
serialized_engine.c_str(), serialized_engine.size(), nullptr);
|
||||||
|
|
||||||
|
// debug print out engine binding;
|
||||||
|
std::stringstream oss;
|
||||||
|
for (int i = 0; i < trt_engine->getNbBindings(); i++) {
|
||||||
|
LOG(INFO) << "index: " << i
|
||||||
|
<< ", binding name: " << trt_engine->getBindingName(i);
|
||||||
|
|
||||||
|
bool input_flag = trt_engine->bindingIsInput(i);
|
||||||
|
oss << "input?: " << (input_flag ? "Y" : "N");
|
||||||
|
|
||||||
|
oss << "Dimension: ";
|
||||||
|
auto dims = trt_engine->getBindingDimensions(i);
|
||||||
|
oss << " nbDims: " << dims.nbDims << " -> ";
|
||||||
|
for (int j = 0; j < dims.nbDims; j++) oss << dims.d[j] << ", ";
|
||||||
|
LOG(INFO) << oss.str();
|
||||||
|
oss.str("");
|
||||||
|
switch (trt_engine->getBindingDataType(i)) {
|
||||||
|
case nvinfer1::DataType::kFLOAT:
|
||||||
|
LOG(INFO) << "data type: float" << std::endl;
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kHALF:
|
||||||
|
LOG(INFO) << "data type: half" << std::endl;
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kINT8:
|
||||||
|
LOG(INFO) << "data type: int8" << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int nbBatch = -1;
|
||||||
|
// debug print out input arrays
|
||||||
|
std::vector<::tensorflow::DataType> input_type;
|
||||||
|
c->GetAttr("InT", &input_type);
|
||||||
|
oss.str("");
|
||||||
|
for (size_t i = 0; i < c->num_inputs(); i++) {
|
||||||
|
// check if input shape is legit
|
||||||
|
auto input_shape = c->input(i);
|
||||||
|
int index = i;
|
||||||
|
oss << "input:" << i << " type: " << input_type[index] << " shape: ";
|
||||||
|
for (int j = 0; j < c->Rank(input_shape); j++) {
|
||||||
|
auto dimHandler = c->Dim(input_shape, j);
|
||||||
|
if (c->ValueKnown(dimHandler))
|
||||||
|
oss << c->Value(dimHandler) << ", ";
|
||||||
|
else
|
||||||
|
oss << "?" << c->Value(dimHandler) << ", ";
|
||||||
|
if (j == 0) {
|
||||||
|
if (i == 0)
|
||||||
|
nbBatch = c->Value(dimHandler);
|
||||||
|
else if (nbBatch != c->Value(dimHandler))
|
||||||
|
LOG(WARNING) << "!!!!!!nbBatch does not match!!!!!!";
|
||||||
|
// assert(nbBatch == c->Value(dimHandler);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG(INFO) << oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// arrange input here
|
||||||
|
std::vector<string> input_nodes;
|
||||||
|
c->GetAttr("input_nodes", &input_nodes);
|
||||||
|
for (size_t i = 0; i < input_nodes.size(); i++) {
|
||||||
|
int index = i;
|
||||||
|
LOG(INFO) << "input:" << i << " name: " << input_nodes[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
// arrange output here
|
||||||
|
std::vector<string> output_nodes;
|
||||||
|
c->GetAttr("output_nodes", &output_nodes);
|
||||||
|
oss.str("");
|
||||||
|
for (size_t i = 0; i < output_nodes.size(); i++) {
|
||||||
|
int index = i;
|
||||||
|
int binding_index =
|
||||||
|
trt_engine->getBindingIndex(output_nodes[index].c_str());
|
||||||
|
oss << "string name " << output_nodes[index];
|
||||||
|
ShapeHandle output_shape;
|
||||||
|
std::vector<DimensionHandle> vecDim;
|
||||||
|
vecDim.emplace_back(c->MakeDim(nbBatch));
|
||||||
|
if (binding_index != -1) {
|
||||||
|
oss << "got binding " << binding_index;
|
||||||
|
auto dims = trt_engine->getBindingDimensions(binding_index);
|
||||||
|
for (int j = 0; j < dims.nbDims; j++)
|
||||||
|
vecDim.emplace_back(c->MakeDim(dims.d[j]));
|
||||||
|
} else {
|
||||||
|
oss << "no binding ";
|
||||||
|
}
|
||||||
|
output_shape = c->MakeShape(vecDim);
|
||||||
|
c->set_output(i, output_shape);
|
||||||
|
LOG(INFO) << oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace shape_inference
|
||||||
|
} // namespace tensorflow
|
28
tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
Normal file
28
tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace shape_inference {
|
||||||
|
Status TRTEngineOpShapeInference(InferenceContext* c);
|
||||||
|
} // namespace shape_inference
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
|
84
tensorflow/contrib/tensorrt/trt_conversion.i
Normal file
84
tensorflow/contrib/tensorrt/trt_conversion.i
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
/*
|
||||||
|
|
||||||
|
wrap trt_conversion
|
||||||
|
|
||||||
|
*/
|
||||||
|
%{
|
||||||
|
#define SWIG_FILE_WITH_INIT
|
||||||
|
%}
|
||||||
|
%include "std_string.i"
|
||||||
|
%include "std_pair.i"
|
||||||
|
%include "tensorflow/python/lib/core/strings.i"
|
||||||
|
%include "tensorflow/python/platform/base.i"
|
||||||
|
%template(StringPair) std::pair<string,string>;
|
||||||
|
%template() std::pair<swig::SwigPtr_PyObject, swig::SwigPtr_PyObject>;
|
||||||
|
|
||||||
|
%{
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/util/stat_summarizer.h"
|
||||||
|
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||||
|
%}
|
||||||
|
|
||||||
|
%ignoreall
|
||||||
|
%unignore tensorflow;
|
||||||
|
%unignore trt_convert;
|
||||||
|
|
||||||
|
%{
|
||||||
|
std::pair<string,string> trt_convert(string graph_def_string,//const tensorflow::GraphDef&
|
||||||
|
std::vector<string> output_names,
|
||||||
|
size_t max_batch_size,
|
||||||
|
size_t max_workspace_size
|
||||||
|
// unfortunately we can't use TF_Status here since it
|
||||||
|
// is in c/c_api and brings in a lot of other libraries
|
||||||
|
// which in turn declare ops. These ops are included
|
||||||
|
// statically in our library and cause an abort when
|
||||||
|
// module is loaded due to double registration
|
||||||
|
// until Tensorflow properly exposes these headers
|
||||||
|
// we have to work around this by returning a string
|
||||||
|
// and converting it to exception on python side.
|
||||||
|
//,TF_Status* out_status) {
|
||||||
|
) {
|
||||||
|
string out_status;
|
||||||
|
|
||||||
|
tensorflow::GraphDef graph_def;
|
||||||
|
if (!graph_def.ParseFromString(graph_def_string)) {
|
||||||
|
out_status="InvalidArgument;Couldn't interpret input as a GraphDef";
|
||||||
|
return std::pair<string,string>{out_status,""};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!output_names.size()) {
|
||||||
|
out_status="InvalidArgument;Size of the output_names vector is 0";
|
||||||
|
return std::pair<string,string>{out_status,""};
|
||||||
|
//return "";
|
||||||
|
}
|
||||||
|
tensorflow::GraphDef outGraph;
|
||||||
|
tensorflow::Status conversion_status =
|
||||||
|
tensorrt::convert::ConvertGraphDefToTensorRT(graph_def,
|
||||||
|
output_names,
|
||||||
|
max_batch_size,
|
||||||
|
max_workspace_size,
|
||||||
|
&outGraph);
|
||||||
|
if (!conversion_status.ok()) {
|
||||||
|
auto retCode=(int)conversion_status.code();
|
||||||
|
char buff[2000];
|
||||||
|
snprintf(buff,2000,"%d;%s",retCode,conversion_status.error_message().c_str());
|
||||||
|
out_status=buff;
|
||||||
|
return std::pair<string,string>{out_status,""};
|
||||||
|
}
|
||||||
|
string result;
|
||||||
|
if (!outGraph.SerializeToString(&result)) {
|
||||||
|
out_status="InvalidArgument;Couldn't serialize output as a GraphDef";
|
||||||
|
return std::pair<string,string>{out_status,""};
|
||||||
|
}
|
||||||
|
out_status="OK;All good!";
|
||||||
|
return std::pair<string,string>{out_status,result};
|
||||||
|
}
|
||||||
|
%}
|
||||||
|
|
||||||
|
std::pair<string,string> trt_convert(string graph_def_string,
|
||||||
|
std::vector<string> output_names,
|
||||||
|
size_t max_batch_size,
|
||||||
|
size_t max_workspace_size);
|
||||||
|
|
||||||
|
%unignoreall
|
@ -279,7 +279,7 @@ def tf_cc_shared_object(
|
|||||||
linkopts=[],
|
linkopts=[],
|
||||||
framework_so=tf_binary_additional_srcs(),
|
framework_so=tf_binary_additional_srcs(),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
native.cc_binary(
|
native.cc_binary(
|
||||||
name=name,
|
name=name,
|
||||||
srcs=srcs + framework_so,
|
srcs=srcs + framework_so,
|
||||||
deps=deps,
|
deps=deps,
|
||||||
@ -1281,6 +1281,45 @@ def tf_extension_linkopts():
|
|||||||
def tf_extension_copts():
|
def tf_extension_copts():
|
||||||
return [] # No extension c opts
|
return [] # No extension c opts
|
||||||
|
|
||||||
|
# In tf_py_wrap_cc generated libraries
|
||||||
|
# module init functions are not exported unless
|
||||||
|
# they contain one of the keywords in the version file
|
||||||
|
# this prevents custom python modules.
|
||||||
|
# This function attempts to append init_module_name to list of
|
||||||
|
# exported functions in version script
|
||||||
|
def _append_init_to_versionscript_impl(ctx):
|
||||||
|
modName=ctx.attr.module_name
|
||||||
|
isVS=ctx.attr.is_version_script
|
||||||
|
if isVS:
|
||||||
|
ctx.actions.expand_template(
|
||||||
|
template=ctx.file.template_file,
|
||||||
|
output=ctx.outputs.versionscript,
|
||||||
|
substitutions={
|
||||||
|
"global:":"global:\n init_%s;"%modName,
|
||||||
|
},
|
||||||
|
is_executable=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ctx.actions.expand_template(
|
||||||
|
template=ctx.file.template_file,
|
||||||
|
output=ctx.outputs.versionscript,
|
||||||
|
substitutions={
|
||||||
|
"*tensorflow*":"*tensorflow*\ninit_%s"%modName,
|
||||||
|
},
|
||||||
|
is_executable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_append_init_to_versionscript= rule(
|
||||||
|
implementation=_append_init_to_versionscript_impl,
|
||||||
|
attrs={
|
||||||
|
"module_name":attr.string(mandatory=True),
|
||||||
|
"template_file":attr.label(allow_files=True,single_file=True,mandatory=True),
|
||||||
|
"is_version_script":attr.bool(default=True,doc='whether target is a ld version script or exported symbol list',mandatory=False),
|
||||||
|
},
|
||||||
|
outputs={"versionscript":"%{name}.lds"},
|
||||||
|
)
|
||||||
|
|
||||||
def tf_py_wrap_cc(name,
|
def tf_py_wrap_cc(name,
|
||||||
srcs,
|
srcs,
|
||||||
swig_includes=[],
|
swig_includes=[],
|
||||||
@ -1302,26 +1341,39 @@ def tf_py_wrap_cc(name,
|
|||||||
toolchain_deps=["//tools/defaults:crosstool"],
|
toolchain_deps=["//tools/defaults:crosstool"],
|
||||||
module_name=module_name,
|
module_name=module_name,
|
||||||
py_module_name=name)
|
py_module_name=name)
|
||||||
|
vscriptname=name+"_versionscript"
|
||||||
|
_append_init_to_versionscript(
|
||||||
|
name=vscriptname,
|
||||||
|
module_name=module_name,
|
||||||
|
is_version_script=select({
|
||||||
|
"@local_config_cuda//cuda:darwin":False,
|
||||||
|
"//conditions:default":True,
|
||||||
|
}),
|
||||||
|
template_file=select({
|
||||||
|
"@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"),
|
||||||
|
"//conditions:default":clean_dep("//tensorflow:tf_version_script.lds")
|
||||||
|
})
|
||||||
|
)
|
||||||
extra_linkopts = select({
|
extra_linkopts = select({
|
||||||
"@local_config_cuda//cuda:darwin": [
|
"@local_config_cuda//cuda:darwin": [
|
||||||
"-Wl,-exported_symbols_list",
|
"-Wl,-exported_symbols_list",
|
||||||
clean_dep("//tensorflow:tf_exported_symbols.lds")
|
"%s.lds"%vscriptname,
|
||||||
],
|
],
|
||||||
clean_dep("//tensorflow:windows"): [],
|
clean_dep("//tensorflow:windows"): [],
|
||||||
clean_dep("//tensorflow:windows_msvc"): [],
|
clean_dep("//tensorflow:windows_msvc"): [],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"-Wl,--version-script",
|
"-Wl,--version-script",
|
||||||
clean_dep("//tensorflow:tf_version_script.lds")
|
"%s.lds"%vscriptname,
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
extra_deps += select({
|
extra_deps += select({
|
||||||
"@local_config_cuda//cuda:darwin": [
|
"@local_config_cuda//cuda:darwin": [
|
||||||
clean_dep("//tensorflow:tf_exported_symbols.lds")
|
"%s.lds"%vscriptname,
|
||||||
],
|
],
|
||||||
clean_dep("//tensorflow:windows"): [],
|
clean_dep("//tensorflow:windows"): [],
|
||||||
clean_dep("//tensorflow:windows_msvc"): [],
|
clean_dep("//tensorflow:windows_msvc"): [],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
clean_dep("//tensorflow:tf_version_script.lds")
|
"%s.lds"%vscriptname,
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ load(
|
|||||||
)
|
)
|
||||||
load("//third_party/mkl:build_defs.bzl", "if_mkl")
|
load("//third_party/mkl:build_defs.bzl", "if_mkl")
|
||||||
load("//tensorflow:tensorflow.bzl", "if_cuda")
|
load("//tensorflow:tensorflow.bzl", "if_cuda")
|
||||||
|
load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
|
||||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
||||||
|
|
||||||
# This returns a list of headers of all public header libraries (e.g.,
|
# This returns a list of headers of all public header libraries (e.g.,
|
||||||
@ -201,7 +202,8 @@ sh_binary(
|
|||||||
"//tensorflow/python:test_ops",
|
"//tensorflow/python:test_ops",
|
||||||
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
|
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
|
||||||
],
|
],
|
||||||
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
|
}) + if_mkl(["//third_party/mkl:intel_binary_blob"])
|
||||||
|
+ if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# A genrule for generating a marker file for the pip package on Windows
|
# A genrule for generating a marker file for the pip package on Windows
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
|
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
|
||||||
|
|
||||||
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||||
|
load("//third_party/tensorrt:build_defs.bzl", "trt_repository")
|
||||||
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
|
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
|
||||||
load("//third_party/git:git_configure.bzl", "git_configure")
|
load("//third_party/git:git_configure.bzl", "git_configure")
|
||||||
load("//third_party/py:python_configure.bzl", "python_configure")
|
load("//third_party/py:python_configure.bzl", "python_configure")
|
||||||
@ -66,6 +67,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
|||||||
# version we require here.
|
# version we require here.
|
||||||
check_bazel_version_at_least("0.5.4")
|
check_bazel_version_at_least("0.5.4")
|
||||||
cuda_configure(name="local_config_cuda")
|
cuda_configure(name="local_config_cuda")
|
||||||
|
trt_repository(name="local_config_tensorrt")
|
||||||
git_configure(name="local_config_git")
|
git_configure(name="local_config_git")
|
||||||
sycl_configure(name="local_config_sycl")
|
sycl_configure(name="local_config_sycl")
|
||||||
python_configure(name="local_config_python")
|
python_configure(name="local_config_python")
|
||||||
|
0
third_party/tensorrt/BUILD
vendored
Normal file
0
third_party/tensorrt/BUILD
vendored
Normal file
42
third_party/tensorrt/BUILD.tpl
vendored
Normal file
42
third_party/tensorrt/BUILD.tpl
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# -*- python -*-
|
||||||
|
# Description:
|
||||||
|
# provide tensorrt information
|
||||||
|
|
||||||
|
#TODO(Sami) these needs to be defined
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "trt_enabled",
|
||||||
|
define_values = {
|
||||||
|
"using_tensorrt":"true"
|
||||||
|
},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensorrt",
|
||||||
|
srcs =[%{tensorrt_lib}],
|
||||||
|
hdrs = ["include/NvInfer.h",
|
||||||
|
"include/NvUtils.h",
|
||||||
|
],
|
||||||
|
copts= cuda_default_copts(),
|
||||||
|
deps =["@local_config_cuda//cuda:cuda",
|
||||||
|
"@local_config_cuda//cuda:cudnn",],
|
||||||
|
linkstatic = 1,
|
||||||
|
#include_prefix="include/",
|
||||||
|
includes=["include/"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
%{tensorrt_genrules}
|
||||||
|
|
||||||
|
# filegroup(
|
||||||
|
# name = "%{tensorrt_lib}",
|
||||||
|
# srcs = ["%{tensorrt_lib}"],
|
||||||
|
# visibility = ["//visibility:public"],
|
||||||
|
# )
|
203
third_party/tensorrt/LICENSE
vendored
Normal file
203
third_party/tensorrt/LICENSE
vendored
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
Copyright 2015 The TensorFlow Authors. All rights reserved.
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2015, The TensorFlow Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
85
third_party/tensorrt/build_defs.bzl
vendored
Normal file
85
third_party/tensorrt/build_defs.bzl
vendored
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# -*- python -*-
|
||||||
|
"""
|
||||||
|
add a repo_generator rule for tensorrt
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH"
|
||||||
|
_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION"
|
||||||
|
|
||||||
|
def _is_trt_enabled(repo_ctx):
|
||||||
|
if "TF_NEED_TENSORRT" in repo_ctx.os.environ:
|
||||||
|
enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip()
|
||||||
|
return enable_trt == "1"
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _dummy_repo(repo_ctx):
|
||||||
|
|
||||||
|
repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
|
||||||
|
{"%{tensorrt_lib}":"","%{tensorrt_genrules}":""},
|
||||||
|
False)
|
||||||
|
repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
|
||||||
|
{"%{trt_configured}":"False"},False)
|
||||||
|
repo_ctx.file("include/NvUtils.h","",False)
|
||||||
|
repo_ctx.file("include/NvInfer.h","",False)
|
||||||
|
|
||||||
|
def _trt_repo_impl(repo_ctx):
|
||||||
|
"""
|
||||||
|
Implements local_config_tensorrt
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not _is_trt_enabled(repo_ctx):
|
||||||
|
_dummy_repo(repo_ctx)
|
||||||
|
return
|
||||||
|
trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH]
|
||||||
|
trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION]
|
||||||
|
# if deb installation
|
||||||
|
# once a standardized installation between tar and deb
|
||||||
|
# is done, we don't need this
|
||||||
|
if trt_libdir == '/usr/lib/x86_64-linux-gnu':
|
||||||
|
incPath='/usr/include/x86_64-linux-gnu'
|
||||||
|
incname='/usr/include/x86_64-linux-gnu/NvInfer.h'
|
||||||
|
else:
|
||||||
|
incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath)
|
||||||
|
incname=incPath+'/NvInfer.h'
|
||||||
|
if len(trt_ver)>0:
|
||||||
|
origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver)
|
||||||
|
else:
|
||||||
|
origLib="%s/libnvinfer.so"%trt_libdir
|
||||||
|
objdump=repo_ctx.which("objdump")
|
||||||
|
if objdump == None:
|
||||||
|
if len(trt_ver)>0:
|
||||||
|
targetlib="lib/libnvinfer.so.%s"%(trt_ver[0])
|
||||||
|
else:
|
||||||
|
targetlib="lib/libnvinfer.so"
|
||||||
|
else:
|
||||||
|
soname=repo_ctx.execute([objdump,"-p",origLib])
|
||||||
|
for l in soname.stdout.splitlines():
|
||||||
|
if "SONAME" in l:
|
||||||
|
lib=l.strip().split(" ")[-1]
|
||||||
|
targetlib="lib/%s"%(lib)
|
||||||
|
|
||||||
|
if len(trt_ver)>0:
|
||||||
|
repo_ctx.symlink(origLib,targetlib)
|
||||||
|
else:
|
||||||
|
repo_ctx.symlink(origLib,targetlib)
|
||||||
|
grule=('genrule(\n name = "trtlinks",\n'+
|
||||||
|
' outs = [\n "%s",\n "include/NvInfer.h",\n "include/NvUtils.h",\n ],\n'%targetlib +
|
||||||
|
' cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) +
|
||||||
|
'&&\n ln -sf %s $(@D)/include/NvInfer.h '%(incname) +
|
||||||
|
'&&\n ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath))
|
||||||
|
repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
|
||||||
|
{"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule},
|
||||||
|
False)
|
||||||
|
repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
|
||||||
|
{"%{trt_configured}":"True"},False)
|
||||||
|
|
||||||
|
trt_repository=repository_rule(
|
||||||
|
implementation= _trt_repo_impl,
|
||||||
|
local=True,
|
||||||
|
environ=[
|
||||||
|
"TF_NEED_TENSORRT",
|
||||||
|
_TF_TENSORRT_VERSION,
|
||||||
|
_TENSORRT_INSTALLATION_PATH,
|
||||||
|
],
|
||||||
|
)
|
18
third_party/tensorrt/build_defs.bzl.tpl
vendored
Normal file
18
third_party/tensorrt/build_defs.bzl.tpl
vendored
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# -*- python -*-
|
||||||
|
"""
|
||||||
|
template file for trt functions
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_trt_enabled():
|
||||||
|
return %{trt_configured}
|
||||||
|
|
||||||
|
def if_trt(if_true,if_false=[]):
|
||||||
|
# if is_trt_enabled():
|
||||||
|
# return if_true
|
||||||
|
# return if_false
|
||||||
|
|
||||||
|
return select({
|
||||||
|
"@local_config_tensorrt//:trt_enabled":if_true,
|
||||||
|
"//conditions:default":if_false,
|
||||||
|
})
|
Loading…
Reference in New Issue
Block a user