diff --git a/configure.py b/configure.py index cf16ef48376..580bbc0ebed 100644 --- a/configure.py +++ b/configure.py @@ -37,12 +37,14 @@ _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), _TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'WORKSPACE') _DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_TENSORRT_VERSION = '4' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) +_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' @@ -382,13 +384,12 @@ def set_build_var(environ_cp, var_name, query_item, option_name, var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) environ_cp[var_name] = var - if var == '1': - write_to_bazelrc('build --define %s=true' % option_name) - elif bazel_config_name is not None: - # TODO(mikecase): Migrate all users of configure.py to use --config Bazel - # options and not to set build configs through environment variables. - write_to_bazelrc('build:%s --define %s=true' - % (bazel_config_name, option_name)) + # TODO(mikecase): Migrate all users of configure.py to use --config Bazel + # options and not to set build configs through environment variables. + if var=='1': + setting='true' + confname=":%s"%(bazel_config_name) if bazel_config_name is not None else "" + write_to_bazelrc('build%s --define %s=%s' % (confname,option_name,setting)) def set_action_env_var(environ_cp, @@ -438,13 +439,12 @@ def convert_version_to_int(version): for seg in version_segments: if not seg.isdigit(): return None - version_str = ''.join(['%03d' % int(seg) for seg in version_segments]) return int(version_str) def check_bazel_version(min_version): - """Check installed bezel version is at least min_version. + """Check installed bazel version is at least min_version. Args: min_version: string for minimum bazel version. @@ -1056,6 +1056,108 @@ def set_other_cuda_vars(environ_cp): write_to_bazelrc('test --config=cuda') +def set_tf_trt_version(environ_cp): + """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.""" + ask_trt_version = ( + 'Please specify the TensorRT (libnvinfer) version you want to use. ' + '[Leave empty to default to libnvinfer %s]: ') % _DEFAULT_TENSORRT_VERSION + + while True: + tf_trt_version = get_from_env_or_user_or_default( + environ_cp, 'TF_TENSORRT_VERSION', ask_trt_version, + _DEFAULT_TENSORRT_VERSION) + # if library version is passed and known + default_trt_path = environ_cp.get('TENSORRT_INSTALL_PATH',_DEFAULT_TENSORRT_PATH_LINUX) + ask_trt_path = (r'Please specify the location where libnvinfer %s library is ' + 'installed. Refer to README.md for more details. [Default' + ' is %s]:') % (tf_trt_version, default_trt_path) + trt_install_path = get_from_env_or_user_or_default( + environ_cp, 'TENSORRT_INSTALL_PATH', ask_trt_path, default_trt_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + trt_install_path = os.path.realpath( + os.path.expanduser(trt_install_path)) + # Simple function to search for libnvinfer in install path + # it will find all libnvinfer.so* in user defined install path + # and lib64 subdirectory and return absolute paths + def find_libs(search_path): + fl=set() + if os.path.exists(search_path) and os.path.isdir(search_path): + fl.update([os.path.realpath(os.path.join(search_path,x)) \ + for x in os.listdir(search_path) if 'libnvinfer.so' in x]) + return fl + possible_files=find_libs(trt_install_path) + possible_files.update(find_libs(os.path.join(trt_install_path,'lib64'))) + if is_linux(): + cudnnpatt=re.compile(".*libcudnn.so\.?(.*) =>.*$") + cudapatt =re.compile(".*libcudart.so\.?(.*) =>.*$") + def is_compatible(lib,cudaver,cudnnver): + ldd_bin=which('ldd') or '/usr/bin/ldd' + ldd_out=run_shell([ldd_bin,lib]).split(os.linesep) + for l in ldd_out: + if 'libcudnn.so' in l: + cudnn=cudnnpatt.search(l) + elif 'libcudart.so' in l: + cudart=cudapatt.search(l) + if cudnn: + cudnn=convert_version_to_int(cudnn.group(1)) if len(cudnn.group(1)) else 0 + if cudart: + cudart=convert_version_to_int(cudart.group(1)) if len(cudart.group(1)) else 0 + return (cudnn==cudnnver) and (cudart==cudaver) + cudaver=convert_version_to_int(environ_cp['TF_CUDA_VERSION']) + cudnnver=convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) + valid_libs=[] + vfinder=re.compile('.*libnvinfer.so.?(.*)$') + highest_ver=[0,None,None] + + for l in possible_files: + if is_compatible(l,cudaver,cudnnver): + valid_libs.append(l) + vstr=vfinder.search(l).group(1) + currver=convert_version_to_int(vstr) if len(vstr) else 0 + if currver > highest_ver[0]: + highest_ver= [currver,vstr,l] + if highest_ver[1] is not None: + trt_install_path=os.path.dirname(highest_ver[2]) + tf_trt_version=highest_ver[1] + break + ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' + libnvinfer_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) + libnvinfer_path_from_ldconfig = re.search('.*libnvinfer.so.* => (.*)', + libnvinfer_path_from_ldconfig) + if libnvinfer_path_from_ldconfig: + libnvinfer_path_from_ldconfig = libnvinfer_path_from_ldconfig.group(1) + if os.path.exists('%s.%s' % (libnvinfer_path_from_ldconfig, + tf_trt_version)): + trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) + break + + # Reset and Retry + if len(possible_files): + print( + 'Invalid path to TensorRT %s. libnvinfer.so* files found are for incompatible cuda versions ' + % tf_trt_version) + print(trt_install_path) + print(os.path.join(trt_install_path,'lib64')) + else: + print( + 'Invalid path to TensorRT %s. No libnvinfer.so* files found in ' + 'found:' % tf_trt_version) + print(trt_install_path) + print(os.path.join(trt_install_path,'lib64')) + if is_linux(): + print('%s.%s' % (libnvinfer_path_from_ldconfig, tf_trt_version)) + + environ_cp['TF_TENSORRT_VERSION'] = '' + + # Set TENSORRT_INSTALL_PATH and TENSORRT_CUDNN_VERSION + environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path + write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path) + environ_cp['TF_TENSORRT_VERSION'] = tf_trt_version + write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_trt_version) + write_to_bazelrc('build:tensorrt --define using_tensorrt=true') + def set_host_cxx_compiler(environ_cp): """Set HOST_CXX_COMPILER.""" default_cxx_host_compiler = which('g++') or '' @@ -1244,9 +1346,11 @@ def main(): environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' + environ_cp['TF_NEED_TENSORRT'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' + environ_cp['TF_NEED_TENSORRT'] = '0' set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) @@ -1301,6 +1405,10 @@ def main(): if not is_windows(): set_gcc_host_compiler_path(environ_cp) set_other_cuda_vars(environ_cp) + # enable tensorrt if desired. Disabled on non-linux + set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False) + if environ_cp.get('TF_NEED_TENSORRT') == '1': + set_tf_trt_version(environ_cp) set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) if environ_cp.get('TF_NEED_MPI') == '1': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index da37564697a..b374462d324 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -358,6 +358,14 @@ config_setting( }, ) +config_setting( + name = "using_tensorrt", + define_values = { + "using_tensorrt":"true", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "with_mpi_support", values = {"define": "with_mpi_support=true"}, diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 8bed0fabd74..e5c3017426f 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_tensorrt//:build_defs.bzl", "if_trt") py_library( name = "contrib_py", @@ -104,7 +105,9 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]), + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]) + + if_trt(["//tensorflow/contrib/tensorrt:init_py"]), + ) cc_library( diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD new file mode 100644 index 00000000000..723c9f5434b --- /dev/null +++ b/tensorflow/contrib/tensorrt/BUILD @@ -0,0 +1,266 @@ +# -*- python -*- +# Description: +# provide tensorrt operators and converter package + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_py_wrap_cc", + "tf_cc_test", + "tf_kernel_library", + "tf_custom_op_py_library", + "tf_copts", +) + + + +tf_custom_op_library( + name = "python/ops/_trt_engine_op.so", + srcs = [ + "kernels/trt_engine_op.cc", + "ops/trt_engine_op.cc", + "kernels/trt_engine_op.h", + ], + gpu_srcs = [], + deps = [ + "@local_config_tensorrt//:tensorrt", + ":trt_shape_function", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core/kernels:bounds_check_lib", + "//tensorflow/core/kernels:ops_util_hdrs", + ], +) + +cc_library( + name = "trt_shape_function", + srcs=[ + "shape_fn/trt_shfn.cc", + ], + hdrs=["shape_fn/trt_shfn.h"], + copts=tf_copts(), + deps=[ + ":trt_logging", + "//third_party/eigen3", + "@local_config_tensorrt//:tensorrt", + "@protobuf_archive//:protobuf", + "@nsync//:nsync_headers", + "//tensorflow/core:framework_headers_lib", + ] +) + + +tf_kernel_library( + name = "trt_engine_op_kernel", + srcs = [ + "kernels/trt_engine_op.cc", + ], + hdrs=[ + "kernels/trt_engine_op.h", + ], + gpu_srcs = [ + ], + deps = [ + ":trt_logging", + ":trt_shape_function", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + "//tensorflow/core:gpu_headers_lib", + "@local_config_tensorrt//:tensorrt", + "//tensorflow/core:lib_proto_parsing", + ], + alwayslink=1, +) + +tf_gen_op_libs( + op_lib_names = [ + "trt_engine_op", + ], + deps=[ + "@local_config_tensorrt//:tensorrt", + ] +) + + +cc_library( + name="trt_logging", + srcs = [ + "log/trt_logger.cc", + ], + hdrs=[ + "log/trt_logger.h", + ], + deps=[ + "@local_config_tensorrt//:tensorrt", + "//tensorflow/core:lib_proto_parsing", + ], + visibility = ["//visibility:public"], +) + +tf_gen_op_wrapper_py( + name = "trt_engine_op", + deps = [ + ":trt_engine_op_op_lib", + ":trt_shape_function", + ], +) + + +tf_custom_op_py_library( + name = "trt_engine_op_loader", + srcs = ["python/ops/trt_engine_op.py"], + dso = [":python/ops/_trt_engine_op.so", + "@local_config_tensorrt//:tensorrt", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:resources", + ], +) + +py_library( + name = "init_py", + srcs = [ + "__init__.py", + "python/__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":trt_ops_py", + ":trt_convert_py", + + ], +) + +py_library( + name="trt_ops_py", + srcs_version = "PY2AND3", + deps=[":trt_engine_op", + ":trt_engine_op_loader", + ], + +) + +py_library( + name="trt_convert_py", + srcs=["python/trt_convert.py"], + srcs_version = "PY2AND3", + deps=[ + ":wrap_conversion" + ], +) + +tf_py_wrap_cc( + name="wrap_conversion", + srcs=["trt_conversion.i"], + deps=[ + ":trt_conversion", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", + ], +) + +cc_library( + name= "trt_conversion", + srcs=[ + "convert/convert_nodes.cc", + "convert/convert_graph.cc", + "segment/segment.cc", + "convert/inferShapes.cc", + ], + hdrs=[ + "convert/convert_nodes.h", + "convert/convert_graph.h", + "convert/inferShapes.h", + "segment/segment.h", + "segment/union_find.h", + ], + deps=[ + "@local_config_tensorrt//:tensorrt", + "@protobuf_archive//:protobuf_headers", + "@nsync//:nsync_headers", + ":trt_logging", + "//tensorflow/core:framework_lite", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:core_cpu_base", + #"//third_party/eigen3", + ], +) + +tf_custom_op_library( + name = "tensorrt_ops.so", + srcs = [ + "ops/tensorrt_ops.cc", + ], + deps = [ + "@local_config_tensorrt//:tensorrt", + ], +) + + +# Library for the segmenting portion of TensorRT operation creation +cc_library( + name = "segment", + srcs = [ + "segment/segment.cc", + ], + hdrs = [ + "segment/union_find.h", + "segment/segment.h", + ], + deps = [ + "@protobuf_archive//:protobuf_headers", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib_proto_parsing", + "//third_party/eigen3", + ], + linkstatic = 1, +) + +tf_cc_test( + name = "segment_test", + size = "small", + srcs = ["segment/segment_test.cc"], + deps = [ + ":segment", + "//tensorflow/c:c_api", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + + +# Library for the node-level conversion portion of TensorRT operation creation + +filegroup( + name = "cppfiles", + srcs = glob(["**/*.cc"]), + visibility=["//visibility:private"], +) + +filegroup( + name = "headers", + srcs = glob(["**/*.h"]), + visibility=["//visibility:private"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md new file mode 100644 index 00000000000..61b348fc605 --- /dev/null +++ b/tensorflow/contrib/tensorrt/README.md @@ -0,0 +1,42 @@ +Using TensorRT in TensorFlow +============================ + +This module provides necessary bindings and introduces TRT_engine_op +operator that wraps a subgraph in TensorRT. + +Compilation +----------- + +In order to compile the module, you need to have a local TensorRT +installation (libnvinfer.so and respective include files). During the +configuration step, TensorRT should be enabled and installation path +should be set. If installed through package managers (deb,rpm), +configure script should find the necessary components from the system +automatically. If installed from tar packages, user has to set path to +location where the library is installed during configuration. + +In order to enable TensorRT support, user has to add `--config=tensorrt` to +the build flags during the compilation such as + +``` +bazel build --config=cuda --config=opt --config=tensorrt //tensorflow/tools/pip_package:build_pip_package +bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ +``` + +After the installation of tensorflow package, TensorRT transformation +will be available. An example use is shown below. + +```python +import tensorflow as tf +import tensorflow.contrib.tensorrt as trt +#... create and train or load model +gdef=sess.graph.as_graph_def() +trt_gdef=trt.CreateInferenceGraph(gdef, #original graph_def + ["output"], #name of output node(s) + max_batch_size, #maximum batch size to run the inference + max_workspace_size # max memory for TensorRT to use + ) +tf.reset_default_graph() +tf.import_graph_def(graph_def=trt_gdef) +#...... run inference +``` diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py new file mode 100644 index 00000000000..0d69ffe4663 --- /dev/null +++ b/tensorflow/contrib/tensorrt/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensorrt.python import * diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc new file mode 100644 index 00000000000..29aa5554679 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -0,0 +1,253 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" + +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/inferShapes.h" +#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1) +//------------------------------------------------------------------------------ +namespace tensorrt { +namespace convert { + +namespace { + +static std::unordered_set output_nodes; +bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) { + static const std::set candidate_ops = { + "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu", + "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean" + // TODO(ben,jie): ... + }; + if (output_nodes.count(node_def.name())) return false; + return candidate_ops.count(node_def.op()); +} + +void GetSubGraphIncomingEdges(tensorflow::Graph const& graph, + std::set const& subgraph_node_ids, + tensorflow::EdgeSet* incoming_edges) { + for (int node_id : subgraph_node_ids) { + tensorflow::Node const* node = graph.FindNodeId(node_id); + LOG(DEBUG) << node->name() << " has incoming edges: "; + for (tensorflow::Edge const* edge : node->in_edges()) { + if (!subgraph_node_ids.count(edge->src()->id()) && + !edge->src()->IsSource()) { + LOG(DEBUG) << edge->src()->name() << ", "; + incoming_edges->insert(edge); + } + } + } +} + +void GetSubGraphOutgoingEdges(tensorflow::Graph const& graph, + std::set const& subgraph_node_ids, + tensorflow::EdgeSet* outgoing_edges) { + for (int node_id : subgraph_node_ids) { + tensorflow::Node const* node = graph.FindNodeId(node_id); + LOG(DEBUG) << node->name() << " has outgoing edges: "; + for (tensorflow::Edge const* edge : node->out_edges()) { + if (!subgraph_node_ids.count(edge->dst()->id()) && + !edge->dst()->IsSink()) { + outgoing_edges->insert(edge); + } + } + } +} + +std::pair ParseTensorName(std::string name, + int default_idx = 0) { + int idx = default_idx; + size_t sep = name.find_last_of(':'); + if (sep != std::string::npos) { + name = name.substr(0, sep); + idx = std::stoi(name.substr(sep + 1)); + } + return std::make_pair(name, idx); +} + +std::unordered_map> BuildTensorNameMap( + const std::vector& tensor_names) { + std::unordered_map> result; + for (std::string const& tensor_name : tensor_names) { + std::string node_name; + int index; + std::tie(node_name, index) = ParseTensorName(tensor_name); + result[node_name].push_back(index); + } + return result; +} + +tensorflow::Status ConvertSubGraphToTensorRT( + tensorflow::Graph& graph, const std::vector& output_names, + const std::set& subgraph_node_ids, size_t max_batch_size, + size_t max_workspace_size, const ShapeMap& shape_map) { + tensorflow::EdgeSet subgraph_incoming_edges; + GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges); + + std::vector> subgraph_inputs; + + + // Collect inputs by looking for incoming edges + for (tensorflow::Edge const* edge : subgraph_incoming_edges) { + subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); + } + std::set> subgraph_outputs_set; + // Collect outputs referenced from output_names + auto output_name_to_index_map = BuildTensorNameMap(output_names); + // for (int node_id : subgraph_node_ids_no_placeholder) { + for (int node_id : subgraph_node_ids) { + tensorflow::Node* node = graph.FindNodeId(node_id); + if (output_name_to_index_map.count(node->name())) { + for (int index : output_name_to_index_map.at(node->name())) { + subgraph_outputs_set.insert({node_id, index}); + } + } + } + // Collect outputs referenced from outgoing edges + tensorflow::EdgeSet subgraph_outgoing_edges; + // GetSubGraphOutgoingEdges(graph, subgraph_node_ids_no_placeholder, + // &subgraph_outgoing_edges); + GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges); + for (tensorflow::Edge const* edge : subgraph_outgoing_edges) { + subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); + } + // Impose an ordering on the outputs + std::vector> subgraph_outputs( + subgraph_outputs_set.begin(), subgraph_outputs_set.end()); + // Build TensorRT node and add it to the graph + tensorflow::NodeDef trt_node_def; + TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef( + graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs, + max_batch_size, max_workspace_size, shape_map, &trt_node_def)); + tensorflow::Status status; + tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status); + + TF_RETURN_IF_ERROR(status); + + // Re-map outgoing edges to use the new TRT node instead of the orig subgraph + std::map, int> subgraph_edge_to_output_map; + for (size_t i = 0; i < subgraph_outputs.size(); ++i) { + subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i}); + } + TF_RETURN_IF_ERROR(status); + for (tensorflow::Edge const* edge : subgraph_outgoing_edges) { + std::pair old_src = {edge->src()->id(), edge->src_output()}; + int new_src_output = subgraph_edge_to_output_map.at(old_src); + graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input()); + } + // Remove the original subgraph + for (int node_id : subgraph_node_ids) { + tensorflow::Node* node = graph.FindNodeId(node_id); + // Don't remove the input placeholders + if (node->type_string() == "Placeholder") { + continue; + } + graph.RemoveNode(node); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status BuildNodeMap( + const tensorflow::Graph& graph, + std::unordered_map* node_map) { + for (auto* node : graph.op_nodes()) { + if (!node_map->insert({node->name(), node}).second) { + return tensorflow::errors::AlreadyExists( + "Node name is not unique in graph: " + node->name()); + } + } + return tensorflow::Status::OK(); +} + +} // namespace + +tensorflow::Status ConvertGraphDefToTensorRT( + const tensorflow::GraphDef& graph_def, + const std::vector& output_names, size_t max_batch_size, + size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) { + ShapeMap shape_map; + TF_RETURN_IF_ERROR( + tensorflow::trt::inferShapes(graph_def, output_names, shape_map)); + std::stringstream oss; + for (auto& n : shape_map) { // nodes + oss << " Node= " << n.first << ", "; + for (auto o : n.second) { // outputs + oss << o.first.DebugString() << " T= " << o.second << ", "; + } + LOG(DEBUG) << oss.str(); + oss.str(""); + } + // Build full graph + tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), + graph_def.library()); + tensorflow::Graph graph(flib); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( + tensorflow::GraphConstructorOptions(), graph_def, &graph)); + + // Segment the graph into subgraphs that can be converted to TensorRT + tensorrt::segment::SegmentOptions segment_options; + // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) + for (auto node : output_names) output_nodes.insert(node); + + // TODO(sami): this should be passed as a knob!!!! + segment_options.minimum_segment_size = 2; + tensorrt::segment::SegmentNodesVector segments; + TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( + graph_def, IsTensorRTCandidate, segment_options, &segments)); + if (segments.size() > 1) { + // LOG(WARNING) << "Multiple TensorRT candidate subgraphs were found, " + //<< "but only the first can be converted."; + // segments.erase(++segments.begin(), segments.end()); + LOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); + } + std::unordered_map node_map; + TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); + for (std::set const& subgraph_node_names : segments) { + std::set subgraph_node_ids; + for (std::string const& node_name : subgraph_node_names) { + subgraph_node_ids.insert(node_map.at(node_name)->id()); + } + TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT( + graph, output_names, subgraph_node_ids, max_batch_size, + max_workspace_size, shape_map)); + } + graph.ToGraphDef(new_graph_def); + return tensorflow::Status::OK(); +} + +} // namespace convert +} // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h new file mode 100644 index 00000000000..cd713de8880 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorrt { +namespace convert { + +tensorflow::Status ConvertGraphDefToTensorRT( + const tensorflow::GraphDef& graph_def, + const std::vector& output_names, size_t max_batch_size, + size_t max_workspace_size, tensorflow::GraphDef* new_graph_def); +} +} // namespace tensorrt + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc new file mode 100644 index 00000000000..03146b1b541 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -0,0 +1,1737 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "NvInfer.h" + +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1) +// Check if the types are equal. Cast to int first so that failure log message +// would work! +#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) +//------------------------------------------------------------------------------ +namespace tensorrt { +namespace convert { + +namespace { + +inline int get_dtype_size(nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kINT8: + return 1; + case nvinfer1::DataType::kHALF: + return 2; + default: + return -1; + } +} + +inline int get_dtype_size(tensorflow::DataType trt_dtype) { + switch (trt_dtype) { + case tensorflow::DataType::DT_FLOAT: + return 4; + case tensorflow::DataType::DT_INT8: + return 1; + case tensorflow::DataType::DT_HALF: + return 2; + case tensorflow::DataType::DT_INT32: + return 4; + default: + return -1; + } +} + +inline tensorflow::Status convert_dtype(tensorflow::DataType tf_dtype, + nvinfer1::DataType* trt_dtype) { + switch (tf_dtype) { + case tensorflow::DataType::DT_FLOAT: + *trt_dtype = nvinfer1::DataType::kFLOAT; + break; + case tensorflow::DataType::DT_INT8: + *trt_dtype = nvinfer1::DataType::kINT8; + break; + case tensorflow::DataType::DT_HALF: + *trt_dtype = nvinfer1::DataType::kHALF; + break; + default: + return tensorflow::errors::InvalidArgument("Unsupported data type"); + } + return tensorflow::Status::OK(); +} + +inline nvinfer1::Dims get_tensor_shape(const tensorflow::Tensor& tensor) { + nvinfer1::Dims dims; + dims.nbDims = tensor.dims(); + for (int i = 0; i < dims.nbDims; i++) { + dims.d[i] = tensor.dim_size(i); + } + return dims; +} + +inline int64_t get_shape_size(nvinfer1::Dims shape) { + // Returns total number of elements in shape + int64_t count = 1; + for (int d = 0; d < shape.nbDims; ++d) { + count *= shape.d[d]; + } + return count; +} + +static std::vector> createSamePadding( + nvinfer1::DimsHW& stride, nvinfer1::DimsHW& kernel, + std::vector inputDims) { + std::vector> padding(inputDims.size()); + CHECK_EQ((size_t)stride.nbDims, inputDims.size()); // TODO(jie): N+C? NC+? + + for (size_t i = 0; i < inputDims.size(); ++i) { + /* formula to calculate the padding */ + int p = ((inputDims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] - + inputDims[i]; + p = (p > 0) ? p : 0; + + /* right precedence padding, like in TensorFlow */ + int left = p / 2; + int right = p - left; + + padding[i] = {left, right}; + } + return padding; +} + +// class TRT_ShapedWeights : public nvinfer1::Weights { +class TRT_ShapedWeights { + public: + nvinfer1::Dims shape_; + tensorflow::DataType type_; + const void* values_; + bool dummy_flag_; + int64_t count() const { + int64_t c = 1; + for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i]; + return c; + } + TRT_ShapedWeights(tensorflow::DataType type, const void* values, + nvinfer1::Dims shape) + : shape_(shape), type_(type), values_(values), dummy_flag_(false) { + // Note: this->shape.type[] is not used + } + explicit TRT_ShapedWeights(tensorflow::DataType type) + : type_(type), values_(nullptr), dummy_flag_(true) {} + nvinfer1::Weights getWeightsForTRT() const { + nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(convert_dtype(type_, &trt_type)); + if (dummy_flag_) return nvinfer1::Weights{trt_type, nullptr, 0}; + + // Note: this->shape.type[] is not used + return nvinfer1::Weights{trt_type, values_, get_shape_size(shape_)}; + } + size_t size_bytes() const { + return this->count() * get_dtype_size(this->type_); + } + // default converter + operator nvinfer1::Weights() const { return getWeightsForTRT(); } +}; + +class TRT_TensorOrWeights { + union { + nvinfer1::ITensor* _tensor_; + TRT_ShapedWeights _weights_; + }; + enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } _variant_; + + public: + explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor) + : _tensor_(tensor), _variant_(TRT_NODE_TENSOR) {} + explicit TRT_TensorOrWeights(TRT_ShapedWeights const& weights) + : _weights_(weights), _variant_(TRT_NODE_WEIGHTS) {} + TRT_TensorOrWeights() = delete; + bool is_tensor() const { return _variant_ == TRT_NODE_TENSOR; } + bool is_weights() const { return _variant_ == TRT_NODE_WEIGHTS; } + nvinfer1::ITensor* tensor() { + CHECK_EQ(this->is_tensor(), true); + return _tensor_; + } + nvinfer1::ITensor const* tensor() const { + CHECK_EQ(this->is_tensor(), true); + return _tensor_; + } + TRT_ShapedWeights& weights() { + CHECK_EQ(this->is_weights(), true); + return _weights_; + } + TRT_ShapedWeights const& weights() const { + CHECK_EQ(this->is_weights(), true); + return _weights_; + } + nvinfer1::Dims shape() const { + if (this->is_tensor()) { + return this->tensor()->getDimensions(); + } else { + return this->weights().shape_; + } + } +}; + +class TRT_LayerOrWeights { + union { + nvinfer1::ILayer* _layer_; + TRT_ShapedWeights _weights_; + }; + enum { TRT_NODE_LAYER, TRT_NODE_WEIGHTS } _variant_; + + public: + explicit TRT_LayerOrWeights(nvinfer1::ILayer* layer) + : _layer_(layer), _variant_(TRT_NODE_LAYER) {} + explicit TRT_LayerOrWeights(TRT_ShapedWeights const& weights) + : _weights_(weights), _variant_(TRT_NODE_WEIGHTS) {} + bool is_layer() const { return _variant_ == TRT_NODE_LAYER; } + bool is_weights() const { return _variant_ == TRT_NODE_WEIGHTS; } + nvinfer1::ILayer* layer() { + CHECK_EQ(this->is_layer(), true); + return _layer_; + } + TRT_ShapedWeights& weights() { + CHECK_EQ(this->is_weights(), true); + return _weights_; + } + TRT_TensorOrWeights output(int index = 0) const { + if (this->is_layer()) { + nvinfer1::ITensor* tensor = _layer_->getOutput(index); + return TRT_TensorOrWeights(tensor); + } else { + CHECK_EQ(index, 0); + return TRT_TensorOrWeights(_weights_); + } + } +}; + +class TFAttrs { + typedef std::map AttrMap; + AttrMap _attrs; + + public: + explicit TFAttrs(tensorflow::NodeDef const& tf_node) { + for (auto const& attr : tf_node.attr()) { + _attrs.insert({attr.first, &attr.second}); + } + } + bool count(std::string key) const { return _attrs.count(key); } + tensorflow::AttrValue const* at(std::string key) const { + if (!_attrs.count(key)) { + throw std::out_of_range("Attribute not found: " + key); + } + return _attrs.at(key); + } + template + T get(std::string key) const; + template + T getShape(std::string key) const; + template + T get(std::string key, T const& default_value) const { + return _attrs.count(key) ? this->get(key) : default_value; + } +}; +// template <> +// float TFAttrs::get(std::string key) const { +// return this->at(key)->f(); +//} + +// template <> +// int TFAttrs::get(std::string key) const { +// return (int)this->at(key)->i(); +//} + +// template <> +// bool TFAttrs::get(std::string key) const { +// auto value = this->at(key)->i(); +// return bool(value); +//} + +template <> +std::string TFAttrs::get(std::string key) const { + return this->at(key)->s(); +} +template <> +std::vector TFAttrs::get>(std::string key) const { + auto attr = this->at(key)->list().i(); + return std::vector(attr.begin(), attr.end()); +} +template <> +nvinfer1::Dims TFAttrs::get(std::string key) const { + auto values = this->get>(key); + nvinfer1::Dims dims; + dims.nbDims = values.size(); + std::copy(values.begin(), values.end(), dims.d); + // Note: No dimension type information is included + return dims; +} +// template <> +// nvinfer1::DimsHW TFAttrs::get(std::string key) const { +// nvinfer1::Dims dims = this->get(key); +// CHECK_EQ(dims.nbDims, 2); +// return nvinfer1::DimsHW(dims.d[0], dims.d[1]); +//} +// template <> +// nvinfer1::Permutation TFAttrs::get( +// std::string key) const { +// auto values = this->get>(key); +// nvinfer1::Permutation perm; +// std::copy(values.begin(), values.end(), perm.order); +// // Fill unused values with -1 to aid debugging +// std::fill(perm.order + values.size(), perm.order + nvinfer1::Dims::MAX_DIMS, +// -1); +// return perm; +//} +// template <> +// nvinfer1::Dims TFAttrs::getShape(std::string key) const { +// auto attr = this->at(key)->shape(); +// nvinfer1::Dims dims; +// dims.nbDims = attr.dim_size(); +// for (int i = 0; i < dims.nbDims; i++) dims.d[i] = attr.dim(i).size(); +// return dims; +//} +// template<> TRT_ShapedWeights TFAttrs::get(std::string key) +// const { +// tensorflow::TensorProto const* tf_weights_tensor = &this->at(key)->tensor(); +// TODO(jie): Implement this +// return convert_tf_weights(tf_weights_tensor); +//} +template <> +nvinfer1::DataType TFAttrs::get(std::string key) const { + nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(convert_dtype(this->at(key)->type(), &trt_dtype)); + return trt_dtype; +} +template <> +tensorflow::DataType TFAttrs::get(std::string key) const { + return this->at(key)->type(); +} + +template +void reorder4(nvinfer1::DimsNCHW shape, T const* idata, + nvinfer1::DimsNCHW istrides, T* odata, + nvinfer1::DimsNCHW ostrides) { + for (int n = 0; n < shape.n(); ++n) { + for (int c = 0; c < shape.c(); ++c) { + for (int h = 0; h < shape.h(); ++h) { + for (int w = 0; w < shape.w(); ++w) { + odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() + + w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() + + h * istrides.h() + w * istrides.w()]; + } + } + } + } +} + +void reorder_rsck_to_kcrs(TRT_ShapedWeights const& iweights, + TRT_ShapedWeights* oweights) { + CHECK_EQ(iweights.type_, oweights->type_); + CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); + int r = iweights.shape_.d[0]; + int s = iweights.shape_.d[1]; + int c = iweights.shape_.d[2]; + int k = iweights.shape_.d[3]; + oweights->shape_.d[0] = k; + oweights->shape_.d[1] = c; + oweights->shape_.d[2] = r; + oweights->shape_.d[3] = s; + // nvinfer1::DimsNCHW istrides = {1, s, c*r*s, r*s}; + nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; + nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; + switch (iweights.type_) { + case tensorflow::DataType::DT_FLOAT: + reorder4( + {k, c, r, s}, static_cast(iweights.values_), istrides, + static_cast(const_cast(oweights->values_)), ostrides); + break; + default: + LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!"; + } +} + +/* not used. clean up needed. +nvinfer1::Weights make_dummy_weights(nvinfer1::DataType +dtype=nvinfer1::DataType::kFLOAT) { nvinfer1::Weights w; w.count = 0; w.values += nullptr; w.type = dtype; return w; +} +*/ + +struct InferDeleter { + template + void operator()(T* obj) const { + if (obj) { + obj->destroy(); + } + } +}; + +template +inline std::shared_ptr infer_object(T* obj) { + return std::shared_ptr(obj, InferDeleter()); +} + +// Logger for GIE info/warning/errors +class Converter; + +using OpConverter = + std::function const&, + std::vector*)>; + +class Converter { + std::unordered_map _trt_tensors; + std::unordered_map _op_registry; + nvinfer1::INetworkDefinition* _trt_network; + std::list> _temp_bufs; + + void register_op_converters(); + + std::vector get_inputs( + tensorflow::NodeDef const& node_def) { + std::vector inputs; + for (auto const& input_name : node_def.input()) { + LOG(DEBUG) << "retrieve input: " << input_name; + inputs.push_back(_trt_tensors.at(input_name)); + } + return inputs; + } + + public: + explicit Converter(nvinfer1::INetworkDefinition* trt_network) + : _trt_network(trt_network) { + this->register_op_converters(); + } + + TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, + nvinfer1::Dims shape) { + TRT_ShapedWeights weights(type, nullptr, shape); + _temp_bufs.push_back(std::vector(weights.size_bytes())); + weights.values_ = _temp_bufs.back().data(); + return weights; + } + + TRT_ShapedWeights get_temp_weights_like(TRT_ShapedWeights const& weights) { + return this->get_temp_weights(weights.type_, weights.shape_); + } + + tensorflow::Status convert_node(tensorflow::NodeDef const& node_def) { + std::vector inputs = this->get_inputs(node_def); + std::string op = node_def.op(); + if (!_op_registry.count(op)) { + return tensorflow::errors::Unimplemented( + "no converter registered for op: " + op); + } + OpConverter op_converter = _op_registry.at(op); + std::vector outputs; + TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + for (size_t i = 0; i < outputs.size(); ++i) { + TRT_TensorOrWeights output = outputs.at(i); + // TODO(jie): tf protobuf seems to be omitting the :0 suffix + std::string output_name = node_def.name(); + if (i != 0) output_name = output_name + ":" + std::to_string(i); + if (output.is_tensor()) { + output.tensor()->setName(output_name.c_str()); + } + LOG(DEBUG) << "write out tensor: " << output_name; + if (!_trt_tensors.insert({output_name, output}).second) { + return tensorflow::errors::AlreadyExists( + "output tensor already exists for op: " + op); + } + } + return tensorflow::Status::OK(); + } + + nvinfer1::INetworkDefinition* network() { return _trt_network; } + + TRT_TensorOrWeights get_tensor(std::string name) { + if (!_trt_tensors.count(name)) { + return TRT_TensorOrWeights(nullptr); + } + return _trt_tensors.at(name); + } + + bool insert_input_tensor(std::string name, nvinfer1::ITensor* tensor) { + return _trt_tensors.insert({name, TRT_TensorOrWeights(tensor)}).second; + } + + nvinfer1::ITensor* transposeTensor(nvinfer1::ITensor* input_tensor, + std::vector order) { + auto dims = input_tensor->getDimensions(); + + // TODO(jie): change the return to status and properly exit + if (order.size() - 1 != size_t(dims.nbDims)) + LOG(ERROR) << "dimension does not match, fail gracefully"; + + nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); + nvinfer1::Permutation permutation; + for (int32_t i = 0; i < dims.nbDims; ++i) { + permutation.order[i] = order[i + 1] - 1; + } + layer->setFirstTranspose(permutation); + + nvinfer1::Dims reshapeDims; + reshapeDims.nbDims = dims.nbDims; + for (int32_t i = 0; i < reshapeDims.nbDims; ++i) { + reshapeDims.d[i] = 0; + reshapeDims.type[i] = dims.type[i]; + } + layer->setReshapeDimensions(reshapeDims); + return layer->getOutput(0); + } +}; + +/******************************************************************************* + Constant folding functions + TODO(jie): once optimizer kicks in, we should have done constant folding +there. +*******************************************************************************/ +struct LambdaFactory { + enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB }; + OP_CATEGORY op; + + template + std::function unary() { + switch (op) { + case OP_CATEGORY::RSQRT: { + LOG(DEBUG) << "RSQRT GETS DONE"; + return [](T t) -> T { return 1.0 / std::sqrt(t); }; + } + case OP_CATEGORY::NEG: + return [](T t) -> T { return -t; }; + default: + LOG(DEBUG) << "not supported op for unary: " << static_cast(op); + return nullptr; + } + } + + template + std::function binary() { + switch (op) { + case OP_CATEGORY::ADD: + return [](T l, T r) -> T { return l + r; }; + case OP_CATEGORY::SUB: + return [](T l, T r) -> T { return l - r; }; + case OP_CATEGORY::MUL: + return [](T l, T r) -> T { return l * r; }; + default: + LOG(WARNING) << "not supported op for binary: " << static_cast(op); + } + return [](T l, T r) -> T { + LOG(FATAL) << "Unsupported op type "; + return l; + }; + } + + template + std::function broadcast_r(T val) { + LOG(DEBUG) << "LAMBDA VAL : " << val; + switch (op) { + case OP_CATEGORY::ADD: + return [val](T l) -> T { + LOG(DEBUG) << "LAMBDA VAL : " << val; + return l + val; + }; + // return [val](T l)-> T {return l+val;}; + case OP_CATEGORY::SUB: + return [val](T l) -> T { + LOG(DEBUG) << "LAMBDA VAL : " << val; + return l - val; + }; + case OP_CATEGORY::MUL: + return [val](T l) -> T { + LOG(DEBUG) << "LAMBDA VAL : " << val; + return l * val; + }; + default: + LOG(WARNING) << "not supported op for binary: " << static_cast(op); + } + return [val](T l) -> T { + LOG(FATAL) << "Unsupported op type "; + return l; + }; + } + + template + std::function broadcast_l(T val) { + LOG(DEBUG) << "LAMBDA VAL : " << val; + switch (op) { + case OP_CATEGORY::ADD: + return [val](T l) -> T { + LOG(DEBUG) << "LAMBDA VAL : " << val; + return val + l; + }; + case OP_CATEGORY::SUB: + return [val](T l) -> T { + LOG(DEBUG) << "LAMBDA VAL : " << val; + return val - l; + }; + case OP_CATEGORY::MUL: + return [val](T l) -> T { + LOG(DEBUG) << "LAMBDA VAL : " << val; + return val * l; + }; + default: + LOG(ERROR) << "not supported op for binary: " << static_cast(op); + } + return [val](T l) -> T { + LOG(FATAL) << "Unsupported op type "; + return l; + }; + } +}; + +tensorflow::Status UnaryCompute(TRT_ShapedWeights const& iweights, + TRT_ShapedWeights* oweights, + LambdaFactory unary_op) { + // assume iweights.type == oweights.type + CHECK_EQ(iweights.type_, oweights->type_); + + switch (iweights.type_) { + case tensorflow::DataType::DT_FLOAT: { + auto inp = static_cast(iweights.values_); + auto oup = static_cast(const_cast(oweights->values_)); + std::transform(inp, inp + iweights.count(), oup, unary_op.unary()); + break; + } + default: + return tensorflow::errors::Unimplemented("data type not supported: " + + iweights.type_); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status BinaryCompute(TRT_ShapedWeights const& iweights_l, + TRT_ShapedWeights const& iweights_r, + TRT_ShapedWeights* oweights, + LambdaFactory binary_op) { + // assume iweights_l.type == iweight_r.type + CHECK_EQ(iweights_l.type_, oweights->type_); + CHECK_EQ(iweights_r.type_, oweights->type_); + LOG(DEBUG) << "SANITY CHECK!"; + + switch (iweights_l.type_) { + case tensorflow::DataType::DT_FLOAT: { + auto inp_l = static_cast(iweights_l.values_); + auto inp_r = static_cast(iweights_r.values_); + auto oup = static_cast(const_cast(oweights->values_)); + + if (iweights_l.count() != iweights_r.count()) { + // we only supports broadcast of RankZero + if (iweights_l.count() == 1) { + LOG(DEBUG) << "I bet it is not working!" << (*inp_l); + std::transform(inp_r, inp_r + iweights_r.count(), oup, + binary_op.broadcast_l(*inp_l)); + } else if (iweights_r.count() == 1) { + LOG(DEBUG) << "I bet it is not working!" << (*inp_r); + std::transform(inp_l, inp_l + iweights_l.count(), oup, + binary_op.broadcast_r(*inp_r)); + } else { + return tensorflow::errors::Unimplemented( + "Binary op with non-rankZero broadcast not supported"); + } + } else { + std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup, + binary_op.binary()); + } + break; + } + default: + return tensorflow::errors::Unimplemented("data type not supported: " + + iweights_l.type_); + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status ConstantFoldUnary( + Converter& ctx, tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + TRT_ShapedWeights weights_input = inputs.at(0).weights(); + + // allocate output weights + TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input); + + // FIXME assume type matches input weights + // get trt type & shape + // maybe this part has to be moved into the block of rsqrt later + // check type consistency + CHECK_EQ(weights_input.type_, + TFAttrs(node_def).get("T")); + + // Maybe I should do a switch + LambdaFactory unary_op; + if (node_def.op() == "Rsqrt") { + // compute rsqrt + unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT; + auto ret = UnaryCompute(weights_input, &weights_output, unary_op); + // pass the output + if (ret == tensorflow::Status::OK()) { + outputs->push_back(TRT_TensorOrWeights(weights_output)); + } + return ret; + } else { + return tensorflow::errors::Unimplemented("Binary op not supported: " + + node_def.op()); + } +} + +// TODO(jie,ben) broadcast is needed yet not implemented +// Let's get the simple stuff working first. Maybe we should fall bakc to TF +// approach for constant folding +tensorflow::Status ConstantFoldBinary( + Converter& ctx, tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + TRT_ShapedWeights weights_input_l = inputs.at(0).weights(); + TRT_ShapedWeights weights_input_r = inputs.at(1).weights(); + + // check type consistency + CHECK_EQ(weights_input_l.type_, weights_input_r.type_); + + if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims) + return tensorflow::errors::Unimplemented( + "Binary op implicit broadcast not supported: " + node_def.op()); + + // TODO(jie): constant fold should really fall back to TF. + int nbDims = weights_input_l.shape_.nbDims; + nvinfer1::Dims output_shape; + output_shape.nbDims = nbDims; + LOG(DEBUG) << "nbDims: " << nbDims + << "the other: " << weights_input_r.shape_.nbDims; + for (int i = 0; i < nbDims; i++) { + if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) { + output_shape.d[i] = weights_input_l.shape_.d[i]; + } else if (weights_input_l.shape_.d[i] == 1 || + weights_input_r.shape_.d[i] == 1) { + output_shape.d[i] = + std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]); + } else { + return tensorflow::errors::Unimplemented( + "Binary op with incompatible shape at, " + node_def.op()); + } + LOG(DEBUG) << "left: " << weights_input_l.shape_.d[i] + << "right: " << weights_input_r.shape_.d[i] + << "output: " << output_shape.d[i]; + } + + // FIXME assume type matches input weights + // get trt type & shape + TFAttrs attrs(node_def); + // maybe this part has to be moved into the block of rsqrt later + tensorflow::DataType dtype = attrs.get("T"); + + // allocate output weights + TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape); + + // Maybe I should do a switch + LambdaFactory binary_op; + if (node_def.op() == "Sub") { + binary_op.op = LambdaFactory::OP_CATEGORY::SUB; + } else if (node_def.op() == "Mul") { + binary_op.op = LambdaFactory::OP_CATEGORY::MUL; + } else if (node_def.op() == "Add") { + binary_op.op = LambdaFactory::OP_CATEGORY::ADD; + } else { + return tensorflow::errors::Unimplemented("Binary op not supported: " + + node_def.op()); + } + auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output, + binary_op); + + // pass the output + if (ret == tensorflow::Status::OK()) { + outputs->push_back(TRT_TensorOrWeights(weights_output)); + } + + return ret; +} + +// TODO(jie): broadcast is needed yet not implemented +// only implemented channel wise for the time being +tensorflow::Status BinaryTensorOpWeight( + Converter& ctx, tensorflow::NodeDef const& node_def, + const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights, + std::vector* outputs) { + // FIXME assume type matches input weights + // get trt type & shape + // maybe this part has to be moved into the block of rsqrt later + + // check type consistency + auto dtype = TFAttrs(node_def).get("T"); + CHECK_EQ_TYPE(tensor->getType(), dtype); // cast to int for error messages + nvinfer1::DataType ttype; + TF_CHECK_OK(convert_dtype(weights.type_, &ttype)); + CHECK_EQ_TYPE(ttype, dtype); // cast to int for error message + + // check scale mode + auto dims_w = weights.shape_; + auto dims_t = tensor->getDimensions(); + + // default to channel-wise + auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; + + /* + if (weights.count() == 1) { + LOG(DEBUG) << "UNIFORM"; + scale_mode = nvinfer1::ScaleMode::kUNIFORM; + } else if (dims_w.nbDims == 1) { + // TODO(jie): should we check for implicit chennel wise binary op + // where weights has shape 1x1xC? + LOG(DEBUG) << "CHANNEL"; + scale_mode = nvinfer1::ScaleMode::kCHANNEL; + } else { + // TODO(jie): check weight shape. + // broadcast is not fully supported + LOG(DEBUG) << "ELEMENTWISE"; + scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; + } */ + + if (weights.count() == 1) { + LOG(DEBUG) << "UNIFORM"; + scale_mode = nvinfer1::ScaleMode::kUNIFORM; + } else { + // no broadcasting on Batch dimension; + assert(dims_w.d[0]==1); + + // broadcasting on Channel dimension only allowed in kUNIFORM + assert(dims_w.d[1]==dims_t.d[0]); + assert(dims_w.nbDims==dims_t.nbDims); + + // default is element; + for (int i=2; i permutation(dims_t.nbDims + 1); + if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) { + // we swap the last dimension into channel for trt. + // because of tensorflow default broadcasting rules. + for (int i = 0; i < static_cast(permutation.size()); i++) { + permutation[i] = i; + } + permutation[1] = dims_t.nbDims; + permutation[dims_t.nbDims] = 1; + tensor = ctx.transposeTensor(const_cast(tensor), + permutation); + } + */ + + // prepare weights + TRT_ShapedWeights shiftWeights(weights.type_); + TRT_ShapedWeights scaleWeights(weights.type_); + TRT_ShapedWeights powerWeights(weights.type_); + + // Maybe I should do a switch + if (node_def.op() == "Sub") { + TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); + LambdaFactory unary_op; + unary_op.op = LambdaFactory::OP_CATEGORY::NEG; + UnaryCompute(weights, &neg_weights, unary_op); + shiftWeights = neg_weights; + } else if (node_def.op() == "Mul") { + scaleWeights = weights; + } else if (node_def.op() == "Add") { + shiftWeights = weights; + } else { + return tensorflow::errors::Unimplemented("Binary op not supported: " + + node_def.op()); + } + + nvinfer1::IScaleLayer* layer = ctx.network()->addScale( + *const_cast(tensor), scale_mode, shiftWeights, + scaleWeights, powerWeights); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + // transpose back dimension + /* + if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) { + output_tensor = ctx.transposeTensor(output_tensor, permutation); + } + */ + + // pass the output + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status BinaryTensorOpTensor( + Converter& ctx, tensorflow::NodeDef const& node_def, + const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, + std::vector* outputs) { + static const std::unordered_map + ops{ + {"Add", nvinfer1::ElementWiseOperation::kSUM}, + {"Mul", nvinfer1::ElementWiseOperation::kPROD}, + // {"max", nvinfer1::ElementWiseOperation::kMAX}, + // {"min", nvinfer1::ElementWiseOperation::kMIN}, + {"Sub", nvinfer1::ElementWiseOperation::kSUB}, + {"Div", nvinfer1::ElementWiseOperation::kDIV}, + }; + + // FIXME assume type matches input weights + // get trt type & shape + TFAttrs attrs(node_def); + // maybe this part has to be moved into the block of rsqrt later + nvinfer1::DataType dtype = attrs.get("T"); + + // check type consistency + CHECK_EQ_TYPE(tensor_l->getType(), dtype); + CHECK_EQ_TYPE(tensor_r->getType(), dtype); + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) + return tensorflow::errors::Unimplemented( + "binary op: " + node_def.op() + + " not supported at: " + node_def.name()); + + nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( + *const_cast(tensor_l), + *const_cast(tensor_r), op_pair->second); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // pass the output + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertPlaceholder( + Converter& ctx, tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + LOG(DEBUG) << "Placeholder should have been replace already"; + return tensorflow::errors::Unimplemented("cannot convert Placeholder op"); + // OK this make sense since we are supposed to replace it with input + TFAttrs attrs(node_def); + nvinfer1::DataType dtype = attrs.get("dtype"); + nvinfer1::Dims dims = attrs.get("shape"); + + dims.nbDims--; + for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; + + nvinfer1::ITensor* output = + ctx.network()->addInput(node_def.name().c_str(), dtype, dims); + if (!output) { + return tensorflow::errors::InvalidArgument("Failed to create Input layer"); + } + outputs->push_back(TRT_TensorOrWeights(output)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertConv2D(Converter& ctx, + tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + // nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + // TODO(jie): handle NHWC/NCHW transpose; + TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); + reorder_rsck_to_kcrs(weights_rsck, &weights); + TRT_ShapedWeights biases(weights.type_); + int noutput = weights.shape_.d[0]; + nvinfer1::DimsHW kernel_size; + kernel_size.h() = weights.shape_.d[2]; + kernel_size.w() = weights.shape_.d[3]; + TFAttrs attrs(node_def); + + int h_index = 2; + int w_index = 3; + auto data_format = attrs.get("data_format"); + if (data_format == "NHWC") { + tensor = ctx.transposeTensor(const_cast(tensor), + {0, 3, 1, 2}); + h_index = 1; + w_index = 2; + // TODO(jie): transpose it + } else { + LOG(DEBUG) << "NCHW !!!!"; + } + // TODO(jie): stride. (NHWC/NCHW) + auto tf_stride = attrs.get>("strides"); + nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + + auto tensor_dim = tensor->getDimensions(); + std::vector> padding; + // TODO(jie): padding. + if (attrs.get("padding") == "SAME") { + // This is NCHW tensor with no batch dimension. + // 1 -> h + // 2 -> w + padding = createSamePadding(stride, kernel_size, + {static_cast(tensor_dim.d[h_index]), + static_cast(tensor_dim.d[w_index])}); + } else { + // return tensorflow::errors::Unimplemented( + // "Current Conv2D cannot support padding other than SAME"); + padding = {{0, 0}, {0, 0}}; + } + + if (padding[0].first != padding[0].second || + padding[1].first != padding[1].second) { + // TODO(jie): handle asymmetric padding + // return tensorflow::errors::Unimplemented( + // "Asymmetric padding not implemented yet"); + auto padLayer = ctx.network()->addPadding( + *const_cast(tensor), + nvinfer1::DimsHW(padding[1].first, padding[0].first), + nvinfer1::DimsHW(padding[1].second, padding[0].second)); + tensor = padLayer->getOutput(0); + } + + nvinfer1::IConvolutionLayer* layer = + ctx.network()->addConvolution(*const_cast(tensor), + noutput, kernel_size, weights, biases); + + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + if (data_format == "NHWC") { + // TODO(jie): transpose it back! + output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1}); + } else { + LOG(DEBUG) << "NCHW !!!!"; + } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertPool(Converter& ctx, + tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + TFAttrs attrs(node_def); + + int h_index = 2; + int w_index = 3; + auto data_format = attrs.get("data_format"); + if (data_format == "NHWC") { + h_index = 1; + w_index = 2; + tensor = ctx.transposeTensor(const_cast(tensor), + {0, 3, 1, 2}); + } else { + LOG(DEBUG) << "NCHW !!!!"; + } + nvinfer1::PoolingType type; + // TODO(jie): support other pooling type + if (node_def.op() == "MaxPool") + type = nvinfer1::PoolingType::kMAX; + else + return tensorflow::errors::Unimplemented("only supports Max pool"); + + // TODO(jie): NCHW + auto tf_stride = attrs.get>("strides"); + nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + + auto tf_kernel = attrs.get>("ksize"); + nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); + + auto tensor_dim = tensor->getDimensions(); + std::vector> padding; + // TODO(jie): padding. + if (attrs.get("padding") == "SAME") { + // This is NCHW tensor with no batch dimension. + // 1 -> h + // 2 -> w + padding = createSamePadding( + stride, ksize, + {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); + } else if (attrs.get("padding") == "VALID") { + // No padding for valid padding here + LOG(DEBUG) << "no padding added for VALID padding in pool" + << node_def.name(); + padding = {{0, 0}, {0, 0}}; + } else { + return tensorflow::errors::Unimplemented( + "Current MaxPool cannot support padding other than SAME"); + } + + if (padding[0].first != padding[0].second || + padding[1].first != padding[1].second) { + // TODO(jie): handle asymmetric padding + // return tensorflow::errors::Unimplemented( + // "Asymmetric padding not implemented yet"); + auto padLayer = ctx.network()->addPadding( + *const_cast(tensor), + nvinfer1::DimsHW(padding[1].first, padding[0].first), + nvinfer1::DimsHW(padding[1].second, padding[0].second)); + tensor = padLayer->getOutput(0); + } + + nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling( + *const_cast(tensor), type, ksize); + + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + if (data_format == "NHWC") { + // TODO(jie): transpose it back! + output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1}); + } else { + LOG(DEBUG) << "NCHW !!!!"; + } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertActivation( + Converter& ctx, tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + nvinfer1::IActivationLayer* layer = ctx.network()->addActivation( + *const_cast(tensor), nvinfer1::ActivationType::kRELU); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertScale(Converter& ctx, + tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) + return tensorflow::errors::Unimplemented( + "only supports tensor op weight for now, at " + node_def.name()); + // implement tensor binaryOp weight [channel wise] for now; + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + // nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + + // TODO(jie): handle NHWC/NCHW transpose; + TRT_ShapedWeights weights = inputs.at(1).weights(); + // nvinfer1::Weights empty_weights{weights.type, nullptr, 0}; + TRT_ShapedWeights empty_weights(weights.type_); + + TFAttrs attrs(node_def); + + // transpose NHWC + auto data_format = attrs.get("data_format"); + if (data_format == "NHWC") { + tensor = ctx.transposeTensor(const_cast(tensor), + {0, 3, 1, 2}); + // TODO(jie): transpose it + } else { + LOG(DEBUG) << "NCHW !!!!"; + } + nvinfer1::IScaleLayer* layer = ctx.network()->addScale( + *const_cast(tensor), nvinfer1::ScaleMode::kCHANNEL, + weights, empty_weights, empty_weights); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + if (data_format == "NHWC") { + // TODO(jie): transpose it back! + output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1}); + } else { + LOG(DEBUG) << "NCHW !!!!"; + } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertConst(Converter& ctx, + tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + auto const& weights_tensor = node_def.attr().at("value").tensor(); + + // get trt type & shape + TFAttrs attrs(node_def); + // nvinfer1::DataType dtype = attrs.get("dtype"); + tensorflow::DataType dtype = attrs.get("dtype"); + + // create shaped weights as output + tensorflow::Tensor tensor; + if (!tensor.FromProto(weights_tensor)) + return tensorflow::errors::Internal("cannot parse weight tensor proto: " + + node_def.name()); + + TRT_ShapedWeights weights(dtype); + if (!weights_tensor.float_val().empty()) { + LOG(DEBUG) << "SCALAR!!!" << node_def.name(); + nvinfer1::Dims scalar_shape; + if (tensor.dims() > 0) { + LOG(DEBUG) << "dimensions: " << tensor.dims(); + weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), + get_tensor_shape(tensor)); + } else { + LOG(DEBUG) << "dimensions: " << tensor.dims(); + scalar_shape.nbDims = 1; + scalar_shape.d[0] = 1; + scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; + for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { + scalar_shape.d[i] = 0; + scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; + } + weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), + scalar_shape); + } + // LOG(INFO) << " add: " << weights_tensor.float_val().data(); + // LOG(INFO) << " value: " << (*weights_tensor.float_val().data()); + + // weights = ctx.get_temp_weights(dtype, scalar_shape); + // std::memcpy(const_cast(weights.values), + // weights_tensor.float_val().data(), weights.size_bytes()); + } else if (!weights_tensor.tensor_content().empty()) { + LOG(DEBUG) << "TENSOR!!!" << node_def.name(); + weights = TRT_ShapedWeights(dtype, weights_tensor.tensor_content().data(), + get_tensor_shape(tensor)); + } else { + return tensorflow::errors::Unimplemented( + "not supported constant type, at " + node_def.name()); + } + // pass the output + outputs->push_back(TRT_TensorOrWeights(weights)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertIdentity( + Converter& ctx, tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertBinary(Converter& ctx, + tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 2) + return tensorflow::errors::FailedPrecondition( + "Binary ops require two tensor input, at " + node_def.name()); + + if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) + return ConstantFoldBinary(ctx, node_def, inputs, outputs); + + if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) + return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), + inputs.at(1).weights(), outputs); + + if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) + return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), + inputs.at(0).weights(), outputs); + + if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) + return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(), + inputs.at(1).tensor(), outputs); + + return tensorflow::errors::Unknown("Binary op input error, at " + + node_def.name()); +} + +tensorflow::Status ConvertUnary(Converter& ctx, + tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 1) + return tensorflow::errors::FailedPrecondition( + "Unary ops require single tensor input, at " + node_def.name()); + + if (inputs.at(0).is_weights()) + return ConstantFoldUnary(ctx, node_def, inputs, outputs); + else if (inputs.at(0).is_tensor()) + return tensorflow::errors::Unimplemented( + "Unary op for tensor not supported, at " + node_def.name()); + + return tensorflow::errors::Unknown("Binary op input error, at " + + node_def.name()); +} + +tensorflow::Status ConvertReduce(Converter& ctx, + tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) + return tensorflow::errors::InvalidArgument( + "Input expects tensor and weights, at" + node_def.name()); + + // implement tensor binaryOp weight [channel wise] for now; + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + auto dims = tensor->getDimensions(); + // restore implicit batch dimension + int nbDims = dims.nbDims + 1; + + TRT_ShapedWeights index_list = inputs.at(1).weights(); + + TFAttrs attrs(node_def); + // TODO(jie): handle data type + // auto data_type = attrs.get("T"); + // index type here is done through TF type + // so I can leverage their EnumToDataType for my cast + auto index_type = attrs.get("Tidx"); + // auto keep_dims_flag = attrs.get("keep_dims"); + + // Only expect to handle INT32 as attributes for now + if (index_type != tensorflow::DataType::DT_INT32) + return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); + // auto pad_data = const_cast::Type*> + // (pads.values); + auto index_list_data = + static_cast(const_cast(index_list.values_)); + // auto index_list_data = + // const_cast::Type*> + // (index_list.values); + + // hack warning: + // have to fall back to pool layer since reduce is not in public TRT yet. + if (nbDims != 4) + return tensorflow::errors::InvalidArgument( + "TRT only support reduce on 4 dimensional tensors, at" + + node_def.name()); + if (index_list.count() > 2) + return tensorflow::errors::InvalidArgument( + "TRT cannot support reduce on more than 2 dimensions, at" + + node_def.name()); + + std::set idx_set; + // we cannot operate on Channel. permutation flag used to transpose tensor + int permuted_index = -1; + for (int i = 0; i < index_list.count(); i++) { + if (index_list_data[i] == 0) + return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" + + node_def.name()); + if (index_list_data[i] == 1) permuted_index = 1; + idx_set.emplace(index_list_data[i]); + } + + std::vector permutation_order(nbDims); + nvinfer1::DimsHW pool_kernel; + if (permuted_index == 1) { + for (int i = 2; i < nbDims; i++) { + if (idx_set.count(i)) { + permuted_index = i; + break; + } + } + for (int i = 0; i < nbDims; i++) permutation_order[i] = i; + + permutation_order[permuted_index] = 1; + permutation_order[1] = permuted_index; + + // apply permutation before extracting dimension for pool_kernel + tensor = ctx.transposeTensor(const_cast(tensor), + permutation_order); + } + + // apply permutation before extracting dimension for pool_kernel + pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1; + pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1; + + nvinfer1::ITensor* output_tensor; + + if (node_def.op() == "Mean") { + nvinfer1::IPoolingLayer* layer = + ctx.network()->addPooling(*const_cast(tensor), + nvinfer1::PoolingType::kAVERAGE, pool_kernel); + output_tensor = layer->getOutput(0); + } else { + return tensorflow::errors::Unimplemented( + "Op not supported " + node_def.op() + " , at " + node_def.name()); + } + if (permuted_index != -1) { + // apply permutation before extracting dimension for pool_kernel + output_tensor = ctx.transposeTensor( + const_cast(output_tensor), permutation_order); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertPad(Converter& ctx, + tensorflow::NodeDef const& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) + return tensorflow::errors::InvalidArgument( + "Input expects tensor and weights, at" + node_def.name()); + + // implement tensor binaryOp weight [channel wise] for now; + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + auto dims = tensor->getDimensions(); + // restore implicit batch dimension + int nbDims = dims.nbDims + 1; + + TRT_ShapedWeights pads = inputs.at(1).weights(); + + TFAttrs attrs(node_def); + // padding type here is done through TF type + // so I can leverage their EnumToDataType for my cast + auto padding_type = attrs.get("Tpaddings"); + // TODO(jie): handle data type conversion for TRT? + // auto data_type = attrs.get("T"); + + if (pads.shape_.d[0] != nbDims || pads.shape_.d[1] != 2) + return tensorflow::errors::InvalidArgument( + "Pad only supports explicit padding on 4 dimensional tensor, at " + + node_def.name()); + + // Only expect to handle INT32 as attributes for now + if (padding_type != tensorflow::DataType::DT_INT32) + return tensorflow::errors::Unimplemented( + "Tpaddings supports only DT_INT32"); + // auto pad_data = const_cast::Type*> + // (pads.values); + auto pad_data = static_cast(const_cast(pads.values_)); + + std::vector pad_index; + for (int i = 0; i < nbDims; i++) { + if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) + pad_index.push_back(i); + } + + // no padding at all, we should exit + if (pad_index.size() == 0) { + outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); + } + + // only supports padding on less than 2 axis GIE-2579 + if (pad_index.size() > 2) + return tensorflow::errors::InvalidArgument( + "Padding layer does not support padding on > 2"); + + // padding on batch dimension is not supported + if (pad_index[0] == 0) + return tensorflow::errors::InvalidArgument( + "Padding layer does not support padding on batch dimension"); + + // not doing the legit thing here. ignoring padding on dim 1 and 3; + // TODO(jie): implement pad as uff parser + if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) + return tensorflow::errors::Unimplemented( + "Padding layer does not support padding on dimension 1 and 3 yet"); + + bool legit_pad = true; + nvinfer1::DimsHW pre_padding(0, 0); + nvinfer1::DimsHW post_padding(0, 0); + + std::vector permuted_pad_index(pad_index); + if (pad_index[0] == 1) { + legit_pad = false; + tensor = ctx.transposeTensor(const_cast(tensor), + {0, 3, 2, 1}); + permuted_pad_index[0] = 3; + } + + for (size_t i = 0; i < pad_index.size(); i++) { + int index = pad_index[i]; + if (permuted_pad_index[i] == 2) { + pre_padding.h() = pad_data[index * 2]; + post_padding.h() = pad_data[index * 2 + 1]; + } else if (permuted_pad_index[i] == 3) { + pre_padding.w() = pad_data[index * 2]; + post_padding.w() = pad_data[index * 2 + 1]; + } + } + + nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding( + *const_cast(tensor), pre_padding, post_padding); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + if (!legit_pad) + output_tensor = ctx.transposeTensor( + const_cast(output_tensor), {0, 3, 2, 1}); + + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +void Converter::register_op_converters() { + // vgg_16 slim implementation + _op_registry["Placeholder"] = ConvertPlaceholder; + _op_registry["Conv2D"] = ConvertConv2D; + _op_registry["Relu"] = ConvertActivation; + _op_registry["MaxPool"] = ConvertPool; + // This could be really handled as ConvertBinary + _op_registry["BiasAdd"] = ConvertScale; + _op_registry["Const"] = ConvertConst; + // _op_registry["MatMul"] = ConvertFullyConnected; // not used in vgg + // TODO(ben,jie): this is a temp hack. + _op_registry["Identity"] = ConvertIdentity; // Identity should be removed + // _op_registry["AvgPool"] = ConvertPool; + + // resnet_50_v1 slim implementation + _op_registry["Add"] = ConvertBinary; + _op_registry["Mul"] = ConvertBinary; + _op_registry["Sub"] = ConvertBinary; + _op_registry["Rsqrt"] = ConvertUnary; + _op_registry["Mean"] = ConvertReduce; + _op_registry["Pad"] = ConvertPad; + // TODO(ben,jie): Add more ops +} + +} // namespace + +tensorflow::Status ConvertSubGraphToTensorRTNodeDef( + const tensorflow::Graph& graph, const std::set& subgraph_node_ids, + const std::vector>& input_inds, + const std::vector>& output_inds, size_t max_batch_size, + size_t max_workspace_size, const ShapeMap& shape_map, + tensorflow::NodeDef* trt_node) { + // Visit nodes in reverse topological order and construct the TRT network. + + // Toposort + std::vector order_vec; + tensorflow::GetPostOrder(graph, &order_vec); + // Select just the subgraph + std::list order; + for (tensorflow::Node* node : order_vec) { + if (subgraph_node_ids.count(node->id())) { + // order.push_back(node); + order.push_front(node); // we want topological order to contstruct the + // network layer by layer + } + } + // topological order is needed to build TRT network + LOG(DEBUG) << "BUILDING 1"; + + // nvinfer1::ILogger::Severity verbosity = + // nvinfer1::ILogger::Severity::kWARNING; + tensorflow::tensorrt::Logger trt_logger; + // TRT_Logger trt_logger(verbosity); + + LOG(DEBUG) << "BUILDING 2"; + + auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger)); + if (!trt_builder) { + return tensorflow::errors::Internal( + "failed to create TensorRT builder object"); + } + + LOG(DEBUG) << "BUILDING 3"; + + auto trt_network = infer_object(trt_builder->createNetwork()); + if (!trt_network) { + return tensorflow::errors::Internal( + "failed to create TensorRT network object"); + } + + LOG(DEBUG) << "BUILDING 4"; + + // Build the network + Converter converter(trt_network.get()); + + LOG(DEBUG) << "BUILDING 5"; + std::vector input_names; + std::vector input_dtypes; + for (std::pair const& input : input_inds) { + LOG(DEBUG) << "parsing input!!!!!"; + int node_id = input.first; + int output_idx = input.second; + tensorflow::Node* node = graph.FindNodeId(node_id); + auto node_name = node->name(); + input_names.push_back(node_name); // insert original node name without port + // TODO(jie): alternative :) + // tensorflow::DataType tf_dtype = node->output_type(output_idx); + if (shape_map.count(node_name) == 0) + return tensorflow::errors::Internal("failed to find input node: " + + node_name); + + auto input_entry_vec = shape_map.at(node_name); + if (static_cast(input_entry_vec.size()) < output_idx) + return tensorflow::errors::Internal( + "accessing output index of: " + std::to_string(output_idx) + + ", at node: " + node_name + "with output entry from shape_map: " + + std::to_string(input_entry_vec.size())); + + auto input_entry = input_entry_vec.at(output_idx); + + tensorflow::DataType tf_dtype = input_entry.second; + input_dtypes.push_back(tf_dtype); + + nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(convert_dtype(tf_dtype, &dtype)); + + LOG(DEBUG) << "accessing output index of: " << std::to_string(output_idx) + << ", at node: " << node_name + << "with output entry from shape_map: " + << std::to_string(input_entry_vec.size()); + // TODO(ben,jie): update TRT input format/dimension + nvinfer1::DimsCHW input_dim_psuedo_chw; + for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1; + + for (int i = 1; i < input_entry.first.dims(); i++) { + LOG(DEBUG) << "dimension: " << i + << " , size: " << input_entry.first.dim_size(i); + input_dim_psuedo_chw.d[i - 1] = input_entry.first.dim_size(i); + } + + // TODO(ben,jie): proper way to restore input tensor name? + auto input_tensor_name = node_name; + if (output_idx != 0) + input_tensor_name = node_name + ":" + std::to_string(output_idx); + + nvinfer1::ITensor* input_tensor = converter.network()->addInput( + input_tensor_name.c_str(), dtype, input_dim_psuedo_chw); + + if (!input_tensor) + return tensorflow::errors::InvalidArgument( + "Failed to create Input layer"); + LOG(DEBUG) << "input tensor name :" << input_tensor_name; + + if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) + return tensorflow::errors::AlreadyExists( + "output tensor already exists for op: " + input_tensor_name); + } + + LOG(DEBUG) << "finished sorting"; + + for (const tensorflow::Node* node : order) { + tensorflow::NodeDef const& node_def = node->def(); + LOG(DEBUG) << "converting node: " << node_def.name() << " , " + << node_def.op(); + TF_RETURN_IF_ERROR(converter.convert_node(node_def)); + } + + LOG(DEBUG) << "finished conversion"; + + // Gather output metadata + std::vector output_names; + std::vector output_dtypes; + for (std::pair const& output : output_inds) { + int node_id = output.first; + int output_idx = output.second; + tensorflow::Node* node = graph.FindNodeId(node_id); + std::string op_name = node->name(); + std::string tensor_name = op_name; + if (output_idx != 0) + tensor_name = tensor_name + ":" + std::to_string(output_idx); + LOG(DEBUG) << "output tensor name: " << tensor_name; + output_names.push_back(tensor_name); + auto tensor_or_weights = converter.get_tensor(tensor_name); + if (!tensor_or_weights.is_tensor()) { + return tensorflow::errors::InvalidArgument( + "Output node is weights not tensor"); + } + nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); + if (!tensor) { + return tensorflow::errors::NotFound("Output tensor not found: " + + tensor_name); + } + converter.network()->markOutput(*tensor); + tensorflow::DataType tf_dtype = node->output_type(output_idx); + output_dtypes.push_back(tf_dtype); + nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; + TF_RETURN_IF_ERROR(convert_dtype(tf_dtype, &trt_dtype)); + tensor->setType(trt_dtype); + } + + LOG(DEBUG) << "finished output"; + + // Build the engine + trt_builder->setMaxBatchSize(max_batch_size); + trt_builder->setMaxWorkspaceSize(max_workspace_size); + LOG(INFO) << "starting build engine"; + // TODO(ben,jie): half2 and int8 mode support + std::string engine_plan_string; + { + auto trt_engine = + infer_object(trt_builder->buildCudaEngine(*converter.network())); + LOG(INFO) << "built network"; + auto engine_plan = infer_object(trt_engine->serialize()); + LOG(INFO) << "serialized engine"; + const char* engine_plan_data = + static_cast(engine_plan->data()); + engine_plan_string = std::move( + std::string(engine_plan_data, engine_plan_data + engine_plan->size())); + } + // std::ofstream engine_out("mini.engine"); + // engine_out << engine_plan_string; + // engine_out.close(); + + LOG(INFO) << "finished engine"; + + // Build the TRT op + // TODO(sami,ben,jie): proper naming! + static int static_id = 0; + tensorflow::NodeDefBuilder op_builder( + "my_trt_op" + std::to_string(static_id++), "TRTEngineOp"); + std::vector income_edges; + for (size_t i = 0; i < input_names.size(); ++i) { + int output_idx = input_inds.at(i).second; + // we wired up the input here already, it is redundant to do it again in + // ConvertSubGraphToTensorRT(convert_graph.cc) + auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(input_names.at(i), + output_idx, input_dtypes.at(i)); + income_edges.push_back(incoming_edge); + } + tensorflow::gtl::ArraySlice + input_list(income_edges); + op_builder.Input(input_list); + + LOG(INFO) << "finished op preparation"; + + auto status = op_builder.Attr("serialized_engine", engine_plan_string) + .Attr("input_nodes", input_names) + .Attr("output_nodes", output_names) + .Attr("OutT", output_dtypes) + .Finalize(trt_node); + + LOG(INFO) << status.ToString(); + LOG(INFO) << "finished op building"; + + return tensorflow::Status::OK(); +} + +} // namespace convert +} // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h new file mode 100644 index 00000000000..a624582deca --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ + +#include +#include +#include + +#include "tensorflow/contrib/tensorrt/convert/inferShapes.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorrt { +namespace convert { + +tensorflow::Status ConvertSubGraphToTensorRTNodeDef( + const tensorflow::Graph& graph, const std::set& subgraph_node_ids, + const std::vector>& + input_inds, // {node_id, output_idx} + const std::vector>& + output_inds, // {node_id, output_idx} + size_t max_batch_size, size_t max_workspace_size, const ShapeMap& shape_map, + tensorflow::NodeDef* trt_node); +} // namespace convert +} // namespace tensorrt + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ diff --git a/tensorflow/contrib/tensorrt/convert/inferShapes.cc b/tensorflow/contrib/tensorrt/convert/inferShapes.cc new file mode 100644 index 00000000000..c7f0f0023d3 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/inferShapes.cc @@ -0,0 +1,125 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorrt/convert/inferShapes.h" +#include +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb_text.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1) + +namespace tensorflow { +namespace trt { +std::vector getTypes(const tensorflow::OpDef& op, + const tensorflow::NodeDef& nd, + bool inp = true) { + const auto& attrMap = nd.attr(); + auto getType = [&attrMap](decltype( + op.input_arg(0)) a) -> std::vector { + std::vector tvec; + if (!a.type_list_attr().empty()) { // get the list types + const auto& tl = attrMap.at(a.type_list_attr()).list(); + int tsize = tl.type_size(); + tvec.reserve(tsize); + for (int t = 0; t < tsize; t++) { + tvec.push_back(tl.type(t)); + } + return tvec; + } + tensorflow::DataType cType = tensorflow::DT_INVALID; + if (a.type() != tensorflow::DT_INVALID) { // get defined types + cType = a.type(); + } else if (!a.type_attr().empty()) { + cType = attrMap.at(a.type_attr()).type(); + } + if (!a.number_attr().empty()) { // numbertypes + int64 nTensors = attrMap.at(a.number_attr()).i(); + tvec = std::vector(nTensors, cType); + return tvec; + } + tvec.push_back(cType); + return tvec; + }; + std::vector types; + if (inp) { + int n_inputs = op.input_arg_size(); + for (int i = 0; i < n_inputs; i++) { + auto tout = getType(op.input_arg(i)); + LOG(DEBUG) << "Node= " << nd.name() << " #inputs" << tout.size(); + types.insert(types.end(), tout.begin(), tout.end()); + } + } else { + int n_outputs = op.output_arg_size(); + // types.resize(n_outputs); + for (int i = 0; i < n_outputs; i++) { + auto tout = getType(op.output_arg(i)); + LOG(DEBUG) << "Node= " << nd.name() << " #outputs" << tout.size(); + types.insert(types.end(), tout.begin(), tout.end()); + } + } + return types; +} + +tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def, + const std::vector& output_names, + ShapeMap& shapes) { + tensorflow::Graph g(OpRegistry::Global()); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( + tensorflow::GraphConstructorOptions(), graph_def, &g)); + std::vector POnodes; + tensorflow::GetPostOrder(g, &POnodes); + tensorflow::ShapeRefiner refiner(graph_def.versions().producer(), + OpRegistry::Global()); + for (auto n = POnodes.rbegin(); n != POnodes.rend(); ++n) { + TF_CHECK_OK(refiner.AddNode(*n)); + } + + auto shape2PTS = [](tensorflow::shape_inference::InferenceContext* ic, + const tensorflow::shape_inference::ShapeHandle& sh) + -> tensorflow::PartialTensorShape { + std::vector dims; + int64 rank = ic->Rank(sh); + for (int64 i = 0; i < rank; i++) { + auto dh = ic->Dim(sh, i); + dims.push_back(ic->Value(dh)); + } + return tensorflow::PartialTensorShape(dims); + }; + for (const auto& n : POnodes) { + auto ic = refiner.GetContext(n); + if (ic) { + int nOuts = ic->num_outputs(); + auto types = getTypes(n->op_def(), n->def(), false); + std::vector< + std::pair> + SAT; + for (int i = 0; i < nOuts; i++) { + auto PTS = shape2PTS(ic, ic->output(i)); + SAT.push_back({PTS, types.at(i)}); + } + shapes[n->name()] = SAT; + } else { + LOG(WARNING) << "Node " << n->name() << " doesn't have InferenceContext!"; + } + } + return tensorflow::Status::OK(); +} +} // namespace trt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/inferShapes.h b/tensorflow/contrib/tensorrt/convert/inferShapes.h new file mode 100644 index 00000000000..b94f1ee893e --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/inferShapes.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" + +typedef std::unordered_map>> + ShapeMap; +namespace tensorflow { +namespace trt { +tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def, + const std::vector& output_names, + ShapeMap& shapes); +} +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_ diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc new file mode 100644 index 00000000000..a1524a592a2 --- /dev/null +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include +#include +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" +// Use TF logging f + + +namespace tensorflow { +static ::tensorflow::tensorrt::Logger gLogger; + +using namespace nvinfer1; + +namespace tensorrt { + +TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { + // char *gieModelStream{nullptr}; + // size_t size{0}; + + // read serialized_engine + std::string serialized_engine; + OP_REQUIRES_OK(context, + context->GetAttr("serialized_engine", &serialized_engine)); + + // register input output node name in trt_sub_graph + OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_)); + OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_)); + + // TODO(samikama) runtime should be taken from a resourcemanager as well. + // Only engine should be in the op and context and runtime should be taken + // from resourcemanager + IRuntime* infer = createInferRuntime(gLogger); + trt_engine_ptr_.reset(infer->deserializeCudaEngine( + serialized_engine.c_str(), serialized_engine.size(), nullptr)); + + trt_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); + // runtime is safe to delete after engine creation + infer->destroy(); + std::stringstream oss; + // debug iterate through all binding instances + for (int i = 0; i < trt_engine_ptr_->getNbBindings(); i++) { + LOG(INFO) << "index: " << i + << ", binding name: " << trt_engine_ptr_->getBindingName(i); + + if (trt_engine_ptr_->bindingIsInput(i)) { + LOG(INFO) << "INPUT"; + } else { + LOG(INFO) << "OUTPUT"; + } + oss << "Dimension: "; + auto dims = trt_engine_ptr_->getBindingDimensions(i); + oss << " nbDims: " << dims.nbDims << " -> "; + for (int j = 0; j < Dims::MAX_DIMS; j++) { + oss << dims.d[j] << ", "; + } + LOG(INFO) << oss.str(); + oss.str(""); + switch (trt_engine_ptr_->getBindingDataType(i)) { + case nvinfer1::DataType::kFLOAT: + LOG(INFO) << "data type float" << std::endl; + break; + case nvinfer1::DataType::kHALF: + LOG(INFO) << "data type half" << std::endl; + break; + case nvinfer1::DataType::kINT8: + LOG(INFO) << "data type int8" << std::endl; + break; + } + } + + // CHECK_NE(cudaStreamCreate(&stream_),0); // logic here is wrong + // cudaStreamCreate(&stream_); +} + +void TRTEngineOp::Compute(OpKernelContext* context) { + int nbBindings = context->num_inputs() + context->num_outputs(); + // TODO(jjsjann123) multiple input/output + std::vector buffers(nbBindings); + + size_t bindingIndex; + int nbBatch = 0; + bool valid = true; + for (int i = 0; i < context->num_inputs(); i++) { + // Grab the input tensor + bindingIndex = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); + + const Tensor& input_tensor = context->input(i); + const TensorShape& input_shape = input_tensor.shape(); + if (i == 0) { + nbBatch = input_shape.dim_size(0); + } else if (nbBatch != input_shape.dim_size(0)) { + valid = false; + break; + } + // int64 input_shape.dim_size(int d) + // int input_shape.dims() + switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) { + case nvinfer1::DataType::kFLOAT: + LOG(INFO) << "float"; + buffers[bindingIndex] = (void*)(input_tensor.flat().data()); + break; + case nvinfer1::DataType::kHALF: + LOG(INFO) << "half"; + // buffers[bindingIndex] = (void*)input_tensor.flat().data(); + break; + case nvinfer1::DataType::kINT8: + LOG(INFO) << "int8"; + // buffers[bindingIndex] = (void*)input_tensor.flat().data(); + break; + } + } + + if (!valid) LOG(WARNING) << "input data inconsistent batch size"; + + for (int i = 0; i < static_cast(output_nodes_.size()); i++) { + // This is bad that we have to reallocate output buffer every run. + // Create an output tensor + bindingIndex = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str()); + Tensor* output_tensor = NULL; + + TensorShape output_shape; + if (bindingIndex != -1) { + LOG(INFO) << "got binding " << bindingIndex; + auto dims = trt_engine_ptr_->getBindingDimensions(bindingIndex); + std::vector trt_shape(dims.nbDims + 1); + trt_shape[0] = nbBatch; + for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; + TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(), + &output_shape); + } else { + LOG(INFO) << "no binding "; + break; + } + + OP_REQUIRES_OK(context, + context->allocate_output(i, output_shape, &output_tensor)); + // buffers[bindingIndex] = (void*)output_tensor->flat(); + // buffers[bindingIndex] = output_tensor->flat().data(); + switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) { + case nvinfer1::DataType::kFLOAT: + LOG(INFO) << "float"; + buffers[bindingIndex] = + reinterpret_cast(output_tensor->flat().data()); + break; + case nvinfer1::DataType::kHALF: + LOG(INFO) << "half"; + // buffers[bindingIndex] = (void*)output_tensor->flat().data(); + break; + case nvinfer1::DataType::kINT8: + LOG(INFO) << "int8"; + // buffers[bindingIndex] = (void*)output_tensor->flat().data(); + break; + } + } + // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + + trt_context_ptr_->enqueue(nbBatch, &buffers[0], *stream, nullptr); + cudaStreamSynchronize(*stream); +} + +REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h new file mode 100644 index 00000000000..631fc114f22 --- /dev/null +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ + +#include +#include +#include +#include +#include +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace tensorrt { +class Logger; +class TRTEngineOp : public OpKernel { + public: + explicit TRTEngineOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* context) override; + + private: + template + struct Destroyer { + void operator()(T* d) { d->destroy(); } + }; + template + using destroyed_ptr = std::unique_ptr>; + destroyed_ptr trt_engine_ptr_; + // TODO(samikama) context should go to a resource manager! + destroyed_ptr trt_context_ptr_; + std::vector input_nodes_; + std::vector output_nodes_; +}; + +} // namespace tensorrt + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc new file mode 100644 index 00000000000..545a4aac50d --- /dev/null +++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +// Use TF logging for TensorRT informations +#include "tensorflow/core/platform/logging.h" + +#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1) +//------------------------------------------------------------------------------ +namespace tensorflow { + +//------------------------------------------------------------------------------ +namespace tensorrt { + +void Logger::log(Severity severity, const char* msg) { + // suppress info-level messages + switch (severity) { + case Severity::kINFO: { // mark TRT info messages as debug! + LOG(DEBUG) << msg; + break; + } + case Severity::kWARNING: { + LOG(WARNING) << msg; + break; + } + case Severity::kERROR: { + LOG(ERROR) << msg; + break; + } + case Severity::kINTERNAL_ERROR: { + LOG(FATAL) << msg; + break; + } + // This is useless for now. But would catch it in future if enum changes. It + // is always good to have default case! + default: { + LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg; + break; + } + } +} + +} // namespace tensorrt + +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h new file mode 100644 index 00000000000..10a78b7a1da --- /dev/null +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -0,0 +1,41 @@ +// -*- c++ -*- +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ + +// Use TF logging f +#include +#include + +//------------------------------------------------------------------------------ +namespace tensorflow { + +//------------------------------------------------------------------------------ +namespace tensorrt { + +// Logger for GIE info/warning/errors +class Logger : public nvinfer1::ILogger { + void log(nvinfer1::ILogger::Severity severity, const char* msg) override; + + private: + std::string name_; +}; + +} // namespace tensorrt + +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc new file mode 100644 index 00000000000..38d37071900 --- /dev/null +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +namespace shape_inference { +extern Status TRTEngineOpShapeInference(InferenceContext* c); +} + +REGISTER_OP("TRTEngineOp") + .Attr("serialized_engine: string") + .Attr("input_nodes: list(string)") + .Attr("output_nodes: list(string)") + .Attr("InT: list({int8, float16, float32})") + .Attr("OutT: list({int8, float16, float32})") + .Input("in_tensor: InT") + .Output("out_tensor: OutT") + .SetShapeFn(shape_inference::TRTEngineOpShapeInference); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py new file mode 100644 index 00000000000..4aeea485151 --- /dev/null +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py new file mode 100644 index 00000000000..ce78d328de3 --- /dev/null +++ b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py @@ -0,0 +1,35 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import platform + +if platform.system() != "Windows": + # pylint: disable=wildcard-import,unused-import,g-import-not-at-top + from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * + + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + # pylint: enable=wildcard-import,unused-import,g-import-not-at-top + + _trt_engine_op = loader.load_op_library( + resource_loader.get_path_to_datafile("_trt_engine_op.so")) +else: + raise RuntimeError("Windows platforms are not supported") + + diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py new file mode 100644 index 00000000000..a66afa8d05a --- /dev/null +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -0,0 +1,91 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Exposes the Python wrapper conversion to trt_graph.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import errors +from tensorflow.python.framework import errors_impl as _impl +from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert +from tensorflow.python.util import compat +import tensorflow as tf +from tensorflow.python.grappler import tf_optimizer +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import ops + + +def CreateInferenceGraph(input_graph_def, outputs,max_batch_size=1,max_workspace_size=2<<20): + """Python wrapper for the TRT transormation. + + + Args: + input_graph_def: GraphDef object containing a model to be transformed. + outputs: List of node names for the model outputs. + max_batch_size: max size for the input batch + max_workspace_size: parameter to control memory allocation (in Bytes) + + Returns: + New GraphDef with TRTEngineOps placed in graph replacing subgraphs. + """ + + # with errors.raise_exception_on_not_ok_status() as status: + # output_graph_def_string = trt_convert( + # input_graph_def_string,outputs, + # max_batch_size,max_workspace_size, status) + g = tf.Graph() + with g.as_default(): + tf.import_graph_def(input_graph_def, name="") + rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.optimizers.append('layout') + rewriter_config.optimizers.append('constfold') + + # mark output nodes as fetch + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + for node_name in outputs: + out_node = g.get_operation_by_name(node_name) + for i in range(0,len(out_node.outputs)): + train_op.append(out_node.outputs[0]) + + # constant folding + mg = meta_graph.create_meta_graph_def(graph=g) + meta_graph.add_collection_def(mg, ops.GraphKeys.TRAIN_OP) + optimized_graph_def_str = \ + tf_optimizer.OptimizeGraph(rewriter_config, mg).SerializeToString() + + # TODO(sami): Fix this when we can return status from C++ library + # There is a problem with the TF internal library setup that doesn't allow us to return a status object from C++. + # Thus we return a pair or strings where first one is encoded status and the second one is the + # transformed graphs protobuf string. + out = trt_convert( + optimized_graph_def_str ,outputs, + max_batch_size,max_workspace_size) + status = out[0] + output_graph_def_string = out[1] + del optimized_graph_def_str #save some memory + if len(status) < 2: + raise _impl.UnknownError(None,None,status) + if status[:2] != "OK": + msg=status.split(";") + if len(msg) == 1: + raise RuntimeError("Status message is malformed {}".format(status)) + raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0])) + output_graph_def = graph_pb2.GraphDef() + output_graph_def.ParseFromString(output_graph_def_string) + del output_graph_def_string #save some memory + return output_graph_def diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc new file mode 100644 index 00000000000..41da528247b --- /dev/null +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -0,0 +1,259 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/segment/segment.h" + +#include +#include +#include +#include + +#include "tensorflow/contrib/tensorrt/segment/union_find.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +//------------------------------------------------------------------------------ +namespace tensorrt { +namespace segment { + +//------------------------------------------------------------------------------ +namespace { + +//------------------------------------------------------------------------------ +bool CanContractEdge(const tensorflow::Edge* edge, + const tensorflow::Graph& graph) { + const tensorflow::Node* src = edge->src(); + const tensorflow::Node* dst = edge->dst(); + + // Can't contract edge if doing so would cause a cycle in the + // graph. So, if there is a directed path from 'src' to 'dst', other + // than 'edge' (or any other direct edge from 'src' to 'dst'), then + // combining 'src' and 'dst' will cause a cycle along that path. + // + // In practice, to avoid modifying the graph and to take advantage + // of existing graph functions, we perform an equivalent. + // 1. Get all nodes incoming to 'dst', excluding 'src' + // 2. Reverse DFS from those nodes + // 3. If reverse DFS reaches 'src' then we have a cycle + std::vector dfs_start_nodes; + for (tensorflow::Node* node : dst->in_nodes()) { + if (node != src) { + dfs_start_nodes.push_back(node); + } + } + + bool is_cycle = false; + if (!dfs_start_nodes.empty()) { + tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {}, + [&is_cycle, src](tensorflow::Node* node) { + if (node == src) { + is_cycle = true; + } + }); + } + + return !is_cycle; +} + +//------------------------------------------------------------------------------ +void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph, + std::vector* remove_edges) { + // Transfer all inputs and outputs of 'dst' to 'src' except edges + // connecting the two. + tensorflow::Node* src = edge->src(); + tensorflow::Node* dst = edge->dst(); + + // We can use '0' for input/output index because we don't need them + // to be accurate for the way we are using the graph. + std::vector in_edges(dst->in_edges().begin(), + dst->in_edges().end()); + for (const tensorflow::Edge* in_edge : in_edges) { + if (in_edge->src() != src) { + tensorflow::Edge* e = const_cast(in_edge); + if (e->src() == graph->source_node()) { + graph->AddEdge(e->src(), e->src_output(), src, + tensorflow::Graph::kControlSlot); + } else { + graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */); + } + } + } + + std::vector out_edges(dst->out_edges().begin(), + dst->out_edges().end()); + for (const tensorflow::Edge* out_edge : out_edges) { + tensorflow::Edge* e = const_cast(out_edge); + if (e->dst() == graph->sink_node()) { + graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(), + e->dst_input()); + } else { + graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input()); + } + } + + // Return the edges that must be removed to disconnect 'dst' from + // the graph. We don't actually remove 'dst' since the caller holds + // references to all the nodes. + for (const auto& in_edge : dst->in_edges()) { + remove_edges->push_back(in_edge); + } + for (const auto& out_edge : dst->out_edges()) { + remove_edges->push_back(out_edge); + } +} + +} // namespace + +//------------------------------------------------------------------------------ +tensorflow::Status SegmentGraph( + const tensorflow::GraphDef& gdef, + const std::function& candidate_fn, + const SegmentOptions& options, SegmentNodesVector* segments) { + // Create a Graph representation of the GraphDef. + tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), + gdef.library()); + tensorflow::Graph graph(flib); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( + tensorflow::GraphConstructorOptions(), gdef, &graph)); + + // tensorflow::DumpGraph("Pre-Segment", &graph); + + // Use a union-find to collect the nodes that belong to the same + // segment. A node value of nullptr indicates that the node is not a + // candidate for TRT. + std::vector> node_segments; + for (int i = 0; i < graph.num_node_ids(); ++i) { + tensorflow::Node* node = graph.FindNodeId(i); + if (!candidate_fn(node->def())) { + node = nullptr; + } + node_segments.emplace_back(node); + } + + // Visit nodes in reverse topological order and use edge + // contraction to merge candidate nodes. + std::vector order; + tensorflow::GetPostOrder(graph, &order); + + for (const tensorflow::Node* node : order) { + // All output nodes of 'node' have been visited... + VLOG(2) << "Trying node " << node->name(); + + // 'node' must be a TRT candidate... + if (node_segments[node->id()].Value() == nullptr) { + VLOG(2) << "... not a TRT candidate"; + continue; + } + + // Contract output edges to combine 'node' with output + // nodes. Iterate since combining two nodes may unblock other + // combining. + while (true) { + std::set contract_edges; + for (const tensorflow::Edge* out_edge : node->out_edges()) { + VLOG(2) << "... out node " << out_edge->dst()->name(); + + // Out node must be TRT candidate... + if (node_segments[out_edge->dst()->id()].Value() == nullptr) { + VLOG(2) << "... ... not a TRT candidate"; + continue; + } + + if (CanContractEdge(out_edge, graph)) { + VLOG(2) << "... ... can contract"; + contract_edges.insert(out_edge); + } else { + VLOG(2) << "... ... cannot contract, would form cycle"; + } + } + + if (contract_edges.empty()) { + break; + } + + // Contract edges and collect the adjacent nodes into the same + // segment/subgraph. + while (!contract_edges.empty()) { + const tensorflow::Edge* contract_edge = *contract_edges.begin(); + const tensorflow::Node* src = contract_edge->src(); + const tensorflow::Node* dst = contract_edge->dst(); + + VLOG(2) << "Merge " << src->name() << " <- " << dst->name(); + node_segments[src->id()].Merge(&node_segments[dst->id()]); + + // Contracting the edge leaves disconnected graph edges. + // Remove these from the graph and from 'contract_edges' so we + // don't visit them again. + tensorflow::Edge* e = const_cast(contract_edge); + std::vector remove_edges; + ContractEdge(e, &graph, &remove_edges); + + for (const tensorflow::Edge* r : remove_edges) { + contract_edges.erase(r); + graph.RemoveEdge(r); + } + } + } + } + + // Collect the segments/subgraphs. Each subgraph is represented by a + // set of the names of the nodes in that subgraph. + std::unordered_map> sg_map; + for (auto& u : node_segments) { + if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) { + sg_map[u.ParentValue()->name()].insert(u.Value()->name()); + } + } + + // Cleanup the graph to remove disconnected nodes before outputting + if (VLOG_IS_ON(2)) { + for (tensorflow::Node* node : graph.nodes()) { + if ((node->in_edges().size() == 0) && (node->out_edges().size() == 0)) { + graph.RemoveNode(node); + } + } + // tensorflow::DumpGraph("Post-Segment", &graph); + } + + // Convert the segments into the expected return format + for (const auto& itr : sg_map) { + const auto& segment_node_names = itr.second; + if (VLOG_IS_ON(1)) { + std::string s; + for (const auto& name : segment_node_names) { + s += " " + name; + } + VLOG(1) << "Segment " << segments->size() << ":" << s; + } + + // Don't use small segments. + if (static_cast(segment_node_names.size()) < + options.minimum_segment_size) { + VLOG(1) << "Segment " << segments->size() << " has only " + << segment_node_names.size() << " nodes, dropping"; + continue; + } + + segments->emplace_back(segment_node_names); + } + + return tensorflow::Status::OK(); +} + +} // namespace segment +} // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h new file mode 100644 index 00000000000..b5aee5bc340 --- /dev/null +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorrt { +namespace segment { + +using SegmentNodesVector = std::vector>; + +struct SegmentOptions { + // Segment must contain at least this many nodes. + int minimum_segment_size = 2; +}; + +// Get the subgraphs of a graph that can be handled by TensorRT. +// +// @param gdef The GraphDef describing the network +// @param candidate_fn A function that returns true for a NodeDef if +// that node can be handled by TensorRT. +// @param segments Returns the TensorRT segments/subgraphs. Each entry +// in the vector describes a subgraph by giving a set of the names of +// all the NodeDefs in that subgraph. +// @return the status. +tensorflow::Status SegmentGraph( + const tensorflow::GraphDef& gdef, + const std::function& candidate_fn, + const SegmentOptions& options, SegmentNodesVector* segments); + +} // namespace segment +} // namespace tensorrt + +#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc new file mode 100644 index 00000000000..dcd0c71ed77 --- /dev/null +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -0,0 +1,363 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" + +//------------------------------------------------------------------------------ +using namespace tensorflow; + +namespace tensorrt { +namespace segment { +namespace test { + +class SegmentTest : public ::testing::Test { + public: + bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); + + TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name); + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name); + + std::function MakeCandidateFn( + const std::set& node_names); + + protected: + void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op); + void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name, TF_Operation** op, bool check); + + SegmentOptions default_options_; +}; + +bool SegmentTest::GetGraphDef(TF_Graph* graph, + tensorflow::GraphDef* graph_def) { + TF_Status* s = TF_NewStatus(); + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphToGraphDef(graph, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length); + TF_DeleteBuffer(buffer); + TF_DeleteStatus(s); + return ret; +} + +std::function SegmentTest::MakeCandidateFn( + const std::set& node_names) { + return [node_names](const NodeDef& node) -> bool { + return node_names.find(node.name()) != node_names.end(); + }; +} + +void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s, + const char* name, TF_Operation** op) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); + TF_SetAttrType(desc, "dtype", TF_INT32); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s, + const char* name) { + TF_Operation* op; + PlaceholderHelper(graph, s, name, &op); + return op; +} + +void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name, TF_Operation** op, + bool check) { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); + TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; + TF_AddInputList(desc, add_inputs, 2); + *op = TF_FinishOperation(desc, s); + if (check) { + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); + } +} + +TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Status* s, + const char* name) { + TF_Operation* op; + AddHelper(l, r, graph, s, name, &op, true); + return op; +} + +//------------------------------------------------------------------------------ +TEST_F(SegmentTest, Empty) { + TF_Graph* graph = TF_NewGraph(); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ( + SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments), + tensorflow::Status::OK()); + + // Expect no segments/subgraphs. + EXPECT_TRUE(segments.empty()); +} + +//------------------------------------------------------------------------------ +TEST_F(SegmentTest, Simple) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // feed + // // || + // add0 add1 + // | | / + // | add2 + // | / || + // add3 add4 + // | / + // + // + TF_Operation* feed = Placeholder(graph, s, "feed"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); + + TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); + TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ( + SegmentGraph(graph_def, + MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}), + default_options_, &segments), + tensorflow::Status::OK()); + + // Expect all Add operations to be collapsed into a single segment + ASSERT_EQ(segments.size(), 1); + std::vector expected{"add0", "add1", "add2", "add3", "add4"}; + for (const auto& ex : expected) { + EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) + << "Missing expected node " << ex; + } +} + +//------------------------------------------------------------------------------ +TEST_F(SegmentTest, AvoidCycle) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // add2 is not a TRT candidate so add0/add3 cannot be formed as a + // subgraph + // + // feed + // // || + // add0 add1 + // | | / + // | add2 + // | / || + // add3 add4 + // | / + // + // + TF_Operation* feed = Placeholder(graph, s, "feed"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); + + TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); + TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ( + SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}), + default_options_, &segments), + tensorflow::Status::OK()); + + // Expect no subgraphs + EXPECT_EQ(segments.size(), 0); +} + +//------------------------------------------------------------------------------ +TEST_F(SegmentTest, Multiple) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // add5 is not a TRT candidate so two subgraphs should be formed + // + // feed + // // || || + // add0 add1 add7 + // | | / / || + // | add2-----add5 add8 + // | / | | | | + // add3 add4 add6 + // | | / + // + // + TF_Operation* feed = Placeholder(graph, s, "feed"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); + + TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add7 = Add(feed, feed, graph, s, "add7"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add5 = Add(add2, add7, graph, s, "add5"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add8 = Add(add7, add7, graph, s, "add8"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); + TF_Operation* add4 = Add(add2, add5, graph, s, "add4"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); + TF_Operation* add6 = Add(add5, add8, graph, s, "add6"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add6"), string(TF_OperationName(add6))); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ(SegmentGraph(graph_def, + MakeCandidateFn({"add0", "add1", "add2", "add3", + "add4", "add6", "add7", "add8"}), + default_options_, &segments), + tensorflow::Status::OK()); + + // Expect two subgraphs + EXPECT_EQ(segments.size(), 2); + + std::vector expected0{"add0", "add1", "add2", "add3"}; + for (const auto& ex : expected0) { + EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) + << "Missing expected node " << ex; + } + + std::vector expected1{"add6", "add8"}; + for (const auto& ex : expected1) { + EXPECT_TRUE(segments[1].find(ex) != segments[1].end()) + << "Missing expected node " << ex; + } +} + +//------------------------------------------------------------------------------ +TEST_F(SegmentTest, BigIfElse) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // add2 is not a TRT candidate + // + // feed + // || + // add0 + // // || + // add1 add4 + // || || + // add2 add5 + // || || + // add3 add6 + // || // + // add7 + // || + // + // + TF_Operation* feed = Placeholder(graph, s, "feed"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); + + TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add1 = Add(add0, add0, graph, s, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add2 = Add(add1, add1, graph, s, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add3 = Add(add2, add2, graph, s, "add3"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add4 = Add(add0, add0, graph, s, "add4"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add5 = Add(add4, add4, graph, s, "add5"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add6 = Add(add5, add5, graph, s, "add6"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add7 = Add(add3, add6, graph, s, "add7"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add7"), string(TF_OperationName(add7))); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ(SegmentGraph(graph_def, + MakeCandidateFn({"add0", "add1", "add3", "add4", + "add5", "add6", "add7"}), + default_options_, &segments), + tensorflow::Status::OK()); + + // Expect 2 subgraphs + EXPECT_EQ(segments.size(), 2); + + std::vector expected0{"add3", "add4", "add5", "add6", "add7"}; + for (const auto& ex : expected0) { + EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) + << "Missing expected node " << ex; + } + + std::vector expected1{"add0", "add1"}; + for (const auto& ex : expected1) { + EXPECT_TRUE(segments[1].find(ex) != segments[1].end()) + << "Missing expected node " << ex; + } +} + +} // namespace test +} // namespace segment +} // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/contrib/tensorrt/segment/union_find.h new file mode 100644 index 00000000000..8ae877cd051 --- /dev/null +++ b/tensorflow/contrib/tensorrt/segment/union_find.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ + +namespace tensorrt { +namespace segment { + +// Union-Find data structure. +// Each cluster has an associated value; when merging clusters we can control +// which value becomes the representative of the merged clusters. Values must be +// copyable. +template +class UnionFind { + public: + UnionFind() : size_(1), parent_(nullptr) {} + explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {} + + // Returns the number of elements in a cluster. + int Size() { return FindRoot()->size_; } + + // Merges this cluster with 'other'. This cluster's value becomes + // the value of the merged cluster; the value of 'other' is ignored. + void Merge(UnionFind* other); + + // Each cluster has an associated value. Retrieves the value associated + // with this cluster. + T& ParentValue() { return FindRoot()->value_; } + + // Get the original value of this node. + T& Value() { return value_; } + + private: + // Finds the root element of the cluster. Performs path compression. + UnionFind* FindRoot(); + + int size_; + UnionFind* parent_; + T value_; +}; + +template +void UnionFind::Merge(UnionFind* other) { + UnionFind* a = FindRoot(); + UnionFind* b = other->FindRoot(); + if (a == b) return; + + b->parent_ = a; + a->size_ += b->size_; +} + +template +UnionFind* UnionFind::FindRoot() { + if (!parent_) return this; + // Path compression: update intermediate nodes to point to the root of the + // equivalence class. + parent_ = parent_->FindRoot(); + return parent_; +} + +} // namespace segment +} // namespace tensorrt + +#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc new file mode 100644 index 00000000000..72022b99e2b --- /dev/null +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" +#include +#include +#include "NvInfer.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" + +namespace tensorflow { +namespace shape_inference { +tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) { + tensorflow::tensorrt::Logger gLogger; + string serialized_engine; + c->GetAttr("serialized_engine", &serialized_engine); + nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger); + nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( + serialized_engine.c_str(), serialized_engine.size(), nullptr); + + // debug print out engine binding; + std::stringstream oss; + for (int i = 0; i < trt_engine->getNbBindings(); i++) { + LOG(INFO) << "index: " << i + << ", binding name: " << trt_engine->getBindingName(i); + + bool input_flag = trt_engine->bindingIsInput(i); + oss << "input?: " << (input_flag ? "Y" : "N"); + + oss << "Dimension: "; + auto dims = trt_engine->getBindingDimensions(i); + oss << " nbDims: " << dims.nbDims << " -> "; + for (int j = 0; j < dims.nbDims; j++) oss << dims.d[j] << ", "; + LOG(INFO) << oss.str(); + oss.str(""); + switch (trt_engine->getBindingDataType(i)) { + case nvinfer1::DataType::kFLOAT: + LOG(INFO) << "data type: float" << std::endl; + break; + case nvinfer1::DataType::kHALF: + LOG(INFO) << "data type: half" << std::endl; + break; + case nvinfer1::DataType::kINT8: + LOG(INFO) << "data type: int8" << std::endl; + break; + } + } + + int nbBatch = -1; + // debug print out input arrays + std::vector<::tensorflow::DataType> input_type; + c->GetAttr("InT", &input_type); + oss.str(""); + for (size_t i = 0; i < c->num_inputs(); i++) { + // check if input shape is legit + auto input_shape = c->input(i); + int index = i; + oss << "input:" << i << " type: " << input_type[index] << " shape: "; + for (int j = 0; j < c->Rank(input_shape); j++) { + auto dimHandler = c->Dim(input_shape, j); + if (c->ValueKnown(dimHandler)) + oss << c->Value(dimHandler) << ", "; + else + oss << "?" << c->Value(dimHandler) << ", "; + if (j == 0) { + if (i == 0) + nbBatch = c->Value(dimHandler); + else if (nbBatch != c->Value(dimHandler)) + LOG(WARNING) << "!!!!!!nbBatch does not match!!!!!!"; + // assert(nbBatch == c->Value(dimHandler); + } + } + LOG(INFO) << oss.str(); + } + + // arrange input here + std::vector input_nodes; + c->GetAttr("input_nodes", &input_nodes); + for (size_t i = 0; i < input_nodes.size(); i++) { + int index = i; + LOG(INFO) << "input:" << i << " name: " << input_nodes[index]; + } + + // arrange output here + std::vector output_nodes; + c->GetAttr("output_nodes", &output_nodes); + oss.str(""); + for (size_t i = 0; i < output_nodes.size(); i++) { + int index = i; + int binding_index = + trt_engine->getBindingIndex(output_nodes[index].c_str()); + oss << "string name " << output_nodes[index]; + ShapeHandle output_shape; + std::vector vecDim; + vecDim.emplace_back(c->MakeDim(nbBatch)); + if (binding_index != -1) { + oss << "got binding " << binding_index; + auto dims = trt_engine->getBindingDimensions(binding_index); + for (int j = 0; j < dims.nbDims; j++) + vecDim.emplace_back(c->MakeDim(dims.d[j])); + } else { + oss << "no binding "; + } + output_shape = c->MakeShape(vecDim); + c->set_output(i, output_shape); + LOG(INFO) << oss.str(); + } + + return Status::OK(); +} +} // namespace shape_inference +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h new file mode 100644 index 00000000000..90a226d91d2 --- /dev/null +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_ + +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace shape_inference { +Status TRTEngineOpShapeInference(InferenceContext* c); +} // namespace shape_inference +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_ diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i new file mode 100644 index 00000000000..5f8e73a59fc --- /dev/null +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -0,0 +1,84 @@ +/* + + wrap trt_conversion + + */ +%{ +#define SWIG_FILE_WITH_INIT +%} +%include "std_string.i" +%include "std_pair.i" +%include "tensorflow/python/lib/core/strings.i" +%include "tensorflow/python/platform/base.i" +%template(StringPair) std::pair; +%template() std::pair; + +%{ +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/stat_summarizer.h" +#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +%} + +%ignoreall +%unignore tensorflow; +%unignore trt_convert; + +%{ + std::pair trt_convert(string graph_def_string,//const tensorflow::GraphDef& + std::vector output_names, + size_t max_batch_size, + size_t max_workspace_size + // unfortunately we can't use TF_Status here since it + // is in c/c_api and brings in a lot of other libraries + // which in turn declare ops. These ops are included + // statically in our library and cause an abort when + // module is loaded due to double registration + // until Tensorflow properly exposes these headers + // we have to work around this by returning a string + // and converting it to exception on python side. + //,TF_Status* out_status) { + ) { + string out_status; + + tensorflow::GraphDef graph_def; + if (!graph_def.ParseFromString(graph_def_string)) { + out_status="InvalidArgument;Couldn't interpret input as a GraphDef"; + return std::pair{out_status,""}; + } + + if (!output_names.size()) { + out_status="InvalidArgument;Size of the output_names vector is 0"; + return std::pair{out_status,""}; + //return ""; + } + tensorflow::GraphDef outGraph; + tensorflow::Status conversion_status = + tensorrt::convert::ConvertGraphDefToTensorRT(graph_def, + output_names, + max_batch_size, + max_workspace_size, + &outGraph); + if (!conversion_status.ok()) { + auto retCode=(int)conversion_status.code(); + char buff[2000]; + snprintf(buff,2000,"%d;%s",retCode,conversion_status.error_message().c_str()); + out_status=buff; + return std::pair{out_status,""}; + } + string result; + if (!outGraph.SerializeToString(&result)) { + out_status="InvalidArgument;Couldn't serialize output as a GraphDef"; + return std::pair{out_status,""}; + } + out_status="OK;All good!"; + return std::pair{out_status,result}; + } +%} + +std::pair trt_convert(string graph_def_string, + std::vector output_names, + size_t max_batch_size, + size_t max_workspace_size); + +%unignoreall diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 383c97344a0..838b1218a4f 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -279,7 +279,7 @@ def tf_cc_shared_object( linkopts=[], framework_so=tf_binary_additional_srcs(), **kwargs): - native.cc_binary( + native.cc_binary( name=name, srcs=srcs + framework_so, deps=deps, @@ -1281,6 +1281,45 @@ def tf_extension_linkopts(): def tf_extension_copts(): return [] # No extension c opts +# In tf_py_wrap_cc generated libraries +# module init functions are not exported unless +# they contain one of the keywords in the version file +# this prevents custom python modules. +# This function attempts to append init_module_name to list of +# exported functions in version script +def _append_init_to_versionscript_impl(ctx): + modName=ctx.attr.module_name + isVS=ctx.attr.is_version_script + if isVS: + ctx.actions.expand_template( + template=ctx.file.template_file, + output=ctx.outputs.versionscript, + substitutions={ + "global:":"global:\n init_%s;"%modName, + }, + is_executable=False, + ) + else: + ctx.actions.expand_template( + template=ctx.file.template_file, + output=ctx.outputs.versionscript, + substitutions={ + "*tensorflow*":"*tensorflow*\ninit_%s"%modName, + }, + is_executable=False, + ) + + +_append_init_to_versionscript= rule( + implementation=_append_init_to_versionscript_impl, + attrs={ + "module_name":attr.string(mandatory=True), + "template_file":attr.label(allow_files=True,single_file=True,mandatory=True), + "is_version_script":attr.bool(default=True,doc='whether target is a ld version script or exported symbol list',mandatory=False), + }, + outputs={"versionscript":"%{name}.lds"}, +) + def tf_py_wrap_cc(name, srcs, swig_includes=[], @@ -1302,26 +1341,39 @@ def tf_py_wrap_cc(name, toolchain_deps=["//tools/defaults:crosstool"], module_name=module_name, py_module_name=name) + vscriptname=name+"_versionscript" + _append_init_to_versionscript( + name=vscriptname, + module_name=module_name, + is_version_script=select({ + "@local_config_cuda//cuda:darwin":False, + "//conditions:default":True, + }), + template_file=select({ + "@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"), + "//conditions:default":clean_dep("//tensorflow:tf_version_script.lds") + }) + ) extra_linkopts = select({ "@local_config_cuda//cuda:darwin": [ "-Wl,-exported_symbols_list", - clean_dep("//tensorflow:tf_exported_symbols.lds") + "%s.lds"%vscriptname, ], clean_dep("//tensorflow:windows"): [], clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-Wl,--version-script", - clean_dep("//tensorflow:tf_version_script.lds") + "%s.lds"%vscriptname, ] }) extra_deps += select({ "@local_config_cuda//cuda:darwin": [ - clean_dep("//tensorflow:tf_exported_symbols.lds") + "%s.lds"%vscriptname, ], clean_dep("//tensorflow:windows"): [], clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ - clean_dep("//tensorflow:tf_version_script.lds") + "%s.lds"%vscriptname, ] }) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index ff5dd6a0b09..f47df0e25dc 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -11,6 +11,7 @@ load( ) load("//third_party/mkl:build_defs.bzl", "if_mkl") load("//tensorflow:tensorflow.bzl", "if_cuda") +load("@local_config_tensorrt//:build_defs.bzl", "if_trt") load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps") # This returns a list of headers of all public header libraries (e.g., @@ -201,7 +202,8 @@ sh_binary( "//tensorflow/python:test_ops", "//tensorflow/tools/dist_test/server:grpc_tensorflow_server", ], - }) + if_mkl(["//third_party/mkl:intel_binary_blob"]), + }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + + if_trt(["//tensorflow/contrib/tensorrt:init_py"]), ) # A genrule for generating a marker file for the pip package on Windows diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 0ba3cca9919..8850610cdb7 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -1,6 +1,7 @@ # TensorFlow external dependencies that can be loaded in WORKSPACE files. load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") +load("//third_party/tensorrt:build_defs.bzl", "trt_repository") load("//third_party/mkl:build_defs.bzl", "mkl_repository") load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") @@ -66,6 +67,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): # version we require here. check_bazel_version_at_least("0.5.4") cuda_configure(name="local_config_cuda") + trt_repository(name="local_config_tensorrt") git_configure(name="local_config_git") sycl_configure(name="local_config_sycl") python_configure(name="local_config_python") diff --git a/third_party/tensorrt/BUILD b/third_party/tensorrt/BUILD new file mode 100644 index 00000000000..e69de29bb2d diff --git a/third_party/tensorrt/BUILD.tpl b/third_party/tensorrt/BUILD.tpl new file mode 100644 index 00000000000..8962751f56f --- /dev/null +++ b/third_party/tensorrt/BUILD.tpl @@ -0,0 +1,42 @@ +# -*- python -*- +# Description: +# provide tensorrt information + +#TODO(Sami) these needs to be defined + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda") + +config_setting( + name = "trt_enabled", + define_values = { + "using_tensorrt":"true" + }, + visibility = ["//visibility:public"], +) + +cc_library( + name = "tensorrt", + srcs =[%{tensorrt_lib}], + hdrs = ["include/NvInfer.h", + "include/NvUtils.h", + ], + copts= cuda_default_copts(), + deps =["@local_config_cuda//cuda:cuda", + "@local_config_cuda//cuda:cudnn",], + linkstatic = 1, + #include_prefix="include/", + includes=["include/"], + visibility = ["//visibility:public"], +) + +%{tensorrt_genrules} + +# filegroup( +# name = "%{tensorrt_lib}", +# srcs = ["%{tensorrt_lib}"], +# visibility = ["//visibility:public"], +# ) diff --git a/third_party/tensorrt/LICENSE b/third_party/tensorrt/LICENSE new file mode 100644 index 00000000000..d3da228420e --- /dev/null +++ b/third_party/tensorrt/LICENSE @@ -0,0 +1,203 @@ +Copyright 2015 The TensorFlow Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2015, The TensorFlow Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/third_party/tensorrt/build_defs.bzl b/third_party/tensorrt/build_defs.bzl new file mode 100644 index 00000000000..392c5e06214 --- /dev/null +++ b/third_party/tensorrt/build_defs.bzl @@ -0,0 +1,85 @@ +# -*- python -*- +""" + add a repo_generator rule for tensorrt + +""" + +_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH" +_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION" + +def _is_trt_enabled(repo_ctx): + if "TF_NEED_TENSORRT" in repo_ctx.os.environ: + enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip() + return enable_trt == "1" + return False + +def _dummy_repo(repo_ctx): + + repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"), + {"%{tensorrt_lib}":"","%{tensorrt_genrules}":""}, + False) + repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"), + {"%{trt_configured}":"False"},False) + repo_ctx.file("include/NvUtils.h","",False) + repo_ctx.file("include/NvInfer.h","",False) + +def _trt_repo_impl(repo_ctx): + """ + Implements local_config_tensorrt + """ + + if not _is_trt_enabled(repo_ctx): + _dummy_repo(repo_ctx) + return + trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH] + trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION] +# if deb installation +# once a standardized installation between tar and deb +# is done, we don't need this + if trt_libdir == '/usr/lib/x86_64-linux-gnu': + incPath='/usr/include/x86_64-linux-gnu' + incname='/usr/include/x86_64-linux-gnu/NvInfer.h' + else: + incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath) + incname=incPath+'/NvInfer.h' + if len(trt_ver)>0: + origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver) + else: + origLib="%s/libnvinfer.so"%trt_libdir + objdump=repo_ctx.which("objdump") + if objdump == None: + if len(trt_ver)>0: + targetlib="lib/libnvinfer.so.%s"%(trt_ver[0]) + else: + targetlib="lib/libnvinfer.so" + else: + soname=repo_ctx.execute([objdump,"-p",origLib]) + for l in soname.stdout.splitlines(): + if "SONAME" in l: + lib=l.strip().split(" ")[-1] + targetlib="lib/%s"%(lib) + + if len(trt_ver)>0: + repo_ctx.symlink(origLib,targetlib) + else: + repo_ctx.symlink(origLib,targetlib) + grule=('genrule(\n name = "trtlinks",\n'+ + ' outs = [\n "%s",\n "include/NvInfer.h",\n "include/NvUtils.h",\n ],\n'%targetlib + + ' cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) + + '&&\n ln -sf %s $(@D)/include/NvInfer.h '%(incname) + + '&&\n ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath)) + repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"), + {"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule}, + False) + repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"), + {"%{trt_configured}":"True"},False) + +trt_repository=repository_rule( + implementation= _trt_repo_impl, + local=True, + environ=[ + "TF_NEED_TENSORRT", + _TF_TENSORRT_VERSION, + _TENSORRT_INSTALLATION_PATH, + ], + ) diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl new file mode 100644 index 00000000000..18f354ee5a3 --- /dev/null +++ b/third_party/tensorrt/build_defs.bzl.tpl @@ -0,0 +1,18 @@ +# -*- python -*- +""" +template file for trt functions + +""" + +def is_trt_enabled(): + return %{trt_configured} + +def if_trt(if_true,if_false=[]): + # if is_trt_enabled(): + # return if_true + # return if_false + + return select({ + "@local_config_tensorrt//:trt_enabled":if_true, + "//conditions:default":if_false, + })