Introducing TensortRT Operator to TF which can run (sub)graphs in
highly optimized TensorRT engines. This commit is a merged version of many commits by benbarsdell <bbarsdell at nvidia.com> deadeyegoodwin <davidg at nvidia.com jjsjann123 <jiej at nvidia.com> samikama <skama at nvidia.com>
This commit is contained in:
parent
e810b107d8
commit
825e7a32e9
126
configure.py
126
configure.py
@ -37,12 +37,14 @@ _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'WORKSPACE')
|
||||
_DEFAULT_CUDA_VERSION = '9.0'
|
||||
_DEFAULT_TENSORRT_VERSION = '4'
|
||||
_DEFAULT_CUDNN_VERSION = '7'
|
||||
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
|
||||
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
|
||||
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
|
||||
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
|
||||
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
|
||||
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu'
|
||||
_TF_OPENCL_VERSION = '1.2'
|
||||
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
|
||||
@ -382,13 +384,12 @@ def set_build_var(environ_cp, var_name, query_item, option_name,
|
||||
|
||||
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
|
||||
environ_cp[var_name] = var
|
||||
if var == '1':
|
||||
write_to_bazelrc('build --define %s=true' % option_name)
|
||||
elif bazel_config_name is not None:
|
||||
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
|
||||
# options and not to set build configs through environment variables.
|
||||
write_to_bazelrc('build:%s --define %s=true'
|
||||
% (bazel_config_name, option_name))
|
||||
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
|
||||
# options and not to set build configs through environment variables.
|
||||
if var=='1':
|
||||
setting='true'
|
||||
confname=":%s"%(bazel_config_name) if bazel_config_name is not None else ""
|
||||
write_to_bazelrc('build%s --define %s=%s' % (confname,option_name,setting))
|
||||
|
||||
|
||||
def set_action_env_var(environ_cp,
|
||||
@ -438,13 +439,12 @@ def convert_version_to_int(version):
|
||||
for seg in version_segments:
|
||||
if not seg.isdigit():
|
||||
return None
|
||||
|
||||
version_str = ''.join(['%03d' % int(seg) for seg in version_segments])
|
||||
return int(version_str)
|
||||
|
||||
|
||||
def check_bazel_version(min_version):
|
||||
"""Check installed bezel version is at least min_version.
|
||||
"""Check installed bazel version is at least min_version.
|
||||
|
||||
Args:
|
||||
min_version: string for minimum bazel version.
|
||||
@ -1056,6 +1056,108 @@ def set_other_cuda_vars(environ_cp):
|
||||
write_to_bazelrc('test --config=cuda')
|
||||
|
||||
|
||||
def set_tf_trt_version(environ_cp):
|
||||
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION."""
|
||||
ask_trt_version = (
|
||||
'Please specify the TensorRT (libnvinfer) version you want to use. '
|
||||
'[Leave empty to default to libnvinfer %s]: ') % _DEFAULT_TENSORRT_VERSION
|
||||
|
||||
while True:
|
||||
tf_trt_version = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TF_TENSORRT_VERSION', ask_trt_version,
|
||||
_DEFAULT_TENSORRT_VERSION)
|
||||
# if library version is passed and known
|
||||
default_trt_path = environ_cp.get('TENSORRT_INSTALL_PATH',_DEFAULT_TENSORRT_PATH_LINUX)
|
||||
ask_trt_path = (r'Please specify the location where libnvinfer %s library is '
|
||||
'installed. Refer to README.md for more details. [Default'
|
||||
' is %s]:') % (tf_trt_version, default_trt_path)
|
||||
trt_install_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TENSORRT_INSTALL_PATH', ask_trt_path, default_trt_path)
|
||||
|
||||
# Result returned from "read" will be used unexpanded. That make "~"
|
||||
# unusable. Going through one more level of expansion to handle that.
|
||||
trt_install_path = os.path.realpath(
|
||||
os.path.expanduser(trt_install_path))
|
||||
# Simple function to search for libnvinfer in install path
|
||||
# it will find all libnvinfer.so* in user defined install path
|
||||
# and lib64 subdirectory and return absolute paths
|
||||
def find_libs(search_path):
|
||||
fl=set()
|
||||
if os.path.exists(search_path) and os.path.isdir(search_path):
|
||||
fl.update([os.path.realpath(os.path.join(search_path,x)) \
|
||||
for x in os.listdir(search_path) if 'libnvinfer.so' in x])
|
||||
return fl
|
||||
possible_files=find_libs(trt_install_path)
|
||||
possible_files.update(find_libs(os.path.join(trt_install_path,'lib64')))
|
||||
if is_linux():
|
||||
cudnnpatt=re.compile(".*libcudnn.so\.?(.*) =>.*$")
|
||||
cudapatt =re.compile(".*libcudart.so\.?(.*) =>.*$")
|
||||
def is_compatible(lib,cudaver,cudnnver):
|
||||
ldd_bin=which('ldd') or '/usr/bin/ldd'
|
||||
ldd_out=run_shell([ldd_bin,lib]).split(os.linesep)
|
||||
for l in ldd_out:
|
||||
if 'libcudnn.so' in l:
|
||||
cudnn=cudnnpatt.search(l)
|
||||
elif 'libcudart.so' in l:
|
||||
cudart=cudapatt.search(l)
|
||||
if cudnn:
|
||||
cudnn=convert_version_to_int(cudnn.group(1)) if len(cudnn.group(1)) else 0
|
||||
if cudart:
|
||||
cudart=convert_version_to_int(cudart.group(1)) if len(cudart.group(1)) else 0
|
||||
return (cudnn==cudnnver) and (cudart==cudaver)
|
||||
cudaver=convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
|
||||
cudnnver=convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
|
||||
valid_libs=[]
|
||||
vfinder=re.compile('.*libnvinfer.so.?(.*)$')
|
||||
highest_ver=[0,None,None]
|
||||
|
||||
for l in possible_files:
|
||||
if is_compatible(l,cudaver,cudnnver):
|
||||
valid_libs.append(l)
|
||||
vstr=vfinder.search(l).group(1)
|
||||
currver=convert_version_to_int(vstr) if len(vstr) else 0
|
||||
if currver > highest_ver[0]:
|
||||
highest_ver= [currver,vstr,l]
|
||||
if highest_ver[1] is not None:
|
||||
trt_install_path=os.path.dirname(highest_ver[2])
|
||||
tf_trt_version=highest_ver[1]
|
||||
break
|
||||
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
|
||||
libnvinfer_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
|
||||
libnvinfer_path_from_ldconfig = re.search('.*libnvinfer.so.* => (.*)',
|
||||
libnvinfer_path_from_ldconfig)
|
||||
if libnvinfer_path_from_ldconfig:
|
||||
libnvinfer_path_from_ldconfig = libnvinfer_path_from_ldconfig.group(1)
|
||||
if os.path.exists('%s.%s' % (libnvinfer_path_from_ldconfig,
|
||||
tf_trt_version)):
|
||||
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
|
||||
break
|
||||
|
||||
# Reset and Retry
|
||||
if len(possible_files):
|
||||
print(
|
||||
'Invalid path to TensorRT %s. libnvinfer.so* files found are for incompatible cuda versions '
|
||||
% tf_trt_version)
|
||||
print(trt_install_path)
|
||||
print(os.path.join(trt_install_path,'lib64'))
|
||||
else:
|
||||
print(
|
||||
'Invalid path to TensorRT %s. No libnvinfer.so* files found in '
|
||||
'found:' % tf_trt_version)
|
||||
print(trt_install_path)
|
||||
print(os.path.join(trt_install_path,'lib64'))
|
||||
if is_linux():
|
||||
print('%s.%s' % (libnvinfer_path_from_ldconfig, tf_trt_version))
|
||||
|
||||
environ_cp['TF_TENSORRT_VERSION'] = ''
|
||||
|
||||
# Set TENSORRT_INSTALL_PATH and TENSORRT_CUDNN_VERSION
|
||||
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
|
||||
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
|
||||
environ_cp['TF_TENSORRT_VERSION'] = tf_trt_version
|
||||
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_trt_version)
|
||||
write_to_bazelrc('build:tensorrt --define using_tensorrt=true')
|
||||
|
||||
def set_host_cxx_compiler(environ_cp):
|
||||
"""Set HOST_CXX_COMPILER."""
|
||||
default_cxx_host_compiler = which('g++') or ''
|
||||
@ -1244,9 +1346,11 @@ def main():
|
||||
environ_cp['TF_NEED_COMPUTECPP'] = '0'
|
||||
environ_cp['TF_NEED_OPENCL'] = '0'
|
||||
environ_cp['TF_CUDA_CLANG'] = '0'
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
||||
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
|
||||
'with_jemalloc', True)
|
||||
@ -1301,6 +1405,10 @@ def main():
|
||||
if not is_windows():
|
||||
set_gcc_host_compiler_path(environ_cp)
|
||||
set_other_cuda_vars(environ_cp)
|
||||
# enable tensorrt if desired. Disabled on non-linux
|
||||
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
|
||||
if environ_cp.get('TF_NEED_TENSORRT') == '1':
|
||||
set_tf_trt_version(environ_cp)
|
||||
|
||||
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
|
||||
if environ_cp.get('TF_NEED_MPI') == '1':
|
||||
|
@ -358,6 +358,14 @@ config_setting(
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "using_tensorrt",
|
||||
define_values = {
|
||||
"using_tensorrt":"true",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_mpi_support",
|
||||
values = {"define": "with_mpi_support=true"},
|
||||
|
@ -7,6 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//third_party/mpi:mpi.bzl", "if_mpi")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
|
||||
|
||||
py_library(
|
||||
name = "contrib_py",
|
||||
@ -104,7 +105,9 @@ py_library(
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:util",
|
||||
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
|
||||
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"])
|
||||
+ if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
|
||||
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
266
tensorflow/contrib/tensorrt/BUILD
Normal file
266
tensorflow/contrib/tensorrt/BUILD
Normal file
@ -0,0 +1,266 @@
|
||||
# -*- python -*-
|
||||
# Description:
|
||||
# provide tensorrt operators and converter package
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_py_wrap_cc",
|
||||
"tf_cc_test",
|
||||
"tf_kernel_library",
|
||||
"tf_custom_op_py_library",
|
||||
"tf_copts",
|
||||
)
|
||||
|
||||
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_trt_engine_op.so",
|
||||
srcs = [
|
||||
"kernels/trt_engine_op.cc",
|
||||
"ops/trt_engine_op.cc",
|
||||
"kernels/trt_engine_op.h",
|
||||
],
|
||||
gpu_srcs = [],
|
||||
deps = [
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
":trt_shape_function",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core/kernels:bounds_check_lib",
|
||||
"//tensorflow/core/kernels:ops_util_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "trt_shape_function",
|
||||
srcs=[
|
||||
"shape_fn/trt_shfn.cc",
|
||||
],
|
||||
hdrs=["shape_fn/trt_shfn.h"],
|
||||
copts=tf_copts(),
|
||||
deps=[
|
||||
":trt_logging",
|
||||
"//third_party/eigen3",
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
"@protobuf_archive//:protobuf",
|
||||
"@nsync//:nsync_headers",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
tf_kernel_library(
|
||||
name = "trt_engine_op_kernel",
|
||||
srcs = [
|
||||
"kernels/trt_engine_op.cc",
|
||||
],
|
||||
hdrs=[
|
||||
"kernels/trt_engine_op.h",
|
||||
],
|
||||
gpu_srcs = [
|
||||
],
|
||||
deps = [
|
||||
":trt_logging",
|
||||
":trt_shape_function",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:gpu_headers_lib",
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
],
|
||||
alwayslink=1,
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"trt_engine_op",
|
||||
],
|
||||
deps=[
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
cc_library(
|
||||
name="trt_logging",
|
||||
srcs = [
|
||||
"log/trt_logger.cc",
|
||||
],
|
||||
hdrs=[
|
||||
"log/trt_logger.h",
|
||||
],
|
||||
deps=[
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "trt_engine_op",
|
||||
deps = [
|
||||
":trt_engine_op_op_lib",
|
||||
":trt_shape_function",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "trt_engine_op_loader",
|
||||
srcs = ["python/ops/trt_engine_op.py"],
|
||||
dso = [":python/ops/_trt_engine_op.so",
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:resources",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "init_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":trt_ops_py",
|
||||
":trt_convert_py",
|
||||
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="trt_ops_py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps=[":trt_engine_op",
|
||||
":trt_engine_op_loader",
|
||||
],
|
||||
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="trt_convert_py",
|
||||
srcs=["python/trt_convert.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps=[
|
||||
":wrap_conversion"
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
name="wrap_conversion",
|
||||
srcs=["trt_conversion.i"],
|
||||
deps=[
|
||||
":trt_conversion",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name= "trt_conversion",
|
||||
srcs=[
|
||||
"convert/convert_nodes.cc",
|
||||
"convert/convert_graph.cc",
|
||||
"segment/segment.cc",
|
||||
"convert/inferShapes.cc",
|
||||
],
|
||||
hdrs=[
|
||||
"convert/convert_nodes.h",
|
||||
"convert/convert_graph.h",
|
||||
"convert/inferShapes.h",
|
||||
"segment/segment.h",
|
||||
"segment/union_find.h",
|
||||
],
|
||||
deps=[
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
"@nsync//:nsync_headers",
|
||||
":trt_logging",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
#"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "tensorrt_ops.so",
|
||||
srcs = [
|
||||
"ops/tensorrt_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Library for the segmenting portion of TensorRT operation creation
|
||||
cc_library(
|
||||
name = "segment",
|
||||
srcs = [
|
||||
"segment/segment.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"segment/union_find.h",
|
||||
"segment/segment.h",
|
||||
],
|
||||
deps = [
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
linkstatic = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "segment_test",
|
||||
size = "small",
|
||||
srcs = ["segment/segment_test.cc"],
|
||||
deps = [
|
||||
":segment",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Library for the node-level conversion portion of TensorRT operation creation
|
||||
|
||||
filegroup(
|
||||
name = "cppfiles",
|
||||
srcs = glob(["**/*.cc"]),
|
||||
visibility=["//visibility:private"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = glob(["**/*.h"]),
|
||||
visibility=["//visibility:private"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
42
tensorflow/contrib/tensorrt/README.md
Normal file
42
tensorflow/contrib/tensorrt/README.md
Normal file
@ -0,0 +1,42 @@
|
||||
Using TensorRT in TensorFlow
|
||||
============================
|
||||
|
||||
This module provides necessary bindings and introduces TRT_engine_op
|
||||
operator that wraps a subgraph in TensorRT.
|
||||
|
||||
Compilation
|
||||
-----------
|
||||
|
||||
In order to compile the module, you need to have a local TensorRT
|
||||
installation (libnvinfer.so and respective include files). During the
|
||||
configuration step, TensorRT should be enabled and installation path
|
||||
should be set. If installed through package managers (deb,rpm),
|
||||
configure script should find the necessary components from the system
|
||||
automatically. If installed from tar packages, user has to set path to
|
||||
location where the library is installed during configuration.
|
||||
|
||||
In order to enable TensorRT support, user has to add `--config=tensorrt` to
|
||||
the build flags during the compilation such as
|
||||
|
||||
```
|
||||
bazel build --config=cuda --config=opt --config=tensorrt //tensorflow/tools/pip_package:build_pip_package
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
|
||||
```
|
||||
|
||||
After the installation of tensorflow package, TensorRT transformation
|
||||
will be available. An example use is shown below.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.tensorrt as trt
|
||||
#... create and train or load model
|
||||
gdef=sess.graph.as_graph_def()
|
||||
trt_gdef=trt.CreateInferenceGraph(gdef, #original graph_def
|
||||
["output"], #name of output node(s)
|
||||
max_batch_size, #maximum batch size to run the inference
|
||||
max_workspace_size # max memory for TensorRT to use
|
||||
)
|
||||
tf.reset_default_graph()
|
||||
tf.import_graph_def(graph_def=trt_gdef)
|
||||
#...... run inference
|
||||
```
|
19
tensorflow/contrib/tensorrt/__init__.py
Normal file
19
tensorflow/contrib/tensorrt/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tensorrt.python import *
|
253
tensorflow/contrib/tensorrt/convert/convert_graph.cc
Normal file
253
tensorflow/contrib/tensorrt/convert/convert_graph.cc
Normal file
@ -0,0 +1,253 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "NvInfer.h"
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
|
||||
#include "tensorflow/contrib/tensorrt/segment/segment.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
|
||||
//------------------------------------------------------------------------------
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
namespace {
|
||||
|
||||
static std::unordered_set<std::string> output_nodes;
|
||||
bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
|
||||
static const std::set<std::string> candidate_ops = {
|
||||
"Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
|
||||
"Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
|
||||
// TODO(ben,jie): ...
|
||||
};
|
||||
if (output_nodes.count(node_def.name())) return false;
|
||||
return candidate_ops.count(node_def.op());
|
||||
}
|
||||
|
||||
void GetSubGraphIncomingEdges(tensorflow::Graph const& graph,
|
||||
std::set<int> const& subgraph_node_ids,
|
||||
tensorflow::EdgeSet* incoming_edges) {
|
||||
for (int node_id : subgraph_node_ids) {
|
||||
tensorflow::Node const* node = graph.FindNodeId(node_id);
|
||||
LOG(DEBUG) << node->name() << " has incoming edges: ";
|
||||
for (tensorflow::Edge const* edge : node->in_edges()) {
|
||||
if (!subgraph_node_ids.count(edge->src()->id()) &&
|
||||
!edge->src()->IsSource()) {
|
||||
LOG(DEBUG) << edge->src()->name() << ", ";
|
||||
incoming_edges->insert(edge);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetSubGraphOutgoingEdges(tensorflow::Graph const& graph,
|
||||
std::set<int> const& subgraph_node_ids,
|
||||
tensorflow::EdgeSet* outgoing_edges) {
|
||||
for (int node_id : subgraph_node_ids) {
|
||||
tensorflow::Node const* node = graph.FindNodeId(node_id);
|
||||
LOG(DEBUG) << node->name() << " has outgoing edges: ";
|
||||
for (tensorflow::Edge const* edge : node->out_edges()) {
|
||||
if (!subgraph_node_ids.count(edge->dst()->id()) &&
|
||||
!edge->dst()->IsSink()) {
|
||||
outgoing_edges->insert(edge);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::string, int> ParseTensorName(std::string name,
|
||||
int default_idx = 0) {
|
||||
int idx = default_idx;
|
||||
size_t sep = name.find_last_of(':');
|
||||
if (sep != std::string::npos) {
|
||||
name = name.substr(0, sep);
|
||||
idx = std::stoi(name.substr(sep + 1));
|
||||
}
|
||||
return std::make_pair(name, idx);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap(
|
||||
const std::vector<std::string>& tensor_names) {
|
||||
std::unordered_map<std::string, std::vector<int>> result;
|
||||
for (std::string const& tensor_name : tensor_names) {
|
||||
std::string node_name;
|
||||
int index;
|
||||
std::tie(node_name, index) = ParseTensorName(tensor_name);
|
||||
result[node_name].push_back(index);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRT(
|
||||
tensorflow::Graph& graph, const std::vector<std::string>& output_names,
|
||||
const std::set<int>& subgraph_node_ids, size_t max_batch_size,
|
||||
size_t max_workspace_size, const ShapeMap& shape_map) {
|
||||
tensorflow::EdgeSet subgraph_incoming_edges;
|
||||
GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges);
|
||||
|
||||
std::vector<std::pair<int, int>> subgraph_inputs;
|
||||
|
||||
|
||||
// Collect inputs by looking for incoming edges
|
||||
for (tensorflow::Edge const* edge : subgraph_incoming_edges) {
|
||||
subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
|
||||
}
|
||||
std::set<std::pair<int, int>> subgraph_outputs_set;
|
||||
// Collect outputs referenced from output_names
|
||||
auto output_name_to_index_map = BuildTensorNameMap(output_names);
|
||||
// for (int node_id : subgraph_node_ids_no_placeholder) {
|
||||
for (int node_id : subgraph_node_ids) {
|
||||
tensorflow::Node* node = graph.FindNodeId(node_id);
|
||||
if (output_name_to_index_map.count(node->name())) {
|
||||
for (int index : output_name_to_index_map.at(node->name())) {
|
||||
subgraph_outputs_set.insert({node_id, index});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Collect outputs referenced from outgoing edges
|
||||
tensorflow::EdgeSet subgraph_outgoing_edges;
|
||||
// GetSubGraphOutgoingEdges(graph, subgraph_node_ids_no_placeholder,
|
||||
// &subgraph_outgoing_edges);
|
||||
GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges);
|
||||
for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
|
||||
subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
|
||||
}
|
||||
// Impose an ordering on the outputs
|
||||
std::vector<std::pair<int, int>> subgraph_outputs(
|
||||
subgraph_outputs_set.begin(), subgraph_outputs_set.end());
|
||||
// Build TensorRT node and add it to the graph
|
||||
tensorflow::NodeDef trt_node_def;
|
||||
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
|
||||
graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
|
||||
max_batch_size, max_workspace_size, shape_map, &trt_node_def));
|
||||
tensorflow::Status status;
|
||||
tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status);
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
// Re-map outgoing edges to use the new TRT node instead of the orig subgraph
|
||||
std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
|
||||
for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
|
||||
subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
|
||||
}
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
|
||||
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
|
||||
int new_src_output = subgraph_edge_to_output_map.at(old_src);
|
||||
graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
|
||||
}
|
||||
// Remove the original subgraph
|
||||
for (int node_id : subgraph_node_ids) {
|
||||
tensorflow::Node* node = graph.FindNodeId(node_id);
|
||||
// Don't remove the input placeholders
|
||||
if (node->type_string() == "Placeholder") {
|
||||
continue;
|
||||
}
|
||||
graph.RemoveNode(node);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status BuildNodeMap(
|
||||
const tensorflow::Graph& graph,
|
||||
std::unordered_map<std::string, tensorflow::Node*>* node_map) {
|
||||
for (auto* node : graph.op_nodes()) {
|
||||
if (!node_map->insert({node->name(), node}).second) {
|
||||
return tensorflow::errors::AlreadyExists(
|
||||
"Node name is not unique in graph: " + node->name());
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<std::string>& output_names, size_t max_batch_size,
|
||||
size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) {
|
||||
ShapeMap shape_map;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::trt::inferShapes(graph_def, output_names, shape_map));
|
||||
std::stringstream oss;
|
||||
for (auto& n : shape_map) { // nodes
|
||||
oss << " Node= " << n.first << ", ";
|
||||
for (auto o : n.second) { // outputs
|
||||
oss << o.first.DebugString() << " T= " << o.second << ", ";
|
||||
}
|
||||
LOG(DEBUG) << oss.str();
|
||||
oss.str("");
|
||||
}
|
||||
// Build full graph
|
||||
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
|
||||
graph_def.library());
|
||||
tensorflow::Graph graph(flib);
|
||||
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
|
||||
tensorflow::GraphConstructorOptions(), graph_def, &graph));
|
||||
|
||||
// Segment the graph into subgraphs that can be converted to TensorRT
|
||||
tensorrt::segment::SegmentOptions segment_options;
|
||||
// TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
|
||||
for (auto node : output_names) output_nodes.insert(node);
|
||||
|
||||
// TODO(sami): this should be passed as a knob!!!!
|
||||
segment_options.minimum_segment_size = 2;
|
||||
tensorrt::segment::SegmentNodesVector segments;
|
||||
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
|
||||
graph_def, IsTensorRTCandidate, segment_options, &segments));
|
||||
if (segments.size() > 1) {
|
||||
// LOG(WARNING) << "Multiple TensorRT candidate subgraphs were found, "
|
||||
//<< "but only the first can be converted.";
|
||||
// segments.erase(++segments.begin(), segments.end());
|
||||
LOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
|
||||
}
|
||||
std::unordered_map<std::string, tensorflow::Node*> node_map;
|
||||
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
|
||||
for (std::set<std::string> const& subgraph_node_names : segments) {
|
||||
std::set<int> subgraph_node_ids;
|
||||
for (std::string const& node_name : subgraph_node_names) {
|
||||
subgraph_node_ids.insert(node_map.at(node_name)->id());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
|
||||
graph, output_names, subgraph_node_ids, max_batch_size,
|
||||
max_workspace_size, shape_map));
|
||||
}
|
||||
graph.ToGraphDef(new_graph_def);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
34
tensorflow/contrib/tensorrt/convert/convert_graph.h
Normal file
34
tensorflow/contrib/tensorrt/convert/convert_graph.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<std::string>& output_names, size_t max_batch_size,
|
||||
size_t max_workspace_size, tensorflow::GraphDef* new_graph_def);
|
||||
}
|
||||
} // namespace tensorrt
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
|
1737
tensorflow/contrib/tensorrt/convert/convert_nodes.cc
Normal file
1737
tensorflow/contrib/tensorrt/convert/convert_nodes.cc
Normal file
File diff suppressed because it is too large
Load Diff
42
tensorflow/contrib/tensorrt/convert/convert_nodes.h
Normal file
42
tensorflow/contrib/tensorrt/convert/convert_nodes.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
|
||||
const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
|
||||
const std::vector<std::pair<int, int>>&
|
||||
input_inds, // {node_id, output_idx}
|
||||
const std::vector<std::pair<int, int>>&
|
||||
output_inds, // {node_id, output_idx}
|
||||
size_t max_batch_size, size_t max_workspace_size, const ShapeMap& shape_map,
|
||||
tensorflow::NodeDef* trt_node);
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
|
125
tensorflow/contrib/tensorrt/convert/inferShapes.cc
Normal file
125
tensorflow/contrib/tensorrt/convert/inferShapes.cc
Normal file
@ -0,0 +1,125 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
|
||||
#include <functional>
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb_text.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
|
||||
|
||||
namespace tensorflow {
|
||||
namespace trt {
|
||||
std::vector<tensorflow::DataType> getTypes(const tensorflow::OpDef& op,
|
||||
const tensorflow::NodeDef& nd,
|
||||
bool inp = true) {
|
||||
const auto& attrMap = nd.attr();
|
||||
auto getType = [&attrMap](decltype(
|
||||
op.input_arg(0)) a) -> std::vector<tensorflow::DataType> {
|
||||
std::vector<tensorflow::DataType> tvec;
|
||||
if (!a.type_list_attr().empty()) { // get the list types
|
||||
const auto& tl = attrMap.at(a.type_list_attr()).list();
|
||||
int tsize = tl.type_size();
|
||||
tvec.reserve(tsize);
|
||||
for (int t = 0; t < tsize; t++) {
|
||||
tvec.push_back(tl.type(t));
|
||||
}
|
||||
return tvec;
|
||||
}
|
||||
tensorflow::DataType cType = tensorflow::DT_INVALID;
|
||||
if (a.type() != tensorflow::DT_INVALID) { // get defined types
|
||||
cType = a.type();
|
||||
} else if (!a.type_attr().empty()) {
|
||||
cType = attrMap.at(a.type_attr()).type();
|
||||
}
|
||||
if (!a.number_attr().empty()) { // numbertypes
|
||||
int64 nTensors = attrMap.at(a.number_attr()).i();
|
||||
tvec = std::vector<tensorflow::DataType>(nTensors, cType);
|
||||
return tvec;
|
||||
}
|
||||
tvec.push_back(cType);
|
||||
return tvec;
|
||||
};
|
||||
std::vector<tensorflow::DataType> types;
|
||||
if (inp) {
|
||||
int n_inputs = op.input_arg_size();
|
||||
for (int i = 0; i < n_inputs; i++) {
|
||||
auto tout = getType(op.input_arg(i));
|
||||
LOG(DEBUG) << "Node= " << nd.name() << " #inputs" << tout.size();
|
||||
types.insert(types.end(), tout.begin(), tout.end());
|
||||
}
|
||||
} else {
|
||||
int n_outputs = op.output_arg_size();
|
||||
// types.resize(n_outputs);
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
auto tout = getType(op.output_arg(i));
|
||||
LOG(DEBUG) << "Node= " << nd.name() << " #outputs" << tout.size();
|
||||
types.insert(types.end(), tout.begin(), tout.end());
|
||||
}
|
||||
}
|
||||
return types;
|
||||
}
|
||||
|
||||
tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<std::string>& output_names,
|
||||
ShapeMap& shapes) {
|
||||
tensorflow::Graph g(OpRegistry::Global());
|
||||
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
|
||||
tensorflow::GraphConstructorOptions(), graph_def, &g));
|
||||
std::vector<tensorflow::Node*> POnodes;
|
||||
tensorflow::GetPostOrder(g, &POnodes);
|
||||
tensorflow::ShapeRefiner refiner(graph_def.versions().producer(),
|
||||
OpRegistry::Global());
|
||||
for (auto n = POnodes.rbegin(); n != POnodes.rend(); ++n) {
|
||||
TF_CHECK_OK(refiner.AddNode(*n));
|
||||
}
|
||||
|
||||
auto shape2PTS = [](tensorflow::shape_inference::InferenceContext* ic,
|
||||
const tensorflow::shape_inference::ShapeHandle& sh)
|
||||
-> tensorflow::PartialTensorShape {
|
||||
std::vector<int64> dims;
|
||||
int64 rank = ic->Rank(sh);
|
||||
for (int64 i = 0; i < rank; i++) {
|
||||
auto dh = ic->Dim(sh, i);
|
||||
dims.push_back(ic->Value(dh));
|
||||
}
|
||||
return tensorflow::PartialTensorShape(dims);
|
||||
};
|
||||
for (const auto& n : POnodes) {
|
||||
auto ic = refiner.GetContext(n);
|
||||
if (ic) {
|
||||
int nOuts = ic->num_outputs();
|
||||
auto types = getTypes(n->op_def(), n->def(), false);
|
||||
std::vector<
|
||||
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>
|
||||
SAT;
|
||||
for (int i = 0; i < nOuts; i++) {
|
||||
auto PTS = shape2PTS(ic, ic->output(i));
|
||||
SAT.push_back({PTS, types.at(i)});
|
||||
}
|
||||
shapes[n->name()] = SAT;
|
||||
} else {
|
||||
LOG(WARNING) << "Node " << n->name() << " doesn't have InferenceContext!";
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
} // namespace trt
|
||||
} // namespace tensorflow
|
39
tensorflow/contrib/tensorrt/convert/inferShapes.h
Normal file
39
tensorflow/contrib/tensorrt/convert/inferShapes.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
typedef std::unordered_map<std::string,
|
||||
std::vector<std::pair<tensorflow::PartialTensorShape,
|
||||
tensorflow::DataType>>>
|
||||
ShapeMap;
|
||||
namespace tensorflow {
|
||||
namespace trt {
|
||||
tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<std::string>& output_names,
|
||||
ShapeMap& shapes);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
|
183
tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
Normal file
183
tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <sstream>
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
// Use TF logging f
|
||||
|
||||
|
||||
namespace tensorflow {
|
||||
static ::tensorflow::tensorrt::Logger gLogger;
|
||||
|
||||
using namespace nvinfer1;
|
||||
|
||||
namespace tensorrt {
|
||||
|
||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
// char *gieModelStream{nullptr};
|
||||
// size_t size{0};
|
||||
|
||||
// read serialized_engine
|
||||
std::string serialized_engine;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("serialized_engine", &serialized_engine));
|
||||
|
||||
// register input output node name in trt_sub_graph
|
||||
OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
|
||||
|
||||
// TODO(samikama) runtime should be taken from a resourcemanager as well.
|
||||
// Only engine should be in the op and context and runtime should be taken
|
||||
// from resourcemanager
|
||||
IRuntime* infer = createInferRuntime(gLogger);
|
||||
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
|
||||
serialized_engine.c_str(), serialized_engine.size(), nullptr));
|
||||
|
||||
trt_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
|
||||
// runtime is safe to delete after engine creation
|
||||
infer->destroy();
|
||||
std::stringstream oss;
|
||||
// debug iterate through all binding instances
|
||||
for (int i = 0; i < trt_engine_ptr_->getNbBindings(); i++) {
|
||||
LOG(INFO) << "index: " << i
|
||||
<< ", binding name: " << trt_engine_ptr_->getBindingName(i);
|
||||
|
||||
if (trt_engine_ptr_->bindingIsInput(i)) {
|
||||
LOG(INFO) << "INPUT";
|
||||
} else {
|
||||
LOG(INFO) << "OUTPUT";
|
||||
}
|
||||
oss << "Dimension: ";
|
||||
auto dims = trt_engine_ptr_->getBindingDimensions(i);
|
||||
oss << " nbDims: " << dims.nbDims << " -> ";
|
||||
for (int j = 0; j < Dims::MAX_DIMS; j++) {
|
||||
oss << dims.d[j] << ", ";
|
||||
}
|
||||
LOG(INFO) << oss.str();
|
||||
oss.str("");
|
||||
switch (trt_engine_ptr_->getBindingDataType(i)) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
LOG(INFO) << "data type float" << std::endl;
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
LOG(INFO) << "data type half" << std::endl;
|
||||
break;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
LOG(INFO) << "data type int8" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK_NE(cudaStreamCreate(&stream_),0); // logic here is wrong
|
||||
// cudaStreamCreate(&stream_);
|
||||
}
|
||||
|
||||
void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
int nbBindings = context->num_inputs() + context->num_outputs();
|
||||
// TODO(jjsjann123) multiple input/output
|
||||
std::vector<void*> buffers(nbBindings);
|
||||
|
||||
size_t bindingIndex;
|
||||
int nbBatch = 0;
|
||||
bool valid = true;
|
||||
for (int i = 0; i < context->num_inputs(); i++) {
|
||||
// Grab the input tensor
|
||||
bindingIndex = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
|
||||
|
||||
const Tensor& input_tensor = context->input(i);
|
||||
const TensorShape& input_shape = input_tensor.shape();
|
||||
if (i == 0) {
|
||||
nbBatch = input_shape.dim_size(0);
|
||||
} else if (nbBatch != input_shape.dim_size(0)) {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
// int64 input_shape.dim_size(int d)
|
||||
// int input_shape.dims()
|
||||
switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
LOG(INFO) << "float";
|
||||
buffers[bindingIndex] = (void*)(input_tensor.flat<float>().data());
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
LOG(INFO) << "half";
|
||||
// buffers[bindingIndex] = (void*)input_tensor.flat<float16>().data();
|
||||
break;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
LOG(INFO) << "int8";
|
||||
// buffers[bindingIndex] = (void*)input_tensor.flat<int8>().data();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!valid) LOG(WARNING) << "input data inconsistent batch size";
|
||||
|
||||
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
|
||||
// This is bad that we have to reallocate output buffer every run.
|
||||
// Create an output tensor
|
||||
bindingIndex = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
|
||||
Tensor* output_tensor = NULL;
|
||||
|
||||
TensorShape output_shape;
|
||||
if (bindingIndex != -1) {
|
||||
LOG(INFO) << "got binding " << bindingIndex;
|
||||
auto dims = trt_engine_ptr_->getBindingDimensions(bindingIndex);
|
||||
std::vector<int> trt_shape(dims.nbDims + 1);
|
||||
trt_shape[0] = nbBatch;
|
||||
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
|
||||
TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
|
||||
&output_shape);
|
||||
} else {
|
||||
LOG(INFO) << "no binding ";
|
||||
break;
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(i, output_shape, &output_tensor));
|
||||
// buffers[bindingIndex] = (void*)output_tensor->flat<float>();
|
||||
// buffers[bindingIndex] = output_tensor->flat<float>().data();
|
||||
switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
LOG(INFO) << "float";
|
||||
buffers[bindingIndex] =
|
||||
reinterpret_cast<void*>(output_tensor->flat<float>().data());
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
LOG(INFO) << "half";
|
||||
// buffers[bindingIndex] = (void*)output_tensor->flat<float16>().data();
|
||||
break;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
LOG(INFO) << "int8";
|
||||
// buffers[bindingIndex] = (void*)output_tensor->flat<int8>().data();
|
||||
break;
|
||||
}
|
||||
}
|
||||
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
||||
const cudaStream_t* stream = CHECK_NOTNULL(
|
||||
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
|
||||
trt_context_ptr_->enqueue(nbBatch, &buffers[0], *stream, nullptr);
|
||||
cudaStreamSynchronize(*stream);
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
55
tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
Normal file
55
tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
|
||||
|
||||
#include <NvInfer.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace tensorrt {
|
||||
class Logger;
|
||||
class TRTEngineOp : public OpKernel {
|
||||
public:
|
||||
explicit TRTEngineOp(OpKernelConstruction* context);
|
||||
|
||||
void Compute(OpKernelContext* context) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
struct Destroyer {
|
||||
void operator()(T* d) { d->destroy(); }
|
||||
};
|
||||
template <typename T>
|
||||
using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
|
||||
destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
|
||||
// TODO(samikama) context should go to a resource manager!
|
||||
destroyed_ptr<nvinfer1::IExecutionContext> trt_context_ptr_;
|
||||
std::vector<string> input_nodes_;
|
||||
std::vector<string> output_nodes_;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
|
56
tensorflow/contrib/tensorrt/log/trt_logger.cc
Normal file
56
tensorflow/contrib/tensorrt/log/trt_logger.cc
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
// Use TF logging for TensorRT informations
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
|
||||
//------------------------------------------------------------------------------
|
||||
namespace tensorflow {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
namespace tensorrt {
|
||||
|
||||
void Logger::log(Severity severity, const char* msg) {
|
||||
// suppress info-level messages
|
||||
switch (severity) {
|
||||
case Severity::kINFO: { // mark TRT info messages as debug!
|
||||
LOG(DEBUG) << msg;
|
||||
break;
|
||||
}
|
||||
case Severity::kWARNING: {
|
||||
LOG(WARNING) << msg;
|
||||
break;
|
||||
}
|
||||
case Severity::kERROR: {
|
||||
LOG(ERROR) << msg;
|
||||
break;
|
||||
}
|
||||
case Severity::kINTERNAL_ERROR: {
|
||||
LOG(FATAL) << msg;
|
||||
break;
|
||||
}
|
||||
// This is useless for now. But would catch it in future if enum changes. It
|
||||
// is always good to have default case!
|
||||
default: {
|
||||
LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
|
||||
} // namespace tensorflow
|
41
tensorflow/contrib/tensorrt/log/trt_logger.h
Normal file
41
tensorflow/contrib/tensorrt/log/trt_logger.h
Normal file
@ -0,0 +1,41 @@
|
||||
// -*- c++ -*-
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
|
||||
|
||||
// Use TF logging f
|
||||
#include <NvInfer.h>
|
||||
#include <string>
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
namespace tensorflow {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
namespace tensorrt {
|
||||
|
||||
// Logger for GIE info/warning/errors
|
||||
class Logger : public nvinfer1::ILogger {
|
||||
void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
|
37
tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
Normal file
37
tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace shape_inference {
|
||||
extern Status TRTEngineOpShapeInference(InferenceContext* c);
|
||||
}
|
||||
|
||||
REGISTER_OP("TRTEngineOp")
|
||||
.Attr("serialized_engine: string")
|
||||
.Attr("input_nodes: list(string)")
|
||||
.Attr("output_nodes: list(string)")
|
||||
.Attr("InT: list({int8, float16, float32})")
|
||||
.Attr("OutT: list({int8, float16, float32})")
|
||||
.Input("in_tensor: InT")
|
||||
.Output("out_tensor: OutT")
|
||||
.SetShapeFn(shape_inference::TRTEngineOpShapeInference);
|
||||
|
||||
} // namespace tensorflow
|
8
tensorflow/contrib/tensorrt/python/__init__.py
Normal file
8
tensorflow/contrib/tensorrt/python/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
|
||||
from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph
|
||||
# pylint: enable=unused-import,wildcard-import
|
35
tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
Normal file
35
tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
||||
from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
|
||||
|
||||
_trt_engine_op = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_trt_engine_op.so"))
|
||||
else:
|
||||
raise RuntimeError("Windows platforms are not supported")
|
||||
|
||||
|
91
tensorflow/contrib/tensorrt/python/trt_convert.py
Normal file
91
tensorflow/contrib/tensorrt/python/trt_convert.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Exposes the Python wrapper conversion to trt_graph."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl as _impl
|
||||
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
|
||||
from tensorflow.python.util import compat
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
def CreateInferenceGraph(input_graph_def, outputs,max_batch_size=1,max_workspace_size=2<<20):
|
||||
"""Python wrapper for the TRT transormation.
|
||||
|
||||
|
||||
Args:
|
||||
input_graph_def: GraphDef object containing a model to be transformed.
|
||||
outputs: List of node names for the model outputs.
|
||||
max_batch_size: max size for the input batch
|
||||
max_workspace_size: parameter to control memory allocation (in Bytes)
|
||||
|
||||
Returns:
|
||||
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
|
||||
"""
|
||||
|
||||
# with errors.raise_exception_on_not_ok_status() as status:
|
||||
# output_graph_def_string = trt_convert(
|
||||
# input_graph_def_string,outputs,
|
||||
# max_batch_size,max_workspace_size, status)
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.import_graph_def(input_graph_def, name="")
|
||||
rewriter_config = rewriter_config_pb2.RewriterConfig()
|
||||
rewriter_config.optimizers.append('layout')
|
||||
rewriter_config.optimizers.append('constfold')
|
||||
|
||||
# mark output nodes as fetch
|
||||
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||||
for node_name in outputs:
|
||||
out_node = g.get_operation_by_name(node_name)
|
||||
for i in range(0,len(out_node.outputs)):
|
||||
train_op.append(out_node.outputs[0])
|
||||
|
||||
# constant folding
|
||||
mg = meta_graph.create_meta_graph_def(graph=g)
|
||||
meta_graph.add_collection_def(mg, ops.GraphKeys.TRAIN_OP)
|
||||
optimized_graph_def_str = \
|
||||
tf_optimizer.OptimizeGraph(rewriter_config, mg).SerializeToString()
|
||||
|
||||
# TODO(sami): Fix this when we can return status from C++ library
|
||||
# There is a problem with the TF internal library setup that doesn't allow us to return a status object from C++.
|
||||
# Thus we return a pair or strings where first one is encoded status and the second one is the
|
||||
# transformed graphs protobuf string.
|
||||
out = trt_convert(
|
||||
optimized_graph_def_str ,outputs,
|
||||
max_batch_size,max_workspace_size)
|
||||
status = out[0]
|
||||
output_graph_def_string = out[1]
|
||||
del optimized_graph_def_str #save some memory
|
||||
if len(status) < 2:
|
||||
raise _impl.UnknownError(None,None,status)
|
||||
if status[:2] != "OK":
|
||||
msg=status.split(";")
|
||||
if len(msg) == 1:
|
||||
raise RuntimeError("Status message is malformed {}".format(status))
|
||||
raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
|
||||
output_graph_def = graph_pb2.GraphDef()
|
||||
output_graph_def.ParseFromString(output_graph_def_string)
|
||||
del output_graph_def_string #save some memory
|
||||
return output_graph_def
|
259
tensorflow/contrib/tensorrt/segment/segment.cc
Normal file
259
tensorflow/contrib/tensorrt/segment/segment.cc
Normal file
@ -0,0 +1,259 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/segment/segment.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/segment/union_find.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
namespace tensorrt {
|
||||
namespace segment {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
namespace {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
bool CanContractEdge(const tensorflow::Edge* edge,
|
||||
const tensorflow::Graph& graph) {
|
||||
const tensorflow::Node* src = edge->src();
|
||||
const tensorflow::Node* dst = edge->dst();
|
||||
|
||||
// Can't contract edge if doing so would cause a cycle in the
|
||||
// graph. So, if there is a directed path from 'src' to 'dst', other
|
||||
// than 'edge' (or any other direct edge from 'src' to 'dst'), then
|
||||
// combining 'src' and 'dst' will cause a cycle along that path.
|
||||
//
|
||||
// In practice, to avoid modifying the graph and to take advantage
|
||||
// of existing graph functions, we perform an equivalent.
|
||||
// 1. Get all nodes incoming to 'dst', excluding 'src'
|
||||
// 2. Reverse DFS from those nodes
|
||||
// 3. If reverse DFS reaches 'src' then we have a cycle
|
||||
std::vector<tensorflow::Node*> dfs_start_nodes;
|
||||
for (tensorflow::Node* node : dst->in_nodes()) {
|
||||
if (node != src) {
|
||||
dfs_start_nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_cycle = false;
|
||||
if (!dfs_start_nodes.empty()) {
|
||||
tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
|
||||
[&is_cycle, src](tensorflow::Node* node) {
|
||||
if (node == src) {
|
||||
is_cycle = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return !is_cycle;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
|
||||
std::vector<const tensorflow::Edge*>* remove_edges) {
|
||||
// Transfer all inputs and outputs of 'dst' to 'src' except edges
|
||||
// connecting the two.
|
||||
tensorflow::Node* src = edge->src();
|
||||
tensorflow::Node* dst = edge->dst();
|
||||
|
||||
// We can use '0' for input/output index because we don't need them
|
||||
// to be accurate for the way we are using the graph.
|
||||
std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
|
||||
dst->in_edges().end());
|
||||
for (const tensorflow::Edge* in_edge : in_edges) {
|
||||
if (in_edge->src() != src) {
|
||||
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
|
||||
if (e->src() == graph->source_node()) {
|
||||
graph->AddEdge(e->src(), e->src_output(), src,
|
||||
tensorflow::Graph::kControlSlot);
|
||||
} else {
|
||||
graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
|
||||
dst->out_edges().end());
|
||||
for (const tensorflow::Edge* out_edge : out_edges) {
|
||||
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
|
||||
if (e->dst() == graph->sink_node()) {
|
||||
graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
|
||||
e->dst_input());
|
||||
} else {
|
||||
graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
// Return the edges that must be removed to disconnect 'dst' from
|
||||
// the graph. We don't actually remove 'dst' since the caller holds
|
||||
// references to all the nodes.
|
||||
for (const auto& in_edge : dst->in_edges()) {
|
||||
remove_edges->push_back(in_edge);
|
||||
}
|
||||
for (const auto& out_edge : dst->out_edges()) {
|
||||
remove_edges->push_back(out_edge);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
tensorflow::Status SegmentGraph(
|
||||
const tensorflow::GraphDef& gdef,
|
||||
const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
|
||||
const SegmentOptions& options, SegmentNodesVector* segments) {
|
||||
// Create a Graph representation of the GraphDef.
|
||||
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
|
||||
gdef.library());
|
||||
tensorflow::Graph graph(flib);
|
||||
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
|
||||
tensorflow::GraphConstructorOptions(), gdef, &graph));
|
||||
|
||||
// tensorflow::DumpGraph("Pre-Segment", &graph);
|
||||
|
||||
// Use a union-find to collect the nodes that belong to the same
|
||||
// segment. A node value of nullptr indicates that the node is not a
|
||||
// candidate for TRT.
|
||||
std::vector<UnionFind<tensorflow::Node*>> node_segments;
|
||||
for (int i = 0; i < graph.num_node_ids(); ++i) {
|
||||
tensorflow::Node* node = graph.FindNodeId(i);
|
||||
if (!candidate_fn(node->def())) {
|
||||
node = nullptr;
|
||||
}
|
||||
node_segments.emplace_back(node);
|
||||
}
|
||||
|
||||
// Visit nodes in reverse topological order and use edge
|
||||
// contraction to merge candidate nodes.
|
||||
std::vector<tensorflow::Node*> order;
|
||||
tensorflow::GetPostOrder(graph, &order);
|
||||
|
||||
for (const tensorflow::Node* node : order) {
|
||||
// All output nodes of 'node' have been visited...
|
||||
VLOG(2) << "Trying node " << node->name();
|
||||
|
||||
// 'node' must be a TRT candidate...
|
||||
if (node_segments[node->id()].Value() == nullptr) {
|
||||
VLOG(2) << "... not a TRT candidate";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Contract output edges to combine 'node' with output
|
||||
// nodes. Iterate since combining two nodes may unblock other
|
||||
// combining.
|
||||
while (true) {
|
||||
std::set<const tensorflow::Edge*> contract_edges;
|
||||
for (const tensorflow::Edge* out_edge : node->out_edges()) {
|
||||
VLOG(2) << "... out node " << out_edge->dst()->name();
|
||||
|
||||
// Out node must be TRT candidate...
|
||||
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
|
||||
VLOG(2) << "... ... not a TRT candidate";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (CanContractEdge(out_edge, graph)) {
|
||||
VLOG(2) << "... ... can contract";
|
||||
contract_edges.insert(out_edge);
|
||||
} else {
|
||||
VLOG(2) << "... ... cannot contract, would form cycle";
|
||||
}
|
||||
}
|
||||
|
||||
if (contract_edges.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Contract edges and collect the adjacent nodes into the same
|
||||
// segment/subgraph.
|
||||
while (!contract_edges.empty()) {
|
||||
const tensorflow::Edge* contract_edge = *contract_edges.begin();
|
||||
const tensorflow::Node* src = contract_edge->src();
|
||||
const tensorflow::Node* dst = contract_edge->dst();
|
||||
|
||||
VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
|
||||
node_segments[src->id()].Merge(&node_segments[dst->id()]);
|
||||
|
||||
// Contracting the edge leaves disconnected graph edges.
|
||||
// Remove these from the graph and from 'contract_edges' so we
|
||||
// don't visit them again.
|
||||
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
|
||||
std::vector<const tensorflow::Edge*> remove_edges;
|
||||
ContractEdge(e, &graph, &remove_edges);
|
||||
|
||||
for (const tensorflow::Edge* r : remove_edges) {
|
||||
contract_edges.erase(r);
|
||||
graph.RemoveEdge(r);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect the segments/subgraphs. Each subgraph is represented by a
|
||||
// set of the names of the nodes in that subgraph.
|
||||
std::unordered_map<std::string, std::set<std::string>> sg_map;
|
||||
for (auto& u : node_segments) {
|
||||
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
|
||||
sg_map[u.ParentValue()->name()].insert(u.Value()->name());
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup the graph to remove disconnected nodes before outputting
|
||||
if (VLOG_IS_ON(2)) {
|
||||
for (tensorflow::Node* node : graph.nodes()) {
|
||||
if ((node->in_edges().size() == 0) && (node->out_edges().size() == 0)) {
|
||||
graph.RemoveNode(node);
|
||||
}
|
||||
}
|
||||
// tensorflow::DumpGraph("Post-Segment", &graph);
|
||||
}
|
||||
|
||||
// Convert the segments into the expected return format
|
||||
for (const auto& itr : sg_map) {
|
||||
const auto& segment_node_names = itr.second;
|
||||
if (VLOG_IS_ON(1)) {
|
||||
std::string s;
|
||||
for (const auto& name : segment_node_names) {
|
||||
s += " " + name;
|
||||
}
|
||||
VLOG(1) << "Segment " << segments->size() << ":" << s;
|
||||
}
|
||||
|
||||
// Don't use small segments.
|
||||
if (static_cast<int>(segment_node_names.size()) <
|
||||
options.minimum_segment_size) {
|
||||
VLOG(1) << "Segment " << segments->size() << " has only "
|
||||
<< segment_node_names.size() << " nodes, dropping";
|
||||
continue;
|
||||
}
|
||||
|
||||
segments->emplace_back(segment_node_names);
|
||||
}
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace segment
|
||||
} // namespace tensorrt
|
53
tensorflow/contrib/tensorrt/segment/segment.h
Normal file
53
tensorflow/contrib/tensorrt/segment/segment.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorrt {
|
||||
namespace segment {
|
||||
|
||||
using SegmentNodesVector = std::vector<std::set<std::string>>;
|
||||
|
||||
struct SegmentOptions {
|
||||
// Segment must contain at least this many nodes.
|
||||
int minimum_segment_size = 2;
|
||||
};
|
||||
|
||||
// Get the subgraphs of a graph that can be handled by TensorRT.
|
||||
//
|
||||
// @param gdef The GraphDef describing the network
|
||||
// @param candidate_fn A function that returns true for a NodeDef if
|
||||
// that node can be handled by TensorRT.
|
||||
// @param segments Returns the TensorRT segments/subgraphs. Each entry
|
||||
// in the vector describes a subgraph by giving a set of the names of
|
||||
// all the NodeDefs in that subgraph.
|
||||
// @return the status.
|
||||
tensorflow::Status SegmentGraph(
|
||||
const tensorflow::GraphDef& gdef,
|
||||
const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
|
||||
const SegmentOptions& options, SegmentNodesVector* segments);
|
||||
|
||||
} // namespace segment
|
||||
} // namespace tensorrt
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
|
363
tensorflow/contrib/tensorrt/segment/segment_test.cc
Normal file
363
tensorflow/contrib/tensorrt/segment/segment_test.cc
Normal file
@ -0,0 +1,363 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/segment/segment.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
using namespace tensorflow;
|
||||
|
||||
namespace tensorrt {
|
||||
namespace segment {
|
||||
namespace test {
|
||||
|
||||
class SegmentTest : public ::testing::Test {
|
||||
public:
|
||||
bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
|
||||
|
||||
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
|
||||
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name);
|
||||
|
||||
std::function<bool(const NodeDef&)> MakeCandidateFn(
|
||||
const std::set<std::string>& node_names);
|
||||
|
||||
protected:
|
||||
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
|
||||
TF_Operation** op);
|
||||
void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name, TF_Operation** op, bool check);
|
||||
|
||||
SegmentOptions default_options_;
|
||||
};
|
||||
|
||||
bool SegmentTest::GetGraphDef(TF_Graph* graph,
|
||||
tensorflow::GraphDef* graph_def) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Buffer* buffer = TF_NewBuffer();
|
||||
TF_GraphToGraphDef(graph, buffer, s);
|
||||
bool ret = TF_GetCode(s) == TF_OK;
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
|
||||
TF_DeleteBuffer(buffer);
|
||||
TF_DeleteStatus(s);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn(
|
||||
const std::set<std::string>& node_names) {
|
||||
return [node_names](const NodeDef& node) -> bool {
|
||||
return node_names.find(node.name()) != node_names.end();
|
||||
};
|
||||
}
|
||||
|
||||
void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
|
||||
const char* name, TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
|
||||
TF_SetAttrType(desc, "dtype", TF_INT32);
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
TF_Operation* op;
|
||||
PlaceholderHelper(graph, s, name, &op);
|
||||
return op;
|
||||
}
|
||||
|
||||
void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
|
||||
TF_Status* s, const char* name, TF_Operation** op,
|
||||
bool check) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
|
||||
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
|
||||
TF_AddInputList(desc, add_inputs, 2);
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
if (check) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
|
||||
TF_Graph* graph, TF_Status* s,
|
||||
const char* name) {
|
||||
TF_Operation* op;
|
||||
AddHelper(l, r, graph, s, name, &op, true);
|
||||
return op;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
TEST_F(SegmentTest, Empty) {
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
GraphDef graph_def;
|
||||
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||
|
||||
SegmentNodesVector segments;
|
||||
ASSERT_EQ(
|
||||
SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
|
||||
tensorflow::Status::OK());
|
||||
|
||||
// Expect no segments/subgraphs.
|
||||
EXPECT_TRUE(segments.empty());
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
TEST_F(SegmentTest, Simple) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// feed
|
||||
// // ||
|
||||
// add0 add1
|
||||
// | | /
|
||||
// | add2
|
||||
// | / ||
|
||||
// add3 add4
|
||||
// | /
|
||||
// <sink>
|
||||
//
|
||||
TF_Operation* feed = Placeholder(graph, s, "feed");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||
|
||||
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
|
||||
TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
|
||||
|
||||
GraphDef graph_def;
|
||||
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||
|
||||
SegmentNodesVector segments;
|
||||
ASSERT_EQ(
|
||||
SegmentGraph(graph_def,
|
||||
MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
|
||||
default_options_, &segments),
|
||||
tensorflow::Status::OK());
|
||||
|
||||
// Expect all Add operations to be collapsed into a single segment
|
||||
ASSERT_EQ(segments.size(), 1);
|
||||
std::vector<std::string> expected{"add0", "add1", "add2", "add3", "add4"};
|
||||
for (const auto& ex : expected) {
|
||||
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
|
||||
<< "Missing expected node " << ex;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
TEST_F(SegmentTest, AvoidCycle) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// add2 is not a TRT candidate so add0/add3 cannot be formed as a
|
||||
// subgraph
|
||||
//
|
||||
// feed
|
||||
// // ||
|
||||
// add0 add1
|
||||
// | | /
|
||||
// | add2
|
||||
// | / ||
|
||||
// add3 add4
|
||||
// | /
|
||||
// <sink>
|
||||
//
|
||||
TF_Operation* feed = Placeholder(graph, s, "feed");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||
|
||||
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
|
||||
TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
|
||||
|
||||
GraphDef graph_def;
|
||||
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||
|
||||
SegmentNodesVector segments;
|
||||
ASSERT_EQ(
|
||||
SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
|
||||
default_options_, &segments),
|
||||
tensorflow::Status::OK());
|
||||
|
||||
// Expect no subgraphs
|
||||
EXPECT_EQ(segments.size(), 0);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
TEST_F(SegmentTest, Multiple) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// add5 is not a TRT candidate so two subgraphs should be formed
|
||||
//
|
||||
// feed
|
||||
// // || ||
|
||||
// add0 add1 add7
|
||||
// | | / / ||
|
||||
// | add2-----add5 add8
|
||||
// | / | | | |
|
||||
// add3 add4 add6
|
||||
// | | /
|
||||
// <sink>
|
||||
//
|
||||
TF_Operation* feed = Placeholder(graph, s, "feed");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||
|
||||
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
|
||||
TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
|
||||
TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
|
||||
|
||||
GraphDef graph_def;
|
||||
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||
|
||||
SegmentNodesVector segments;
|
||||
ASSERT_EQ(SegmentGraph(graph_def,
|
||||
MakeCandidateFn({"add0", "add1", "add2", "add3",
|
||||
"add4", "add6", "add7", "add8"}),
|
||||
default_options_, &segments),
|
||||
tensorflow::Status::OK());
|
||||
|
||||
// Expect two subgraphs
|
||||
EXPECT_EQ(segments.size(), 2);
|
||||
|
||||
std::vector<std::string> expected0{"add0", "add1", "add2", "add3"};
|
||||
for (const auto& ex : expected0) {
|
||||
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
|
||||
<< "Missing expected node " << ex;
|
||||
}
|
||||
|
||||
std::vector<std::string> expected1{"add6", "add8"};
|
||||
for (const auto& ex : expected1) {
|
||||
EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
|
||||
<< "Missing expected node " << ex;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
TEST_F(SegmentTest, BigIfElse) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// add2 is not a TRT candidate
|
||||
//
|
||||
// feed
|
||||
// ||
|
||||
// add0
|
||||
// // ||
|
||||
// add1 add4
|
||||
// || ||
|
||||
// add2 add5
|
||||
// || ||
|
||||
// add3 add6
|
||||
// || //
|
||||
// add7
|
||||
// ||
|
||||
// <sink>
|
||||
//
|
||||
TF_Operation* feed = Placeholder(graph, s, "feed");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
|
||||
|
||||
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
|
||||
|
||||
GraphDef graph_def;
|
||||
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
|
||||
|
||||
SegmentNodesVector segments;
|
||||
ASSERT_EQ(SegmentGraph(graph_def,
|
||||
MakeCandidateFn({"add0", "add1", "add3", "add4",
|
||||
"add5", "add6", "add7"}),
|
||||
default_options_, &segments),
|
||||
tensorflow::Status::OK());
|
||||
|
||||
// Expect 2 subgraphs
|
||||
EXPECT_EQ(segments.size(), 2);
|
||||
|
||||
std::vector<std::string> expected0{"add3", "add4", "add5", "add6", "add7"};
|
||||
for (const auto& ex : expected0) {
|
||||
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
|
||||
<< "Missing expected node " << ex;
|
||||
}
|
||||
|
||||
std::vector<std::string> expected1{"add0", "add1"};
|
||||
for (const auto& ex : expected1) {
|
||||
EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
|
||||
<< "Missing expected node " << ex;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace segment
|
||||
} // namespace tensorrt
|
77
tensorflow/contrib/tensorrt/segment/union_find.h
Normal file
77
tensorflow/contrib/tensorrt/segment/union_find.h
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
|
||||
|
||||
namespace tensorrt {
|
||||
namespace segment {
|
||||
|
||||
// Union-Find data structure.
|
||||
// Each cluster has an associated value; when merging clusters we can control
|
||||
// which value becomes the representative of the merged clusters. Values must be
|
||||
// copyable.
|
||||
template <typename T>
|
||||
class UnionFind {
|
||||
public:
|
||||
UnionFind() : size_(1), parent_(nullptr) {}
|
||||
explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {}
|
||||
|
||||
// Returns the number of elements in a cluster.
|
||||
int Size() { return FindRoot()->size_; }
|
||||
|
||||
// Merges this cluster with 'other'. This cluster's value becomes
|
||||
// the value of the merged cluster; the value of 'other' is ignored.
|
||||
void Merge(UnionFind* other);
|
||||
|
||||
// Each cluster has an associated value. Retrieves the value associated
|
||||
// with this cluster.
|
||||
T& ParentValue() { return FindRoot()->value_; }
|
||||
|
||||
// Get the original value of this node.
|
||||
T& Value() { return value_; }
|
||||
|
||||
private:
|
||||
// Finds the root element of the cluster. Performs path compression.
|
||||
UnionFind* FindRoot();
|
||||
|
||||
int size_;
|
||||
UnionFind* parent_;
|
||||
T value_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void UnionFind<T>::Merge(UnionFind* other) {
|
||||
UnionFind<T>* a = FindRoot();
|
||||
UnionFind<T>* b = other->FindRoot();
|
||||
if (a == b) return;
|
||||
|
||||
b->parent_ = a;
|
||||
a->size_ += b->size_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
UnionFind<T>* UnionFind<T>::FindRoot() {
|
||||
if (!parent_) return this;
|
||||
// Path compression: update intermediate nodes to point to the root of the
|
||||
// equivalence class.
|
||||
parent_ = parent_->FindRoot();
|
||||
return parent_;
|
||||
}
|
||||
|
||||
} // namespace segment
|
||||
} // namespace tensorrt
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
|
123
tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
Normal file
123
tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
Normal file
@ -0,0 +1,123 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "NvInfer.h"
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace shape_inference {
|
||||
tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) {
|
||||
tensorflow::tensorrt::Logger gLogger;
|
||||
string serialized_engine;
|
||||
c->GetAttr("serialized_engine", &serialized_engine);
|
||||
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger);
|
||||
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
|
||||
serialized_engine.c_str(), serialized_engine.size(), nullptr);
|
||||
|
||||
// debug print out engine binding;
|
||||
std::stringstream oss;
|
||||
for (int i = 0; i < trt_engine->getNbBindings(); i++) {
|
||||
LOG(INFO) << "index: " << i
|
||||
<< ", binding name: " << trt_engine->getBindingName(i);
|
||||
|
||||
bool input_flag = trt_engine->bindingIsInput(i);
|
||||
oss << "input?: " << (input_flag ? "Y" : "N");
|
||||
|
||||
oss << "Dimension: ";
|
||||
auto dims = trt_engine->getBindingDimensions(i);
|
||||
oss << " nbDims: " << dims.nbDims << " -> ";
|
||||
for (int j = 0; j < dims.nbDims; j++) oss << dims.d[j] << ", ";
|
||||
LOG(INFO) << oss.str();
|
||||
oss.str("");
|
||||
switch (trt_engine->getBindingDataType(i)) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
LOG(INFO) << "data type: float" << std::endl;
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
LOG(INFO) << "data type: half" << std::endl;
|
||||
break;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
LOG(INFO) << "data type: int8" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int nbBatch = -1;
|
||||
// debug print out input arrays
|
||||
std::vector<::tensorflow::DataType> input_type;
|
||||
c->GetAttr("InT", &input_type);
|
||||
oss.str("");
|
||||
for (size_t i = 0; i < c->num_inputs(); i++) {
|
||||
// check if input shape is legit
|
||||
auto input_shape = c->input(i);
|
||||
int index = i;
|
||||
oss << "input:" << i << " type: " << input_type[index] << " shape: ";
|
||||
for (int j = 0; j < c->Rank(input_shape); j++) {
|
||||
auto dimHandler = c->Dim(input_shape, j);
|
||||
if (c->ValueKnown(dimHandler))
|
||||
oss << c->Value(dimHandler) << ", ";
|
||||
else
|
||||
oss << "?" << c->Value(dimHandler) << ", ";
|
||||
if (j == 0) {
|
||||
if (i == 0)
|
||||
nbBatch = c->Value(dimHandler);
|
||||
else if (nbBatch != c->Value(dimHandler))
|
||||
LOG(WARNING) << "!!!!!!nbBatch does not match!!!!!!";
|
||||
// assert(nbBatch == c->Value(dimHandler);
|
||||
}
|
||||
}
|
||||
LOG(INFO) << oss.str();
|
||||
}
|
||||
|
||||
// arrange input here
|
||||
std::vector<string> input_nodes;
|
||||
c->GetAttr("input_nodes", &input_nodes);
|
||||
for (size_t i = 0; i < input_nodes.size(); i++) {
|
||||
int index = i;
|
||||
LOG(INFO) << "input:" << i << " name: " << input_nodes[index];
|
||||
}
|
||||
|
||||
// arrange output here
|
||||
std::vector<string> output_nodes;
|
||||
c->GetAttr("output_nodes", &output_nodes);
|
||||
oss.str("");
|
||||
for (size_t i = 0; i < output_nodes.size(); i++) {
|
||||
int index = i;
|
||||
int binding_index =
|
||||
trt_engine->getBindingIndex(output_nodes[index].c_str());
|
||||
oss << "string name " << output_nodes[index];
|
||||
ShapeHandle output_shape;
|
||||
std::vector<DimensionHandle> vecDim;
|
||||
vecDim.emplace_back(c->MakeDim(nbBatch));
|
||||
if (binding_index != -1) {
|
||||
oss << "got binding " << binding_index;
|
||||
auto dims = trt_engine->getBindingDimensions(binding_index);
|
||||
for (int j = 0; j < dims.nbDims; j++)
|
||||
vecDim.emplace_back(c->MakeDim(dims.d[j]));
|
||||
} else {
|
||||
oss << "no binding ";
|
||||
}
|
||||
output_shape = c->MakeShape(vecDim);
|
||||
c->set_output(i, output_shape);
|
||||
LOG(INFO) << oss.str();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
28
tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
Normal file
28
tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
|
||||
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace shape_inference {
|
||||
Status TRTEngineOpShapeInference(InferenceContext* c);
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
|
84
tensorflow/contrib/tensorrt/trt_conversion.i
Normal file
84
tensorflow/contrib/tensorrt/trt_conversion.i
Normal file
@ -0,0 +1,84 @@
|
||||
/*
|
||||
|
||||
wrap trt_conversion
|
||||
|
||||
*/
|
||||
%{
|
||||
#define SWIG_FILE_WITH_INIT
|
||||
%}
|
||||
%include "std_string.i"
|
||||
%include "std_pair.i"
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
%template(StringPair) std::pair<string,string>;
|
||||
%template() std::pair<swig::SwigPtr_PyObject, swig::SwigPtr_PyObject>;
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/stat_summarizer.h"
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
%unignore tensorflow;
|
||||
%unignore trt_convert;
|
||||
|
||||
%{
|
||||
std::pair<string,string> trt_convert(string graph_def_string,//const tensorflow::GraphDef&
|
||||
std::vector<string> output_names,
|
||||
size_t max_batch_size,
|
||||
size_t max_workspace_size
|
||||
// unfortunately we can't use TF_Status here since it
|
||||
// is in c/c_api and brings in a lot of other libraries
|
||||
// which in turn declare ops. These ops are included
|
||||
// statically in our library and cause an abort when
|
||||
// module is loaded due to double registration
|
||||
// until Tensorflow properly exposes these headers
|
||||
// we have to work around this by returning a string
|
||||
// and converting it to exception on python side.
|
||||
//,TF_Status* out_status) {
|
||||
) {
|
||||
string out_status;
|
||||
|
||||
tensorflow::GraphDef graph_def;
|
||||
if (!graph_def.ParseFromString(graph_def_string)) {
|
||||
out_status="InvalidArgument;Couldn't interpret input as a GraphDef";
|
||||
return std::pair<string,string>{out_status,""};
|
||||
}
|
||||
|
||||
if (!output_names.size()) {
|
||||
out_status="InvalidArgument;Size of the output_names vector is 0";
|
||||
return std::pair<string,string>{out_status,""};
|
||||
//return "";
|
||||
}
|
||||
tensorflow::GraphDef outGraph;
|
||||
tensorflow::Status conversion_status =
|
||||
tensorrt::convert::ConvertGraphDefToTensorRT(graph_def,
|
||||
output_names,
|
||||
max_batch_size,
|
||||
max_workspace_size,
|
||||
&outGraph);
|
||||
if (!conversion_status.ok()) {
|
||||
auto retCode=(int)conversion_status.code();
|
||||
char buff[2000];
|
||||
snprintf(buff,2000,"%d;%s",retCode,conversion_status.error_message().c_str());
|
||||
out_status=buff;
|
||||
return std::pair<string,string>{out_status,""};
|
||||
}
|
||||
string result;
|
||||
if (!outGraph.SerializeToString(&result)) {
|
||||
out_status="InvalidArgument;Couldn't serialize output as a GraphDef";
|
||||
return std::pair<string,string>{out_status,""};
|
||||
}
|
||||
out_status="OK;All good!";
|
||||
return std::pair<string,string>{out_status,result};
|
||||
}
|
||||
%}
|
||||
|
||||
std::pair<string,string> trt_convert(string graph_def_string,
|
||||
std::vector<string> output_names,
|
||||
size_t max_batch_size,
|
||||
size_t max_workspace_size);
|
||||
|
||||
%unignoreall
|
@ -279,7 +279,7 @@ def tf_cc_shared_object(
|
||||
linkopts=[],
|
||||
framework_so=tf_binary_additional_srcs(),
|
||||
**kwargs):
|
||||
native.cc_binary(
|
||||
native.cc_binary(
|
||||
name=name,
|
||||
srcs=srcs + framework_so,
|
||||
deps=deps,
|
||||
@ -1281,6 +1281,45 @@ def tf_extension_linkopts():
|
||||
def tf_extension_copts():
|
||||
return [] # No extension c opts
|
||||
|
||||
# In tf_py_wrap_cc generated libraries
|
||||
# module init functions are not exported unless
|
||||
# they contain one of the keywords in the version file
|
||||
# this prevents custom python modules.
|
||||
# This function attempts to append init_module_name to list of
|
||||
# exported functions in version script
|
||||
def _append_init_to_versionscript_impl(ctx):
|
||||
modName=ctx.attr.module_name
|
||||
isVS=ctx.attr.is_version_script
|
||||
if isVS:
|
||||
ctx.actions.expand_template(
|
||||
template=ctx.file.template_file,
|
||||
output=ctx.outputs.versionscript,
|
||||
substitutions={
|
||||
"global:":"global:\n init_%s;"%modName,
|
||||
},
|
||||
is_executable=False,
|
||||
)
|
||||
else:
|
||||
ctx.actions.expand_template(
|
||||
template=ctx.file.template_file,
|
||||
output=ctx.outputs.versionscript,
|
||||
substitutions={
|
||||
"*tensorflow*":"*tensorflow*\ninit_%s"%modName,
|
||||
},
|
||||
is_executable=False,
|
||||
)
|
||||
|
||||
|
||||
_append_init_to_versionscript= rule(
|
||||
implementation=_append_init_to_versionscript_impl,
|
||||
attrs={
|
||||
"module_name":attr.string(mandatory=True),
|
||||
"template_file":attr.label(allow_files=True,single_file=True,mandatory=True),
|
||||
"is_version_script":attr.bool(default=True,doc='whether target is a ld version script or exported symbol list',mandatory=False),
|
||||
},
|
||||
outputs={"versionscript":"%{name}.lds"},
|
||||
)
|
||||
|
||||
def tf_py_wrap_cc(name,
|
||||
srcs,
|
||||
swig_includes=[],
|
||||
@ -1302,26 +1341,39 @@ def tf_py_wrap_cc(name,
|
||||
toolchain_deps=["//tools/defaults:crosstool"],
|
||||
module_name=module_name,
|
||||
py_module_name=name)
|
||||
vscriptname=name+"_versionscript"
|
||||
_append_init_to_versionscript(
|
||||
name=vscriptname,
|
||||
module_name=module_name,
|
||||
is_version_script=select({
|
||||
"@local_config_cuda//cuda:darwin":False,
|
||||
"//conditions:default":True,
|
||||
}),
|
||||
template_file=select({
|
||||
"@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"),
|
||||
"//conditions:default":clean_dep("//tensorflow:tf_version_script.lds")
|
||||
})
|
||||
)
|
||||
extra_linkopts = select({
|
||||
"@local_config_cuda//cuda:darwin": [
|
||||
"-Wl,-exported_symbols_list",
|
||||
clean_dep("//tensorflow:tf_exported_symbols.lds")
|
||||
"%s.lds"%vscriptname,
|
||||
],
|
||||
clean_dep("//tensorflow:windows"): [],
|
||||
clean_dep("//tensorflow:windows_msvc"): [],
|
||||
"//conditions:default": [
|
||||
"-Wl,--version-script",
|
||||
clean_dep("//tensorflow:tf_version_script.lds")
|
||||
"%s.lds"%vscriptname,
|
||||
]
|
||||
})
|
||||
extra_deps += select({
|
||||
"@local_config_cuda//cuda:darwin": [
|
||||
clean_dep("//tensorflow:tf_exported_symbols.lds")
|
||||
"%s.lds"%vscriptname,
|
||||
],
|
||||
clean_dep("//tensorflow:windows"): [],
|
||||
clean_dep("//tensorflow:windows_msvc"): [],
|
||||
"//conditions:default": [
|
||||
clean_dep("//tensorflow:tf_version_script.lds")
|
||||
"%s.lds"%vscriptname,
|
||||
]
|
||||
})
|
||||
|
||||
|
@ -11,6 +11,7 @@ load(
|
||||
)
|
||||
load("//third_party/mkl:build_defs.bzl", "if_mkl")
|
||||
load("//tensorflow:tensorflow.bzl", "if_cuda")
|
||||
load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
||||
|
||||
# This returns a list of headers of all public header libraries (e.g.,
|
||||
@ -201,7 +202,8 @@ sh_binary(
|
||||
"//tensorflow/python:test_ops",
|
||||
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
|
||||
],
|
||||
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
|
||||
}) + if_mkl(["//third_party/mkl:intel_binary_blob"])
|
||||
+ if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
|
||||
)
|
||||
|
||||
# A genrule for generating a marker file for the pip package on Windows
|
||||
|
@ -1,6 +1,7 @@
|
||||
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
|
||||
|
||||
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||
load("//third_party/tensorrt:build_defs.bzl", "trt_repository")
|
||||
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
|
||||
load("//third_party/git:git_configure.bzl", "git_configure")
|
||||
load("//third_party/py:python_configure.bzl", "python_configure")
|
||||
@ -66,6 +67,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
# version we require here.
|
||||
check_bazel_version_at_least("0.5.4")
|
||||
cuda_configure(name="local_config_cuda")
|
||||
trt_repository(name="local_config_tensorrt")
|
||||
git_configure(name="local_config_git")
|
||||
sycl_configure(name="local_config_sycl")
|
||||
python_configure(name="local_config_python")
|
||||
|
0
third_party/tensorrt/BUILD
vendored
Normal file
0
third_party/tensorrt/BUILD
vendored
Normal file
42
third_party/tensorrt/BUILD.tpl
vendored
Normal file
42
third_party/tensorrt/BUILD.tpl
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
# -*- python -*-
|
||||
# Description:
|
||||
# provide tensorrt information
|
||||
|
||||
#TODO(Sami) these needs to be defined
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
|
||||
|
||||
config_setting(
|
||||
name = "trt_enabled",
|
||||
define_values = {
|
||||
"using_tensorrt":"true"
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorrt",
|
||||
srcs =[%{tensorrt_lib}],
|
||||
hdrs = ["include/NvInfer.h",
|
||||
"include/NvUtils.h",
|
||||
],
|
||||
copts= cuda_default_copts(),
|
||||
deps =["@local_config_cuda//cuda:cuda",
|
||||
"@local_config_cuda//cuda:cudnn",],
|
||||
linkstatic = 1,
|
||||
#include_prefix="include/",
|
||||
includes=["include/"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
%{tensorrt_genrules}
|
||||
|
||||
# filegroup(
|
||||
# name = "%{tensorrt_lib}",
|
||||
# srcs = ["%{tensorrt_lib}"],
|
||||
# visibility = ["//visibility:public"],
|
||||
# )
|
203
third_party/tensorrt/LICENSE
vendored
Normal file
203
third_party/tensorrt/LICENSE
vendored
Normal file
@ -0,0 +1,203 @@
|
||||
Copyright 2015 The TensorFlow Authors. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2015, The TensorFlow Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
85
third_party/tensorrt/build_defs.bzl
vendored
Normal file
85
third_party/tensorrt/build_defs.bzl
vendored
Normal file
@ -0,0 +1,85 @@
|
||||
# -*- python -*-
|
||||
"""
|
||||
add a repo_generator rule for tensorrt
|
||||
|
||||
"""
|
||||
|
||||
_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH"
|
||||
_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION"
|
||||
|
||||
def _is_trt_enabled(repo_ctx):
|
||||
if "TF_NEED_TENSORRT" in repo_ctx.os.environ:
|
||||
enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip()
|
||||
return enable_trt == "1"
|
||||
return False
|
||||
|
||||
def _dummy_repo(repo_ctx):
|
||||
|
||||
repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
|
||||
{"%{tensorrt_lib}":"","%{tensorrt_genrules}":""},
|
||||
False)
|
||||
repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
|
||||
{"%{trt_configured}":"False"},False)
|
||||
repo_ctx.file("include/NvUtils.h","",False)
|
||||
repo_ctx.file("include/NvInfer.h","",False)
|
||||
|
||||
def _trt_repo_impl(repo_ctx):
|
||||
"""
|
||||
Implements local_config_tensorrt
|
||||
"""
|
||||
|
||||
if not _is_trt_enabled(repo_ctx):
|
||||
_dummy_repo(repo_ctx)
|
||||
return
|
||||
trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH]
|
||||
trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION]
|
||||
# if deb installation
|
||||
# once a standardized installation between tar and deb
|
||||
# is done, we don't need this
|
||||
if trt_libdir == '/usr/lib/x86_64-linux-gnu':
|
||||
incPath='/usr/include/x86_64-linux-gnu'
|
||||
incname='/usr/include/x86_64-linux-gnu/NvInfer.h'
|
||||
else:
|
||||
incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath)
|
||||
incname=incPath+'/NvInfer.h'
|
||||
if len(trt_ver)>0:
|
||||
origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver)
|
||||
else:
|
||||
origLib="%s/libnvinfer.so"%trt_libdir
|
||||
objdump=repo_ctx.which("objdump")
|
||||
if objdump == None:
|
||||
if len(trt_ver)>0:
|
||||
targetlib="lib/libnvinfer.so.%s"%(trt_ver[0])
|
||||
else:
|
||||
targetlib="lib/libnvinfer.so"
|
||||
else:
|
||||
soname=repo_ctx.execute([objdump,"-p",origLib])
|
||||
for l in soname.stdout.splitlines():
|
||||
if "SONAME" in l:
|
||||
lib=l.strip().split(" ")[-1]
|
||||
targetlib="lib/%s"%(lib)
|
||||
|
||||
if len(trt_ver)>0:
|
||||
repo_ctx.symlink(origLib,targetlib)
|
||||
else:
|
||||
repo_ctx.symlink(origLib,targetlib)
|
||||
grule=('genrule(\n name = "trtlinks",\n'+
|
||||
' outs = [\n "%s",\n "include/NvInfer.h",\n "include/NvUtils.h",\n ],\n'%targetlib +
|
||||
' cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) +
|
||||
'&&\n ln -sf %s $(@D)/include/NvInfer.h '%(incname) +
|
||||
'&&\n ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath))
|
||||
repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
|
||||
{"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule},
|
||||
False)
|
||||
repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
|
||||
{"%{trt_configured}":"True"},False)
|
||||
|
||||
trt_repository=repository_rule(
|
||||
implementation= _trt_repo_impl,
|
||||
local=True,
|
||||
environ=[
|
||||
"TF_NEED_TENSORRT",
|
||||
_TF_TENSORRT_VERSION,
|
||||
_TENSORRT_INSTALLATION_PATH,
|
||||
],
|
||||
)
|
18
third_party/tensorrt/build_defs.bzl.tpl
vendored
Normal file
18
third_party/tensorrt/build_defs.bzl.tpl
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
# -*- python -*-
|
||||
"""
|
||||
template file for trt functions
|
||||
|
||||
"""
|
||||
|
||||
def is_trt_enabled():
|
||||
return %{trt_configured}
|
||||
|
||||
def if_trt(if_true,if_false=[]):
|
||||
# if is_trt_enabled():
|
||||
# return if_true
|
||||
# return if_false
|
||||
|
||||
return select({
|
||||
"@local_config_tensorrt//:trt_enabled":if_true,
|
||||
"//conditions:default":if_false,
|
||||
})
|
Loading…
Reference in New Issue
Block a user