Introducing TensortRT Operator to TF which can run (sub)graphs in

highly optimized TensorRT engines.  This commit is a merged version of
many commits by

   benbarsdell    <bbarsdell at nvidia.com>
   deadeyegoodwin <davidg at nvidia.com
   jjsjann123     <jiej at nvidia.com>
   samikama      <skama at  nvidia.com>
This commit is contained in:
Sami Kama 2018-01-19 22:58:50 +00:00
parent e810b107d8
commit 825e7a32e9
35 changed files with 4589 additions and 16 deletions

View File

@ -37,12 +37,14 @@ _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)),
_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'WORKSPACE')
_DEFAULT_CUDA_VERSION = '9.0'
_DEFAULT_TENSORRT_VERSION = '4'
_DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu'
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@ -382,13 +384,12 @@ def set_build_var(environ_cp, var_name, query_item, option_name,
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
environ_cp[var_name] = var
if var == '1':
write_to_bazelrc('build --define %s=true' % option_name)
elif bazel_config_name is not None:
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
# options and not to set build configs through environment variables.
write_to_bazelrc('build:%s --define %s=true'
% (bazel_config_name, option_name))
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
# options and not to set build configs through environment variables.
if var=='1':
setting='true'
confname=":%s"%(bazel_config_name) if bazel_config_name is not None else ""
write_to_bazelrc('build%s --define %s=%s' % (confname,option_name,setting))
def set_action_env_var(environ_cp,
@ -438,13 +439,12 @@ def convert_version_to_int(version):
for seg in version_segments:
if not seg.isdigit():
return None
version_str = ''.join(['%03d' % int(seg) for seg in version_segments])
return int(version_str)
def check_bazel_version(min_version):
"""Check installed bezel version is at least min_version.
"""Check installed bazel version is at least min_version.
Args:
min_version: string for minimum bazel version.
@ -1056,6 +1056,108 @@ def set_other_cuda_vars(environ_cp):
write_to_bazelrc('test --config=cuda')
def set_tf_trt_version(environ_cp):
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION."""
ask_trt_version = (
'Please specify the TensorRT (libnvinfer) version you want to use. '
'[Leave empty to default to libnvinfer %s]: ') % _DEFAULT_TENSORRT_VERSION
while True:
tf_trt_version = get_from_env_or_user_or_default(
environ_cp, 'TF_TENSORRT_VERSION', ask_trt_version,
_DEFAULT_TENSORRT_VERSION)
# if library version is passed and known
default_trt_path = environ_cp.get('TENSORRT_INSTALL_PATH',_DEFAULT_TENSORRT_PATH_LINUX)
ask_trt_path = (r'Please specify the location where libnvinfer %s library is '
'installed. Refer to README.md for more details. [Default'
' is %s]:') % (tf_trt_version, default_trt_path)
trt_install_path = get_from_env_or_user_or_default(
environ_cp, 'TENSORRT_INSTALL_PATH', ask_trt_path, default_trt_path)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
trt_install_path = os.path.realpath(
os.path.expanduser(trt_install_path))
# Simple function to search for libnvinfer in install path
# it will find all libnvinfer.so* in user defined install path
# and lib64 subdirectory and return absolute paths
def find_libs(search_path):
fl=set()
if os.path.exists(search_path) and os.path.isdir(search_path):
fl.update([os.path.realpath(os.path.join(search_path,x)) \
for x in os.listdir(search_path) if 'libnvinfer.so' in x])
return fl
possible_files=find_libs(trt_install_path)
possible_files.update(find_libs(os.path.join(trt_install_path,'lib64')))
if is_linux():
cudnnpatt=re.compile(".*libcudnn.so\.?(.*) =>.*$")
cudapatt =re.compile(".*libcudart.so\.?(.*) =>.*$")
def is_compatible(lib,cudaver,cudnnver):
ldd_bin=which('ldd') or '/usr/bin/ldd'
ldd_out=run_shell([ldd_bin,lib]).split(os.linesep)
for l in ldd_out:
if 'libcudnn.so' in l:
cudnn=cudnnpatt.search(l)
elif 'libcudart.so' in l:
cudart=cudapatt.search(l)
if cudnn:
cudnn=convert_version_to_int(cudnn.group(1)) if len(cudnn.group(1)) else 0
if cudart:
cudart=convert_version_to_int(cudart.group(1)) if len(cudart.group(1)) else 0
return (cudnn==cudnnver) and (cudart==cudaver)
cudaver=convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnnver=convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
valid_libs=[]
vfinder=re.compile('.*libnvinfer.so.?(.*)$')
highest_ver=[0,None,None]
for l in possible_files:
if is_compatible(l,cudaver,cudnnver):
valid_libs.append(l)
vstr=vfinder.search(l).group(1)
currver=convert_version_to_int(vstr) if len(vstr) else 0
if currver > highest_ver[0]:
highest_ver= [currver,vstr,l]
if highest_ver[1] is not None:
trt_install_path=os.path.dirname(highest_ver[2])
tf_trt_version=highest_ver[1]
break
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
libnvinfer_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
libnvinfer_path_from_ldconfig = re.search('.*libnvinfer.so.* => (.*)',
libnvinfer_path_from_ldconfig)
if libnvinfer_path_from_ldconfig:
libnvinfer_path_from_ldconfig = libnvinfer_path_from_ldconfig.group(1)
if os.path.exists('%s.%s' % (libnvinfer_path_from_ldconfig,
tf_trt_version)):
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
break
# Reset and Retry
if len(possible_files):
print(
'Invalid path to TensorRT %s. libnvinfer.so* files found are for incompatible cuda versions '
% tf_trt_version)
print(trt_install_path)
print(os.path.join(trt_install_path,'lib64'))
else:
print(
'Invalid path to TensorRT %s. No libnvinfer.so* files found in '
'found:' % tf_trt_version)
print(trt_install_path)
print(os.path.join(trt_install_path,'lib64'))
if is_linux():
print('%s.%s' % (libnvinfer_path_from_ldconfig, tf_trt_version))
environ_cp['TF_TENSORRT_VERSION'] = ''
# Set TENSORRT_INSTALL_PATH and TENSORRT_CUDNN_VERSION
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
environ_cp['TF_TENSORRT_VERSION'] = tf_trt_version
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_trt_version)
write_to_bazelrc('build:tensorrt --define using_tensorrt=true')
def set_host_cxx_compiler(environ_cp):
"""Set HOST_CXX_COMPILER."""
default_cxx_host_compiler = which('g++') or ''
@ -1244,9 +1346,11 @@ def main():
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'
if is_macos():
environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
'with_jemalloc', True)
@ -1301,6 +1405,10 @@ def main():
if not is_windows():
set_gcc_host_compiler_path(environ_cp)
set_other_cuda_vars(environ_cp)
# enable tensorrt if desired. Disabled on non-linux
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
if environ_cp.get('TF_NEED_TENSORRT') == '1':
set_tf_trt_version(environ_cp)
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
if environ_cp.get('TF_NEED_MPI') == '1':

View File

@ -358,6 +358,14 @@ config_setting(
},
)
config_setting(
name = "using_tensorrt",
define_values = {
"using_tensorrt":"true",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "with_mpi_support",
values = {"define": "with_mpi_support=true"},

View File

@ -7,6 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
load("//third_party/mpi:mpi.bzl", "if_mpi")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
py_library(
name = "contrib_py",
@ -104,7 +105,9 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"])
+ if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
)
cc_library(

View File

@ -0,0 +1,266 @@
# -*- python -*-
# Description:
# provide tensorrt operators and converter package
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
"tf_py_wrap_cc",
"tf_cc_test",
"tf_kernel_library",
"tf_custom_op_py_library",
"tf_copts",
)
tf_custom_op_library(
name = "python/ops/_trt_engine_op.so",
srcs = [
"kernels/trt_engine_op.cc",
"ops/trt_engine_op.cc",
"kernels/trt_engine_op.h",
],
gpu_srcs = [],
deps = [
"@local_config_tensorrt//:tensorrt",
":trt_shape_function",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core/kernels:bounds_check_lib",
"//tensorflow/core/kernels:ops_util_hdrs",
],
)
cc_library(
name = "trt_shape_function",
srcs=[
"shape_fn/trt_shfn.cc",
],
hdrs=["shape_fn/trt_shfn.h"],
copts=tf_copts(),
deps=[
":trt_logging",
"//third_party/eigen3",
"@local_config_tensorrt//:tensorrt",
"@protobuf_archive//:protobuf",
"@nsync//:nsync_headers",
"//tensorflow/core:framework_headers_lib",
]
)
tf_kernel_library(
name = "trt_engine_op_kernel",
srcs = [
"kernels/trt_engine_op.cc",
],
hdrs=[
"kernels/trt_engine_op.h",
],
gpu_srcs = [
],
deps = [
":trt_logging",
":trt_shape_function",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
"//tensorflow/core:gpu_headers_lib",
"@local_config_tensorrt//:tensorrt",
"//tensorflow/core:lib_proto_parsing",
],
alwayslink=1,
)
tf_gen_op_libs(
op_lib_names = [
"trt_engine_op",
],
deps=[
"@local_config_tensorrt//:tensorrt",
]
)
cc_library(
name="trt_logging",
srcs = [
"log/trt_logger.cc",
],
hdrs=[
"log/trt_logger.h",
],
deps=[
"@local_config_tensorrt//:tensorrt",
"//tensorflow/core:lib_proto_parsing",
],
visibility = ["//visibility:public"],
)
tf_gen_op_wrapper_py(
name = "trt_engine_op",
deps = [
":trt_engine_op_op_lib",
":trt_shape_function",
],
)
tf_custom_op_py_library(
name = "trt_engine_op_loader",
srcs = ["python/ops/trt_engine_op.py"],
dso = [":python/ops/_trt_engine_op.so",
"@local_config_tensorrt//:tensorrt",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:resources",
],
)
py_library(
name = "init_py",
srcs = [
"__init__.py",
"python/__init__.py",
],
srcs_version = "PY2AND3",
deps = [
":trt_ops_py",
":trt_convert_py",
],
)
py_library(
name="trt_ops_py",
srcs_version = "PY2AND3",
deps=[":trt_engine_op",
":trt_engine_op_loader",
],
)
py_library(
name="trt_convert_py",
srcs=["python/trt_convert.py"],
srcs_version = "PY2AND3",
deps=[
":wrap_conversion"
],
)
tf_py_wrap_cc(
name="wrap_conversion",
srcs=["trt_conversion.i"],
deps=[
":trt_conversion",
"//tensorflow/core:framework_lite",
"//util/python:python_headers",
],
)
cc_library(
name= "trt_conversion",
srcs=[
"convert/convert_nodes.cc",
"convert/convert_graph.cc",
"segment/segment.cc",
"convert/inferShapes.cc",
],
hdrs=[
"convert/convert_nodes.h",
"convert/convert_graph.h",
"convert/inferShapes.h",
"segment/segment.h",
"segment/union_find.h",
],
deps=[
"@local_config_tensorrt//:tensorrt",
"@protobuf_archive//:protobuf_headers",
"@nsync//:nsync_headers",
":trt_logging",
"//tensorflow/core:framework_lite",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:core_cpu_base",
#"//third_party/eigen3",
],
)
tf_custom_op_library(
name = "tensorrt_ops.so",
srcs = [
"ops/tensorrt_ops.cc",
],
deps = [
"@local_config_tensorrt//:tensorrt",
],
)
# Library for the segmenting portion of TensorRT operation creation
cc_library(
name = "segment",
srcs = [
"segment/segment.cc",
],
hdrs = [
"segment/union_find.h",
"segment/segment.h",
],
deps = [
"@protobuf_archive//:protobuf_headers",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib_proto_parsing",
"//third_party/eigen3",
],
linkstatic = 1,
)
tf_cc_test(
name = "segment_test",
size = "small",
srcs = ["segment/segment_test.cc"],
deps = [
":segment",
"//tensorflow/c:c_api",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Library for the node-level conversion portion of TensorRT operation creation
filegroup(
name = "cppfiles",
srcs = glob(["**/*.cc"]),
visibility=["//visibility:private"],
)
filegroup(
name = "headers",
srcs = glob(["**/*.h"]),
visibility=["//visibility:private"],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,42 @@
Using TensorRT in TensorFlow
============================
This module provides necessary bindings and introduces TRT_engine_op
operator that wraps a subgraph in TensorRT.
Compilation
-----------
In order to compile the module, you need to have a local TensorRT
installation (libnvinfer.so and respective include files). During the
configuration step, TensorRT should be enabled and installation path
should be set. If installed through package managers (deb,rpm),
configure script should find the necessary components from the system
automatically. If installed from tar packages, user has to set path to
location where the library is installed during configuration.
In order to enable TensorRT support, user has to add `--config=tensorrt` to
the build flags during the compilation such as
```
bazel build --config=cuda --config=opt --config=tensorrt //tensorflow/tools/pip_package:build_pip_package
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
```
After the installation of tensorflow package, TensorRT transformation
will be available. An example use is shown below.
```python
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
#... create and train or load model
gdef=sess.graph.as_graph_def()
trt_gdef=trt.CreateInferenceGraph(gdef, #original graph_def
["output"], #name of output node(s)
max_batch_size, #maximum batch size to run the inference
max_workspace_size # max memory for TensorRT to use
)
tf.reset_default_graph()
tf.import_graph_def(graph_def=trt_gdef)
#...... run inference
```

View File

@ -0,0 +1,19 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tensorrt.python import *

View File

@ -0,0 +1,253 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include <list>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <map>
#include <utility>
#include "NvInfer.h"
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
//------------------------------------------------------------------------------
namespace tensorrt {
namespace convert {
namespace {
static std::unordered_set<std::string> output_nodes;
bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
static const std::set<std::string> candidate_ops = {
"Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
"Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
// TODO(ben,jie): ...
};
if (output_nodes.count(node_def.name())) return false;
return candidate_ops.count(node_def.op());
}
void GetSubGraphIncomingEdges(tensorflow::Graph const& graph,
std::set<int> const& subgraph_node_ids,
tensorflow::EdgeSet* incoming_edges) {
for (int node_id : subgraph_node_ids) {
tensorflow::Node const* node = graph.FindNodeId(node_id);
LOG(DEBUG) << node->name() << " has incoming edges: ";
for (tensorflow::Edge const* edge : node->in_edges()) {
if (!subgraph_node_ids.count(edge->src()->id()) &&
!edge->src()->IsSource()) {
LOG(DEBUG) << edge->src()->name() << ", ";
incoming_edges->insert(edge);
}
}
}
}
void GetSubGraphOutgoingEdges(tensorflow::Graph const& graph,
std::set<int> const& subgraph_node_ids,
tensorflow::EdgeSet* outgoing_edges) {
for (int node_id : subgraph_node_ids) {
tensorflow::Node const* node = graph.FindNodeId(node_id);
LOG(DEBUG) << node->name() << " has outgoing edges: ";
for (tensorflow::Edge const* edge : node->out_edges()) {
if (!subgraph_node_ids.count(edge->dst()->id()) &&
!edge->dst()->IsSink()) {
outgoing_edges->insert(edge);
}
}
}
}
std::pair<std::string, int> ParseTensorName(std::string name,
int default_idx = 0) {
int idx = default_idx;
size_t sep = name.find_last_of(':');
if (sep != std::string::npos) {
name = name.substr(0, sep);
idx = std::stoi(name.substr(sep + 1));
}
return std::make_pair(name, idx);
}
std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap(
const std::vector<std::string>& tensor_names) {
std::unordered_map<std::string, std::vector<int>> result;
for (std::string const& tensor_name : tensor_names) {
std::string node_name;
int index;
std::tie(node_name, index) = ParseTensorName(tensor_name);
result[node_name].push_back(index);
}
return result;
}
tensorflow::Status ConvertSubGraphToTensorRT(
tensorflow::Graph& graph, const std::vector<std::string>& output_names,
const std::set<int>& subgraph_node_ids, size_t max_batch_size,
size_t max_workspace_size, const ShapeMap& shape_map) {
tensorflow::EdgeSet subgraph_incoming_edges;
GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges);
std::vector<std::pair<int, int>> subgraph_inputs;
// Collect inputs by looking for incoming edges
for (tensorflow::Edge const* edge : subgraph_incoming_edges) {
subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
}
std::set<std::pair<int, int>> subgraph_outputs_set;
// Collect outputs referenced from output_names
auto output_name_to_index_map = BuildTensorNameMap(output_names);
// for (int node_id : subgraph_node_ids_no_placeholder) {
for (int node_id : subgraph_node_ids) {
tensorflow::Node* node = graph.FindNodeId(node_id);
if (output_name_to_index_map.count(node->name())) {
for (int index : output_name_to_index_map.at(node->name())) {
subgraph_outputs_set.insert({node_id, index});
}
}
}
// Collect outputs referenced from outgoing edges
tensorflow::EdgeSet subgraph_outgoing_edges;
// GetSubGraphOutgoingEdges(graph, subgraph_node_ids_no_placeholder,
// &subgraph_outgoing_edges);
GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges);
for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
}
// Impose an ordering on the outputs
std::vector<std::pair<int, int>> subgraph_outputs(
subgraph_outputs_set.begin(), subgraph_outputs_set.end());
// Build TensorRT node and add it to the graph
tensorflow::NodeDef trt_node_def;
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
max_batch_size, max_workspace_size, shape_map, &trt_node_def));
tensorflow::Status status;
tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status);
TF_RETURN_IF_ERROR(status);
// Re-map outgoing edges to use the new TRT node instead of the orig subgraph
std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
}
TF_RETURN_IF_ERROR(status);
for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
int new_src_output = subgraph_edge_to_output_map.at(old_src);
graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
}
// Remove the original subgraph
for (int node_id : subgraph_node_ids) {
tensorflow::Node* node = graph.FindNodeId(node_id);
// Don't remove the input placeholders
if (node->type_string() == "Placeholder") {
continue;
}
graph.RemoveNode(node);
}
return tensorflow::Status::OK();
}
tensorflow::Status BuildNodeMap(
const tensorflow::Graph& graph,
std::unordered_map<std::string, tensorflow::Node*>* node_map) {
for (auto* node : graph.op_nodes()) {
if (!node_map->insert({node->name(), node}).second) {
return tensorflow::errors::AlreadyExists(
"Node name is not unique in graph: " + node->name());
}
}
return tensorflow::Status::OK();
}
} // namespace
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<std::string>& output_names, size_t max_batch_size,
size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) {
ShapeMap shape_map;
TF_RETURN_IF_ERROR(
tensorflow::trt::inferShapes(graph_def, output_names, shape_map));
std::stringstream oss;
for (auto& n : shape_map) { // nodes
oss << " Node= " << n.first << ", ";
for (auto o : n.second) { // outputs
oss << o.first.DebugString() << " T= " << o.second << ", ";
}
LOG(DEBUG) << oss.str();
oss.str("");
}
// Build full graph
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
graph_def.library());
tensorflow::Graph graph(flib);
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
tensorflow::GraphConstructorOptions(), graph_def, &graph));
// Segment the graph into subgraphs that can be converted to TensorRT
tensorrt::segment::SegmentOptions segment_options;
// TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
for (auto node : output_names) output_nodes.insert(node);
// TODO(sami): this should be passed as a knob!!!!
segment_options.minimum_segment_size = 2;
tensorrt::segment::SegmentNodesVector segments;
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
graph_def, IsTensorRTCandidate, segment_options, &segments));
if (segments.size() > 1) {
// LOG(WARNING) << "Multiple TensorRT candidate subgraphs were found, "
//<< "but only the first can be converted.";
// segments.erase(++segments.begin(), segments.end());
LOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
}
std::unordered_map<std::string, tensorflow::Node*> node_map;
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
for (std::set<std::string> const& subgraph_node_names : segments) {
std::set<int> subgraph_node_ids;
for (std::string const& node_name : subgraph_node_names) {
subgraph_node_ids.insert(node_map.at(node_name)->id());
}
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
graph, output_names, subgraph_node_ids, max_batch_size,
max_workspace_size, shape_map));
}
graph.ToGraphDef(new_graph_def);
return tensorflow::Status::OK();
}
} // namespace convert
} // namespace tensorrt

View File

@ -0,0 +1,34 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
#include <string>
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorrt {
namespace convert {
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<std::string>& output_names, size_t max_batch_size,
size_t max_workspace_size, tensorflow::GraphDef* new_graph_def);
}
} // namespace tensorrt
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,42 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#include <set>
#include <vector>
#include <utility>
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorrt {
namespace convert {
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
const std::vector<std::pair<int, int>>&
input_inds, // {node_id, output_idx}
const std::vector<std::pair<int, int>>&
output_inds, // {node_id, output_idx}
size_t max_batch_size, size_t max_workspace_size, const ShapeMap& shape_map,
tensorflow::NodeDef* trt_node);
} // namespace convert
} // namespace tensorrt
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_

View File

@ -0,0 +1,125 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
#include <functional>
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb_text.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
namespace tensorflow {
namespace trt {
std::vector<tensorflow::DataType> getTypes(const tensorflow::OpDef& op,
const tensorflow::NodeDef& nd,
bool inp = true) {
const auto& attrMap = nd.attr();
auto getType = [&attrMap](decltype(
op.input_arg(0)) a) -> std::vector<tensorflow::DataType> {
std::vector<tensorflow::DataType> tvec;
if (!a.type_list_attr().empty()) { // get the list types
const auto& tl = attrMap.at(a.type_list_attr()).list();
int tsize = tl.type_size();
tvec.reserve(tsize);
for (int t = 0; t < tsize; t++) {
tvec.push_back(tl.type(t));
}
return tvec;
}
tensorflow::DataType cType = tensorflow::DT_INVALID;
if (a.type() != tensorflow::DT_INVALID) { // get defined types
cType = a.type();
} else if (!a.type_attr().empty()) {
cType = attrMap.at(a.type_attr()).type();
}
if (!a.number_attr().empty()) { // numbertypes
int64 nTensors = attrMap.at(a.number_attr()).i();
tvec = std::vector<tensorflow::DataType>(nTensors, cType);
return tvec;
}
tvec.push_back(cType);
return tvec;
};
std::vector<tensorflow::DataType> types;
if (inp) {
int n_inputs = op.input_arg_size();
for (int i = 0; i < n_inputs; i++) {
auto tout = getType(op.input_arg(i));
LOG(DEBUG) << "Node= " << nd.name() << " #inputs" << tout.size();
types.insert(types.end(), tout.begin(), tout.end());
}
} else {
int n_outputs = op.output_arg_size();
// types.resize(n_outputs);
for (int i = 0; i < n_outputs; i++) {
auto tout = getType(op.output_arg(i));
LOG(DEBUG) << "Node= " << nd.name() << " #outputs" << tout.size();
types.insert(types.end(), tout.begin(), tout.end());
}
}
return types;
}
tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
const std::vector<std::string>& output_names,
ShapeMap& shapes) {
tensorflow::Graph g(OpRegistry::Global());
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
tensorflow::GraphConstructorOptions(), graph_def, &g));
std::vector<tensorflow::Node*> POnodes;
tensorflow::GetPostOrder(g, &POnodes);
tensorflow::ShapeRefiner refiner(graph_def.versions().producer(),
OpRegistry::Global());
for (auto n = POnodes.rbegin(); n != POnodes.rend(); ++n) {
TF_CHECK_OK(refiner.AddNode(*n));
}
auto shape2PTS = [](tensorflow::shape_inference::InferenceContext* ic,
const tensorflow::shape_inference::ShapeHandle& sh)
-> tensorflow::PartialTensorShape {
std::vector<int64> dims;
int64 rank = ic->Rank(sh);
for (int64 i = 0; i < rank; i++) {
auto dh = ic->Dim(sh, i);
dims.push_back(ic->Value(dh));
}
return tensorflow::PartialTensorShape(dims);
};
for (const auto& n : POnodes) {
auto ic = refiner.GetContext(n);
if (ic) {
int nOuts = ic->num_outputs();
auto types = getTypes(n->op_def(), n->def(), false);
std::vector<
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>
SAT;
for (int i = 0; i < nOuts; i++) {
auto PTS = shape2PTS(ic, ic->output(i));
SAT.push_back({PTS, types.at(i)});
}
shapes[n->name()] = SAT;
} else {
LOG(WARNING) << "Node " << n->name() << " doesn't have InferenceContext!";
}
}
return tensorflow::Status::OK();
}
} // namespace trt
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
typedef std::unordered_map<std::string,
std::vector<std::pair<tensorflow::PartialTensorShape,
tensorflow::DataType>>>
ShapeMap;
namespace tensorflow {
namespace trt {
tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
const std::vector<std::string>& output_names,
ShapeMap& shapes);
}
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_

View File

@ -0,0 +1,183 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#include <cuda_runtime_api.h>
#include <sstream>
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
// Use TF logging f
namespace tensorflow {
static ::tensorflow::tensorrt::Logger gLogger;
using namespace nvinfer1;
namespace tensorrt {
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
// char *gieModelStream{nullptr};
// size_t size{0};
// read serialized_engine
std::string serialized_engine;
OP_REQUIRES_OK(context,
context->GetAttr("serialized_engine", &serialized_engine));
// register input output node name in trt_sub_graph
OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
// TODO(samikama) runtime should be taken from a resourcemanager as well.
// Only engine should be in the op and context and runtime should be taken
// from resourcemanager
IRuntime* infer = createInferRuntime(gLogger);
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(), nullptr));
trt_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
// runtime is safe to delete after engine creation
infer->destroy();
std::stringstream oss;
// debug iterate through all binding instances
for (int i = 0; i < trt_engine_ptr_->getNbBindings(); i++) {
LOG(INFO) << "index: " << i
<< ", binding name: " << trt_engine_ptr_->getBindingName(i);
if (trt_engine_ptr_->bindingIsInput(i)) {
LOG(INFO) << "INPUT";
} else {
LOG(INFO) << "OUTPUT";
}
oss << "Dimension: ";
auto dims = trt_engine_ptr_->getBindingDimensions(i);
oss << " nbDims: " << dims.nbDims << " -> ";
for (int j = 0; j < Dims::MAX_DIMS; j++) {
oss << dims.d[j] << ", ";
}
LOG(INFO) << oss.str();
oss.str("");
switch (trt_engine_ptr_->getBindingDataType(i)) {
case nvinfer1::DataType::kFLOAT:
LOG(INFO) << "data type float" << std::endl;
break;
case nvinfer1::DataType::kHALF:
LOG(INFO) << "data type half" << std::endl;
break;
case nvinfer1::DataType::kINT8:
LOG(INFO) << "data type int8" << std::endl;
break;
}
}
// CHECK_NE(cudaStreamCreate(&stream_),0); // logic here is wrong
// cudaStreamCreate(&stream_);
}
void TRTEngineOp::Compute(OpKernelContext* context) {
int nbBindings = context->num_inputs() + context->num_outputs();
// TODO(jjsjann123) multiple input/output
std::vector<void*> buffers(nbBindings);
size_t bindingIndex;
int nbBatch = 0;
bool valid = true;
for (int i = 0; i < context->num_inputs(); i++) {
// Grab the input tensor
bindingIndex = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
const Tensor& input_tensor = context->input(i);
const TensorShape& input_shape = input_tensor.shape();
if (i == 0) {
nbBatch = input_shape.dim_size(0);
} else if (nbBatch != input_shape.dim_size(0)) {
valid = false;
break;
}
// int64 input_shape.dim_size(int d)
// int input_shape.dims()
switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
case nvinfer1::DataType::kFLOAT:
LOG(INFO) << "float";
buffers[bindingIndex] = (void*)(input_tensor.flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(INFO) << "half";
// buffers[bindingIndex] = (void*)input_tensor.flat<float16>().data();
break;
case nvinfer1::DataType::kINT8:
LOG(INFO) << "int8";
// buffers[bindingIndex] = (void*)input_tensor.flat<int8>().data();
break;
}
}
if (!valid) LOG(WARNING) << "input data inconsistent batch size";
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
// This is bad that we have to reallocate output buffer every run.
// Create an output tensor
bindingIndex = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
Tensor* output_tensor = NULL;
TensorShape output_shape;
if (bindingIndex != -1) {
LOG(INFO) << "got binding " << bindingIndex;
auto dims = trt_engine_ptr_->getBindingDimensions(bindingIndex);
std::vector<int> trt_shape(dims.nbDims + 1);
trt_shape[0] = nbBatch;
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
&output_shape);
} else {
LOG(INFO) << "no binding ";
break;
}
OP_REQUIRES_OK(context,
context->allocate_output(i, output_shape, &output_tensor));
// buffers[bindingIndex] = (void*)output_tensor->flat<float>();
// buffers[bindingIndex] = output_tensor->flat<float>().data();
switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
case nvinfer1::DataType::kFLOAT:
LOG(INFO) << "float";
buffers[bindingIndex] =
reinterpret_cast<void*>(output_tensor->flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(INFO) << "half";
// buffers[bindingIndex] = (void*)output_tensor->flat<float16>().data();
break;
case nvinfer1::DataType::kINT8:
LOG(INFO) << "int8";
// buffers[bindingIndex] = (void*)output_tensor->flat<int8>().data();
break;
}
}
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
->stream()
->implementation()
->CudaStreamMemberHack()));
trt_context_ptr_->enqueue(nbBatch, &buffers[0], *stream, nullptr);
cudaStreamSynchronize(*stream);
}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
} // namespace tensorrt
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
#include <NvInfer.h>
#include <cuda_runtime_api.h>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace tensorrt {
class Logger;
class TRTEngineOp : public OpKernel {
public:
explicit TRTEngineOp(OpKernelConstruction* context);
void Compute(OpKernelContext* context) override;
private:
template <typename T>
struct Destroyer {
void operator()(T* d) { d->destroy(); }
};
template <typename T>
using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
// TODO(samikama) context should go to a resource manager!
destroyed_ptr<nvinfer1::IExecutionContext> trt_context_ptr_;
std::vector<string> input_nodes_;
std::vector<string> output_nodes_;
};
} // namespace tensorrt
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_

View File

@ -0,0 +1,56 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
// Use TF logging for TensorRT informations
#include "tensorflow/core/platform/logging.h"
#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
//------------------------------------------------------------------------------
namespace tensorflow {
//------------------------------------------------------------------------------
namespace tensorrt {
void Logger::log(Severity severity, const char* msg) {
// suppress info-level messages
switch (severity) {
case Severity::kINFO: { // mark TRT info messages as debug!
LOG(DEBUG) << msg;
break;
}
case Severity::kWARNING: {
LOG(WARNING) << msg;
break;
}
case Severity::kERROR: {
LOG(ERROR) << msg;
break;
}
case Severity::kINTERNAL_ERROR: {
LOG(FATAL) << msg;
break;
}
// This is useless for now. But would catch it in future if enum changes. It
// is always good to have default case!
default: {
LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg;
break;
}
}
}
} // namespace tensorrt
} // namespace tensorflow

View File

@ -0,0 +1,41 @@
// -*- c++ -*-
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
// Use TF logging f
#include <NvInfer.h>
#include <string>
//------------------------------------------------------------------------------
namespace tensorflow {
//------------------------------------------------------------------------------
namespace tensorrt {
// Logger for GIE info/warning/errors
class Logger : public nvinfer1::ILogger {
void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
private:
std::string name_;
};
} // namespace tensorrt
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace shape_inference {
extern Status TRTEngineOpShapeInference(InferenceContext* c);
}
REGISTER_OP("TRTEngineOp")
.Attr("serialized_engine: string")
.Attr("input_nodes: list(string)")
.Attr("output_nodes: list(string)")
.Attr("InT: list({int8, float16, float32})")
.Attr("OutT: list({int8, float16, float32})")
.Input("in_tensor: InT")
.Output("out_tensor: OutT")
.SetShapeFn(shape_inference::TRTEngineOpShapeInference);
} // namespace tensorflow

View File

@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph
# pylint: enable=unused-import,wildcard-import

View File

@ -0,0 +1,35 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
from tensorflow.contrib.util import loader
from tensorflow.python.platform import resource_loader
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
_trt_engine_op = loader.load_op_library(
resource_loader.get_path_to_datafile("_trt_engine_op.so"))
else:
raise RuntimeError("Windows platforms are not supported")

View File

@ -0,0 +1,91 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the Python wrapper conversion to trt_graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import, line-too-long
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl as _impl
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.python.util import compat
import tensorflow as tf
from tensorflow.python.grappler import tf_optimizer
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
def CreateInferenceGraph(input_graph_def, outputs,max_batch_size=1,max_workspace_size=2<<20):
"""Python wrapper for the TRT transormation.
Args:
input_graph_def: GraphDef object containing a model to be transformed.
outputs: List of node names for the model outputs.
max_batch_size: max size for the input batch
max_workspace_size: parameter to control memory allocation (in Bytes)
Returns:
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
"""
# with errors.raise_exception_on_not_ok_status() as status:
# output_graph_def_string = trt_convert(
# input_graph_def_string,outputs,
# max_batch_size,max_workspace_size, status)
g = tf.Graph()
with g.as_default():
tf.import_graph_def(input_graph_def, name="")
rewriter_config = rewriter_config_pb2.RewriterConfig()
rewriter_config.optimizers.append('layout')
rewriter_config.optimizers.append('constfold')
# mark output nodes as fetch
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
for node_name in outputs:
out_node = g.get_operation_by_name(node_name)
for i in range(0,len(out_node.outputs)):
train_op.append(out_node.outputs[0])
# constant folding
mg = meta_graph.create_meta_graph_def(graph=g)
meta_graph.add_collection_def(mg, ops.GraphKeys.TRAIN_OP)
optimized_graph_def_str = \
tf_optimizer.OptimizeGraph(rewriter_config, mg).SerializeToString()
# TODO(sami): Fix this when we can return status from C++ library
# There is a problem with the TF internal library setup that doesn't allow us to return a status object from C++.
# Thus we return a pair or strings where first one is encoded status and the second one is the
# transformed graphs protobuf string.
out = trt_convert(
optimized_graph_def_str ,outputs,
max_batch_size,max_workspace_size)
status = out[0]
output_graph_def_string = out[1]
del optimized_graph_def_str #save some memory
if len(status) < 2:
raise _impl.UnknownError(None,None,status)
if status[:2] != "OK":
msg=status.split(";")
if len(msg) == 1:
raise RuntimeError("Status message is malformed {}".format(status))
raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
output_graph_def = graph_pb2.GraphDef()
output_graph_def.ParseFromString(output_graph_def_string)
del output_graph_def_string #save some memory
return output_graph_def

View File

@ -0,0 +1,259 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/segment/segment.h"
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/contrib/tensorrt/segment/union_find.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
//------------------------------------------------------------------------------
namespace tensorrt {
namespace segment {
//------------------------------------------------------------------------------
namespace {
//------------------------------------------------------------------------------
bool CanContractEdge(const tensorflow::Edge* edge,
const tensorflow::Graph& graph) {
const tensorflow::Node* src = edge->src();
const tensorflow::Node* dst = edge->dst();
// Can't contract edge if doing so would cause a cycle in the
// graph. So, if there is a directed path from 'src' to 'dst', other
// than 'edge' (or any other direct edge from 'src' to 'dst'), then
// combining 'src' and 'dst' will cause a cycle along that path.
//
// In practice, to avoid modifying the graph and to take advantage
// of existing graph functions, we perform an equivalent.
// 1. Get all nodes incoming to 'dst', excluding 'src'
// 2. Reverse DFS from those nodes
// 3. If reverse DFS reaches 'src' then we have a cycle
std::vector<tensorflow::Node*> dfs_start_nodes;
for (tensorflow::Node* node : dst->in_nodes()) {
if (node != src) {
dfs_start_nodes.push_back(node);
}
}
bool is_cycle = false;
if (!dfs_start_nodes.empty()) {
tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
[&is_cycle, src](tensorflow::Node* node) {
if (node == src) {
is_cycle = true;
}
});
}
return !is_cycle;
}
//------------------------------------------------------------------------------
void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
std::vector<const tensorflow::Edge*>* remove_edges) {
// Transfer all inputs and outputs of 'dst' to 'src' except edges
// connecting the two.
tensorflow::Node* src = edge->src();
tensorflow::Node* dst = edge->dst();
// We can use '0' for input/output index because we don't need them
// to be accurate for the way we are using the graph.
std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
dst->in_edges().end());
for (const tensorflow::Edge* in_edge : in_edges) {
if (in_edge->src() != src) {
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
if (e->src() == graph->source_node()) {
graph->AddEdge(e->src(), e->src_output(), src,
tensorflow::Graph::kControlSlot);
} else {
graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
}
}
}
std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
dst->out_edges().end());
for (const tensorflow::Edge* out_edge : out_edges) {
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
if (e->dst() == graph->sink_node()) {
graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
e->dst_input());
} else {
graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
}
}
// Return the edges that must be removed to disconnect 'dst' from
// the graph. We don't actually remove 'dst' since the caller holds
// references to all the nodes.
for (const auto& in_edge : dst->in_edges()) {
remove_edges->push_back(in_edge);
}
for (const auto& out_edge : dst->out_edges()) {
remove_edges->push_back(out_edge);
}
}
} // namespace
//------------------------------------------------------------------------------
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments) {
// Create a Graph representation of the GraphDef.
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
gdef.library());
tensorflow::Graph graph(flib);
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
tensorflow::GraphConstructorOptions(), gdef, &graph));
// tensorflow::DumpGraph("Pre-Segment", &graph);
// Use a union-find to collect the nodes that belong to the same
// segment. A node value of nullptr indicates that the node is not a
// candidate for TRT.
std::vector<UnionFind<tensorflow::Node*>> node_segments;
for (int i = 0; i < graph.num_node_ids(); ++i) {
tensorflow::Node* node = graph.FindNodeId(i);
if (!candidate_fn(node->def())) {
node = nullptr;
}
node_segments.emplace_back(node);
}
// Visit nodes in reverse topological order and use edge
// contraction to merge candidate nodes.
std::vector<tensorflow::Node*> order;
tensorflow::GetPostOrder(graph, &order);
for (const tensorflow::Node* node : order) {
// All output nodes of 'node' have been visited...
VLOG(2) << "Trying node " << node->name();
// 'node' must be a TRT candidate...
if (node_segments[node->id()].Value() == nullptr) {
VLOG(2) << "... not a TRT candidate";
continue;
}
// Contract output edges to combine 'node' with output
// nodes. Iterate since combining two nodes may unblock other
// combining.
while (true) {
std::set<const tensorflow::Edge*> contract_edges;
for (const tensorflow::Edge* out_edge : node->out_edges()) {
VLOG(2) << "... out node " << out_edge->dst()->name();
// Out node must be TRT candidate...
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
VLOG(2) << "... ... not a TRT candidate";
continue;
}
if (CanContractEdge(out_edge, graph)) {
VLOG(2) << "... ... can contract";
contract_edges.insert(out_edge);
} else {
VLOG(2) << "... ... cannot contract, would form cycle";
}
}
if (contract_edges.empty()) {
break;
}
// Contract edges and collect the adjacent nodes into the same
// segment/subgraph.
while (!contract_edges.empty()) {
const tensorflow::Edge* contract_edge = *contract_edges.begin();
const tensorflow::Node* src = contract_edge->src();
const tensorflow::Node* dst = contract_edge->dst();
VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
node_segments[src->id()].Merge(&node_segments[dst->id()]);
// Contracting the edge leaves disconnected graph edges.
// Remove these from the graph and from 'contract_edges' so we
// don't visit them again.
tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
std::vector<const tensorflow::Edge*> remove_edges;
ContractEdge(e, &graph, &remove_edges);
for (const tensorflow::Edge* r : remove_edges) {
contract_edges.erase(r);
graph.RemoveEdge(r);
}
}
}
}
// Collect the segments/subgraphs. Each subgraph is represented by a
// set of the names of the nodes in that subgraph.
std::unordered_map<std::string, std::set<std::string>> sg_map;
for (auto& u : node_segments) {
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
sg_map[u.ParentValue()->name()].insert(u.Value()->name());
}
}
// Cleanup the graph to remove disconnected nodes before outputting
if (VLOG_IS_ON(2)) {
for (tensorflow::Node* node : graph.nodes()) {
if ((node->in_edges().size() == 0) && (node->out_edges().size() == 0)) {
graph.RemoveNode(node);
}
}
// tensorflow::DumpGraph("Post-Segment", &graph);
}
// Convert the segments into the expected return format
for (const auto& itr : sg_map) {
const auto& segment_node_names = itr.second;
if (VLOG_IS_ON(1)) {
std::string s;
for (const auto& name : segment_node_names) {
s += " " + name;
}
VLOG(1) << "Segment " << segments->size() << ":" << s;
}
// Don't use small segments.
if (static_cast<int>(segment_node_names.size()) <
options.minimum_segment_size) {
VLOG(1) << "Segment " << segments->size() << " has only "
<< segment_node_names.size() << " nodes, dropping";
continue;
}
segments->emplace_back(segment_node_names);
}
return tensorflow::Status::OK();
}
} // namespace segment
} // namespace tensorrt

View File

@ -0,0 +1,53 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
#include <set>
#include <vector>
#include <string>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorrt {
namespace segment {
using SegmentNodesVector = std::vector<std::set<std::string>>;
struct SegmentOptions {
// Segment must contain at least this many nodes.
int minimum_segment_size = 2;
};
// Get the subgraphs of a graph that can be handled by TensorRT.
//
// @param gdef The GraphDef describing the network
// @param candidate_fn A function that returns true for a NodeDef if
// that node can be handled by TensorRT.
// @param segments Returns the TensorRT segments/subgraphs. Each entry
// in the vector describes a subgraph by giving a set of the names of
// all the NodeDefs in that subgraph.
// @return the status.
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments);
} // namespace segment
} // namespace tensorrt
#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_

View File

@ -0,0 +1,363 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/segment/segment.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
//------------------------------------------------------------------------------
using namespace tensorflow;
namespace tensorrt {
namespace segment {
namespace test {
class SegmentTest : public ::testing::Test {
public:
bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name);
std::function<bool(const NodeDef&)> MakeCandidateFn(
const std::set<std::string>& node_names);
protected:
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation** op);
void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name, TF_Operation** op, bool check);
SegmentOptions default_options_;
};
bool SegmentTest::GetGraphDef(TF_Graph* graph,
tensorflow::GraphDef* graph_def) {
TF_Status* s = TF_NewStatus();
TF_Buffer* buffer = TF_NewBuffer();
TF_GraphToGraphDef(graph, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
TF_DeleteBuffer(buffer);
TF_DeleteStatus(s);
return ret;
}
std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn(
const std::set<std::string>& node_names) {
return [node_names](const NodeDef& node) -> bool {
return node_names.find(node.name()) != node_names.end();
};
}
void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
const char* name, TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
const char* name) {
TF_Operation* op;
PlaceholderHelper(graph, s, name, &op);
return op;
}
void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name, TF_Operation** op,
bool check) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
*op = TF_FinishOperation(desc, s);
if (check) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
}
TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
TF_Graph* graph, TF_Status* s,
const char* name) {
TF_Operation* op;
AddHelper(l, r, graph, s, name, &op, true);
return op;
}
//------------------------------------------------------------------------------
TEST_F(SegmentTest, Empty) {
TF_Graph* graph = TF_NewGraph();
GraphDef graph_def;
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
SegmentNodesVector segments;
ASSERT_EQ(
SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
tensorflow::Status::OK());
// Expect no segments/subgraphs.
EXPECT_TRUE(segments.empty());
}
//------------------------------------------------------------------------------
TEST_F(SegmentTest, Simple) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// feed
// // ||
// add0 add1
// | | /
// | add2
// | / ||
// add3 add4
// | /
// <sink>
//
TF_Operation* feed = Placeholder(graph, s, "feed");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
GraphDef graph_def;
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
SegmentNodesVector segments;
ASSERT_EQ(
SegmentGraph(graph_def,
MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
default_options_, &segments),
tensorflow::Status::OK());
// Expect all Add operations to be collapsed into a single segment
ASSERT_EQ(segments.size(), 1);
std::vector<std::string> expected{"add0", "add1", "add2", "add3", "add4"};
for (const auto& ex : expected) {
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
<< "Missing expected node " << ex;
}
}
//------------------------------------------------------------------------------
TEST_F(SegmentTest, AvoidCycle) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// add2 is not a TRT candidate so add0/add3 cannot be formed as a
// subgraph
//
// feed
// // ||
// add0 add1
// | | /
// | add2
// | / ||
// add3 add4
// | /
// <sink>
//
TF_Operation* feed = Placeholder(graph, s, "feed");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
GraphDef graph_def;
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
SegmentNodesVector segments;
ASSERT_EQ(
SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
default_options_, &segments),
tensorflow::Status::OK());
// Expect no subgraphs
EXPECT_EQ(segments.size(), 0);
}
//------------------------------------------------------------------------------
TEST_F(SegmentTest, Multiple) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// add5 is not a TRT candidate so two subgraphs should be formed
//
// feed
// // || ||
// add0 add1 add7
// | | / / ||
// | add2-----add5 add8
// | / | | | |
// add3 add4 add6
// | | /
// <sink>
//
TF_Operation* feed = Placeholder(graph, s, "feed");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
GraphDef graph_def;
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
SegmentNodesVector segments;
ASSERT_EQ(SegmentGraph(graph_def,
MakeCandidateFn({"add0", "add1", "add2", "add3",
"add4", "add6", "add7", "add8"}),
default_options_, &segments),
tensorflow::Status::OK());
// Expect two subgraphs
EXPECT_EQ(segments.size(), 2);
std::vector<std::string> expected0{"add0", "add1", "add2", "add3"};
for (const auto& ex : expected0) {
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
<< "Missing expected node " << ex;
}
std::vector<std::string> expected1{"add6", "add8"};
for (const auto& ex : expected1) {
EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
<< "Missing expected node " << ex;
}
}
//------------------------------------------------------------------------------
TEST_F(SegmentTest, BigIfElse) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// add2 is not a TRT candidate
//
// feed
// ||
// add0
// // ||
// add1 add4
// || ||
// add2 add5
// || ||
// add3 add6
// || //
// add7
// ||
// <sink>
//
TF_Operation* feed = Placeholder(graph, s, "feed");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
GraphDef graph_def;
ASSERT_TRUE(GetGraphDef(graph, &graph_def));
SegmentNodesVector segments;
ASSERT_EQ(SegmentGraph(graph_def,
MakeCandidateFn({"add0", "add1", "add3", "add4",
"add5", "add6", "add7"}),
default_options_, &segments),
tensorflow::Status::OK());
// Expect 2 subgraphs
EXPECT_EQ(segments.size(), 2);
std::vector<std::string> expected0{"add3", "add4", "add5", "add6", "add7"};
for (const auto& ex : expected0) {
EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
<< "Missing expected node " << ex;
}
std::vector<std::string> expected1{"add0", "add1"};
for (const auto& ex : expected1) {
EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
<< "Missing expected node " << ex;
}
}
} // namespace test
} // namespace segment
} // namespace tensorrt

View File

@ -0,0 +1,77 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
namespace tensorrt {
namespace segment {
// Union-Find data structure.
// Each cluster has an associated value; when merging clusters we can control
// which value becomes the representative of the merged clusters. Values must be
// copyable.
template <typename T>
class UnionFind {
public:
UnionFind() : size_(1), parent_(nullptr) {}
explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {}
// Returns the number of elements in a cluster.
int Size() { return FindRoot()->size_; }
// Merges this cluster with 'other'. This cluster's value becomes
// the value of the merged cluster; the value of 'other' is ignored.
void Merge(UnionFind* other);
// Each cluster has an associated value. Retrieves the value associated
// with this cluster.
T& ParentValue() { return FindRoot()->value_; }
// Get the original value of this node.
T& Value() { return value_; }
private:
// Finds the root element of the cluster. Performs path compression.
UnionFind* FindRoot();
int size_;
UnionFind* parent_;
T value_;
};
template <typename T>
void UnionFind<T>::Merge(UnionFind* other) {
UnionFind<T>* a = FindRoot();
UnionFind<T>* b = other->FindRoot();
if (a == b) return;
b->parent_ = a;
a->size_ += b->size_;
}
template <typename T>
UnionFind<T>* UnionFind<T>::FindRoot() {
if (!parent_) return this;
// Path compression: update intermediate nodes to point to the root of the
// equivalence class.
parent_ = parent_->FindRoot();
return parent_;
}
} // namespace segment
} // namespace tensorrt
#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_

View File

@ -0,0 +1,123 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
#include <string>
#include <vector>
#include "NvInfer.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
namespace tensorflow {
namespace shape_inference {
tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) {
tensorflow::tensorrt::Logger gLogger;
string serialized_engine;
c->GetAttr("serialized_engine", &serialized_engine);
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger);
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(), nullptr);
// debug print out engine binding;
std::stringstream oss;
for (int i = 0; i < trt_engine->getNbBindings(); i++) {
LOG(INFO) << "index: " << i
<< ", binding name: " << trt_engine->getBindingName(i);
bool input_flag = trt_engine->bindingIsInput(i);
oss << "input?: " << (input_flag ? "Y" : "N");
oss << "Dimension: ";
auto dims = trt_engine->getBindingDimensions(i);
oss << " nbDims: " << dims.nbDims << " -> ";
for (int j = 0; j < dims.nbDims; j++) oss << dims.d[j] << ", ";
LOG(INFO) << oss.str();
oss.str("");
switch (trt_engine->getBindingDataType(i)) {
case nvinfer1::DataType::kFLOAT:
LOG(INFO) << "data type: float" << std::endl;
break;
case nvinfer1::DataType::kHALF:
LOG(INFO) << "data type: half" << std::endl;
break;
case nvinfer1::DataType::kINT8:
LOG(INFO) << "data type: int8" << std::endl;
break;
}
}
int nbBatch = -1;
// debug print out input arrays
std::vector<::tensorflow::DataType> input_type;
c->GetAttr("InT", &input_type);
oss.str("");
for (size_t i = 0; i < c->num_inputs(); i++) {
// check if input shape is legit
auto input_shape = c->input(i);
int index = i;
oss << "input:" << i << " type: " << input_type[index] << " shape: ";
for (int j = 0; j < c->Rank(input_shape); j++) {
auto dimHandler = c->Dim(input_shape, j);
if (c->ValueKnown(dimHandler))
oss << c->Value(dimHandler) << ", ";
else
oss << "?" << c->Value(dimHandler) << ", ";
if (j == 0) {
if (i == 0)
nbBatch = c->Value(dimHandler);
else if (nbBatch != c->Value(dimHandler))
LOG(WARNING) << "!!!!!!nbBatch does not match!!!!!!";
// assert(nbBatch == c->Value(dimHandler);
}
}
LOG(INFO) << oss.str();
}
// arrange input here
std::vector<string> input_nodes;
c->GetAttr("input_nodes", &input_nodes);
for (size_t i = 0; i < input_nodes.size(); i++) {
int index = i;
LOG(INFO) << "input:" << i << " name: " << input_nodes[index];
}
// arrange output here
std::vector<string> output_nodes;
c->GetAttr("output_nodes", &output_nodes);
oss.str("");
for (size_t i = 0; i < output_nodes.size(); i++) {
int index = i;
int binding_index =
trt_engine->getBindingIndex(output_nodes[index].c_str());
oss << "string name " << output_nodes[index];
ShapeHandle output_shape;
std::vector<DimensionHandle> vecDim;
vecDim.emplace_back(c->MakeDim(nbBatch));
if (binding_index != -1) {
oss << "got binding " << binding_index;
auto dims = trt_engine->getBindingDimensions(binding_index);
for (int j = 0; j < dims.nbDims; j++)
vecDim.emplace_back(c->MakeDim(dims.d[j]));
} else {
oss << "no binding ";
}
output_shape = c->MakeShape(vecDim);
c->set_output(i, output_shape);
LOG(INFO) << oss.str();
}
return Status::OK();
}
} // namespace shape_inference
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace shape_inference {
Status TRTEngineOpShapeInference(InferenceContext* c);
} // namespace shape_inference
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_

View File

@ -0,0 +1,84 @@
/*
wrap trt_conversion
*/
%{
#define SWIG_FILE_WITH_INIT
%}
%include "std_string.i"
%include "std_pair.i"
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%template(StringPair) std::pair<string,string>;
%template() std::pair<swig::SwigPtr_PyObject, swig::SwigPtr_PyObject>;
%{
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
%}
%ignoreall
%unignore tensorflow;
%unignore trt_convert;
%{
std::pair<string,string> trt_convert(string graph_def_string,//const tensorflow::GraphDef&
std::vector<string> output_names,
size_t max_batch_size,
size_t max_workspace_size
// unfortunately we can't use TF_Status here since it
// is in c/c_api and brings in a lot of other libraries
// which in turn declare ops. These ops are included
// statically in our library and cause an abort when
// module is loaded due to double registration
// until Tensorflow properly exposes these headers
// we have to work around this by returning a string
// and converting it to exception on python side.
//,TF_Status* out_status) {
) {
string out_status;
tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(graph_def_string)) {
out_status="InvalidArgument;Couldn't interpret input as a GraphDef";
return std::pair<string,string>{out_status,""};
}
if (!output_names.size()) {
out_status="InvalidArgument;Size of the output_names vector is 0";
return std::pair<string,string>{out_status,""};
//return "";
}
tensorflow::GraphDef outGraph;
tensorflow::Status conversion_status =
tensorrt::convert::ConvertGraphDefToTensorRT(graph_def,
output_names,
max_batch_size,
max_workspace_size,
&outGraph);
if (!conversion_status.ok()) {
auto retCode=(int)conversion_status.code();
char buff[2000];
snprintf(buff,2000,"%d;%s",retCode,conversion_status.error_message().c_str());
out_status=buff;
return std::pair<string,string>{out_status,""};
}
string result;
if (!outGraph.SerializeToString(&result)) {
out_status="InvalidArgument;Couldn't serialize output as a GraphDef";
return std::pair<string,string>{out_status,""};
}
out_status="OK;All good!";
return std::pair<string,string>{out_status,result};
}
%}
std::pair<string,string> trt_convert(string graph_def_string,
std::vector<string> output_names,
size_t max_batch_size,
size_t max_workspace_size);
%unignoreall

View File

@ -279,7 +279,7 @@ def tf_cc_shared_object(
linkopts=[],
framework_so=tf_binary_additional_srcs(),
**kwargs):
native.cc_binary(
native.cc_binary(
name=name,
srcs=srcs + framework_so,
deps=deps,
@ -1281,6 +1281,45 @@ def tf_extension_linkopts():
def tf_extension_copts():
return [] # No extension c opts
# In tf_py_wrap_cc generated libraries
# module init functions are not exported unless
# they contain one of the keywords in the version file
# this prevents custom python modules.
# This function attempts to append init_module_name to list of
# exported functions in version script
def _append_init_to_versionscript_impl(ctx):
modName=ctx.attr.module_name
isVS=ctx.attr.is_version_script
if isVS:
ctx.actions.expand_template(
template=ctx.file.template_file,
output=ctx.outputs.versionscript,
substitutions={
"global:":"global:\n init_%s;"%modName,
},
is_executable=False,
)
else:
ctx.actions.expand_template(
template=ctx.file.template_file,
output=ctx.outputs.versionscript,
substitutions={
"*tensorflow*":"*tensorflow*\ninit_%s"%modName,
},
is_executable=False,
)
_append_init_to_versionscript= rule(
implementation=_append_init_to_versionscript_impl,
attrs={
"module_name":attr.string(mandatory=True),
"template_file":attr.label(allow_files=True,single_file=True,mandatory=True),
"is_version_script":attr.bool(default=True,doc='whether target is a ld version script or exported symbol list',mandatory=False),
},
outputs={"versionscript":"%{name}.lds"},
)
def tf_py_wrap_cc(name,
srcs,
swig_includes=[],
@ -1302,26 +1341,39 @@ def tf_py_wrap_cc(name,
toolchain_deps=["//tools/defaults:crosstool"],
module_name=module_name,
py_module_name=name)
vscriptname=name+"_versionscript"
_append_init_to_versionscript(
name=vscriptname,
module_name=module_name,
is_version_script=select({
"@local_config_cuda//cuda:darwin":False,
"//conditions:default":True,
}),
template_file=select({
"@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"),
"//conditions:default":clean_dep("//tensorflow:tf_version_script.lds")
})
)
extra_linkopts = select({
"@local_config_cuda//cuda:darwin": [
"-Wl,-exported_symbols_list",
clean_dep("//tensorflow:tf_exported_symbols.lds")
"%s.lds"%vscriptname,
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
"-Wl,--version-script",
clean_dep("//tensorflow:tf_version_script.lds")
"%s.lds"%vscriptname,
]
})
extra_deps += select({
"@local_config_cuda//cuda:darwin": [
clean_dep("//tensorflow:tf_exported_symbols.lds")
"%s.lds"%vscriptname,
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
clean_dep("//tensorflow:tf_version_script.lds")
"%s.lds"%vscriptname,
]
})

View File

@ -11,6 +11,7 @@ load(
)
load("//third_party/mkl:build_defs.bzl", "if_mkl")
load("//tensorflow:tensorflow.bzl", "if_cuda")
load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
# This returns a list of headers of all public header libraries (e.g.,
@ -201,7 +202,8 @@ sh_binary(
"//tensorflow/python:test_ops",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
],
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
}) + if_mkl(["//third_party/mkl:intel_binary_blob"])
+ if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
)
# A genrule for generating a marker file for the pip package on Windows

View File

@ -1,6 +1,7 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
load("//third_party/tensorrt:build_defs.bzl", "trt_repository")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure")
@ -66,6 +67,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
# version we require here.
check_bazel_version_at_least("0.5.4")
cuda_configure(name="local_config_cuda")
trt_repository(name="local_config_tensorrt")
git_configure(name="local_config_git")
sycl_configure(name="local_config_sycl")
python_configure(name="local_config_python")

0
third_party/tensorrt/BUILD vendored Normal file
View File

42
third_party/tensorrt/BUILD.tpl vendored Normal file
View File

@ -0,0 +1,42 @@
# -*- python -*-
# Description:
# provide tensorrt information
#TODO(Sami) these needs to be defined
licenses(["notice"])
exports_files(["LICENSE"])
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
config_setting(
name = "trt_enabled",
define_values = {
"using_tensorrt":"true"
},
visibility = ["//visibility:public"],
)
cc_library(
name = "tensorrt",
srcs =[%{tensorrt_lib}],
hdrs = ["include/NvInfer.h",
"include/NvUtils.h",
],
copts= cuda_default_copts(),
deps =["@local_config_cuda//cuda:cuda",
"@local_config_cuda//cuda:cudnn",],
linkstatic = 1,
#include_prefix="include/",
includes=["include/"],
visibility = ["//visibility:public"],
)
%{tensorrt_genrules}
# filegroup(
# name = "%{tensorrt_lib}",
# srcs = ["%{tensorrt_lib}"],
# visibility = ["//visibility:public"],
# )

203
third_party/tensorrt/LICENSE vendored Normal file
View File

@ -0,0 +1,203 @@
Copyright 2015 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2015, The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

85
third_party/tensorrt/build_defs.bzl vendored Normal file
View File

@ -0,0 +1,85 @@
# -*- python -*-
"""
add a repo_generator rule for tensorrt
"""
_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH"
_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION"
def _is_trt_enabled(repo_ctx):
if "TF_NEED_TENSORRT" in repo_ctx.os.environ:
enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip()
return enable_trt == "1"
return False
def _dummy_repo(repo_ctx):
repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
{"%{tensorrt_lib}":"","%{tensorrt_genrules}":""},
False)
repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
{"%{trt_configured}":"False"},False)
repo_ctx.file("include/NvUtils.h","",False)
repo_ctx.file("include/NvInfer.h","",False)
def _trt_repo_impl(repo_ctx):
"""
Implements local_config_tensorrt
"""
if not _is_trt_enabled(repo_ctx):
_dummy_repo(repo_ctx)
return
trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH]
trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION]
# if deb installation
# once a standardized installation between tar and deb
# is done, we don't need this
if trt_libdir == '/usr/lib/x86_64-linux-gnu':
incPath='/usr/include/x86_64-linux-gnu'
incname='/usr/include/x86_64-linux-gnu/NvInfer.h'
else:
incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath)
incname=incPath+'/NvInfer.h'
if len(trt_ver)>0:
origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver)
else:
origLib="%s/libnvinfer.so"%trt_libdir
objdump=repo_ctx.which("objdump")
if objdump == None:
if len(trt_ver)>0:
targetlib="lib/libnvinfer.so.%s"%(trt_ver[0])
else:
targetlib="lib/libnvinfer.so"
else:
soname=repo_ctx.execute([objdump,"-p",origLib])
for l in soname.stdout.splitlines():
if "SONAME" in l:
lib=l.strip().split(" ")[-1]
targetlib="lib/%s"%(lib)
if len(trt_ver)>0:
repo_ctx.symlink(origLib,targetlib)
else:
repo_ctx.symlink(origLib,targetlib)
grule=('genrule(\n name = "trtlinks",\n'+
' outs = [\n "%s",\n "include/NvInfer.h",\n "include/NvUtils.h",\n ],\n'%targetlib +
' cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) +
'&&\n ln -sf %s $(@D)/include/NvInfer.h '%(incname) +
'&&\n ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath))
repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
{"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule},
False)
repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
{"%{trt_configured}":"True"},False)
trt_repository=repository_rule(
implementation= _trt_repo_impl,
local=True,
environ=[
"TF_NEED_TENSORRT",
_TF_TENSORRT_VERSION,
_TENSORRT_INSTALLATION_PATH,
],
)

18
third_party/tensorrt/build_defs.bzl.tpl vendored Normal file
View File

@ -0,0 +1,18 @@
# -*- python -*-
"""
template file for trt functions
"""
def is_trt_enabled():
return %{trt_configured}
def if_trt(if_true,if_false=[]):
# if is_trt_enabled():
# return if_true
# return if_false
return select({
"@local_config_tensorrt//:trt_enabled":if_true,
"//conditions:default":if_false,
})