Set up TensorRT configurations for external use, and add a test.
PiperOrigin-RevId: 183347199
This commit is contained in:
parent
ffda6079ed
commit
76f6938baf
119
configure.py
119
configure.py
@ -43,6 +43,7 @@ _DEFAULT_CUDA_PATH = '/usr/local/cuda'
|
||||
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
|
||||
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
|
||||
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
|
||||
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu'
|
||||
_TF_OPENCL_VERSION = '1.2'
|
||||
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
|
||||
@ -959,6 +960,119 @@ def set_tf_cudnn_version(environ_cp):
|
||||
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
|
||||
|
||||
|
||||
def set_tf_tensorrt_install_path(environ_cp):
|
||||
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
|
||||
|
||||
Adapted from code contributed by Sami Kama (https://github.com/samikama).
|
||||
|
||||
Args:
|
||||
environ_cp: copy of the os.environ.
|
||||
|
||||
Raises:
|
||||
ValueError: if this method was called under non-Linux platform.
|
||||
UserInputError: if user has provided invalid input multiple times.
|
||||
"""
|
||||
if not is_linux():
|
||||
raise ValueError('Currently TensorRT is only supported on Linux platform.')
|
||||
|
||||
# Ask user whether to add TensorRT support.
|
||||
if str(int(get_var(
|
||||
environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1':
|
||||
return
|
||||
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
ask_tensorrt_path = (r'Please specify the location where TensorRT is '
|
||||
'installed. [Default is %s]:') % (
|
||||
_DEFAULT_TENSORRT_PATH_LINUX)
|
||||
trt_install_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
|
||||
_DEFAULT_TENSORRT_PATH_LINUX)
|
||||
|
||||
# Result returned from "read" will be used unexpanded. That make "~"
|
||||
# unusable. Going through one more level of expansion to handle that.
|
||||
trt_install_path = os.path.realpath(
|
||||
os.path.expanduser(trt_install_path))
|
||||
|
||||
def find_libs(search_path):
|
||||
"""Search for libnvinfer.so in "search_path"."""
|
||||
fl = set()
|
||||
if os.path.exists(search_path) and os.path.isdir(search_path):
|
||||
fl.update([os.path.realpath(os.path.join(search_path, x))
|
||||
for x in os.listdir(search_path) if 'libnvinfer.so' in x])
|
||||
return fl
|
||||
|
||||
possible_files = find_libs(trt_install_path)
|
||||
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
|
||||
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
|
||||
|
||||
def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
|
||||
"""Check the compatibility between tensorrt and cudnn/cudart libraries."""
|
||||
ldd_bin = which('ldd') or '/usr/bin/ldd'
|
||||
ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
|
||||
cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
|
||||
cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
|
||||
cudnn = None
|
||||
cudart = None
|
||||
for line in ldd_out:
|
||||
if 'libcudnn.so' in line:
|
||||
cudnn = cudnn_pattern.search(line)
|
||||
elif 'libcudart.so' in line:
|
||||
cudart = cuda_pattern.search(line)
|
||||
if cudnn and len(cudnn.group(1)):
|
||||
cudnn = convert_version_to_int(cudnn.group(1))
|
||||
if cudart and len(cudart.group(1)):
|
||||
cudart = convert_version_to_int(cudart.group(1))
|
||||
return (cudnn == cudnn_ver) and (cudart == cuda_ver)
|
||||
|
||||
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
|
||||
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
|
||||
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
|
||||
highest_ver = [0, None, None]
|
||||
|
||||
for lib_file in possible_files:
|
||||
if is_compatible(lib_file, cuda_ver, cudnn_ver):
|
||||
ver_str = nvinfer_pattern.search(lib_file).group(1)
|
||||
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
|
||||
if ver > highest_ver[0]:
|
||||
highest_ver = [ver, ver_str, lib_file]
|
||||
if highest_ver[1] is not None:
|
||||
trt_install_path = os.path.dirname(highest_ver[2])
|
||||
tf_tensorrt_version = highest_ver[1]
|
||||
break
|
||||
|
||||
# Try another alternative from ldconfig.
|
||||
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
|
||||
ldconfig_output = run_shell([ldconfig_bin, '-p'])
|
||||
search_result = re.search(
|
||||
'.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output)
|
||||
if search_result:
|
||||
libnvinfer_path_from_ldconfig = search_result.group(2)
|
||||
if os.path.exists(libnvinfer_path_from_ldconfig):
|
||||
if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver):
|
||||
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
|
||||
tf_tensorrt_version = search_result.group(1)
|
||||
break
|
||||
|
||||
# Reset and Retry
|
||||
print('Invalid path to TensorRT. None of the following files can be found:')
|
||||
print(trt_install_path)
|
||||
print(os.path.join(trt_install_path, 'lib'))
|
||||
print(os.path.join(trt_install_path, 'lib64'))
|
||||
if search_result:
|
||||
print(libnvinfer_path_from_ldconfig)
|
||||
|
||||
else:
|
||||
raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
|
||||
'times in a row. Assuming to be a scripting mistake.' %
|
||||
_DEFAULT_PROMPT_ASK_ATTEMPTS)
|
||||
|
||||
# Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
|
||||
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
|
||||
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
|
||||
environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
|
||||
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
|
||||
|
||||
|
||||
def get_native_cuda_compute_capabilities(environ_cp):
|
||||
"""Get native cuda compute capabilities.
|
||||
|
||||
@ -1244,9 +1358,11 @@ def main():
|
||||
environ_cp['TF_NEED_COMPUTECPP'] = '0'
|
||||
environ_cp['TF_NEED_OPENCL'] = '0'
|
||||
environ_cp['TF_CUDA_CLANG'] = '0'
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
||||
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
|
||||
'with_jemalloc', True)
|
||||
@ -1278,6 +1394,8 @@ def main():
|
||||
'TF_CUDA_CONFIG_REPO' not in environ_cp):
|
||||
set_tf_cuda_version(environ_cp)
|
||||
set_tf_cudnn_version(environ_cp)
|
||||
if is_linux():
|
||||
set_tf_tensorrt_install_path(environ_cp)
|
||||
set_tf_cuda_compute_capabilities(environ_cp)
|
||||
|
||||
set_tf_cuda_clang(environ_cp)
|
||||
@ -1332,6 +1450,7 @@ def main():
|
||||
'more details.')
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
config_info_line('tensorrt', 'Build with TensorRT support.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -370,6 +370,14 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# TODO(laigd): consider removing this option and make TensorRT enabled
|
||||
# automatically when CUDA is enabled.
|
||||
config_setting(
|
||||
name = "with_tensorrt_support",
|
||||
values = {"define": "with_tensorrt_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
@ -558,6 +566,7 @@ filegroup(
|
||||
"//tensorflow/contrib/tensor_forest/proto:all_files",
|
||||
"//tensorflow/contrib/tensorboard:all_files",
|
||||
"//tensorflow/contrib/tensorboard/db:all_files",
|
||||
"//tensorflow/contrib/tensorrt:all_files",
|
||||
"//tensorflow/contrib/testing:all_files",
|
||||
"//tensorflow/contrib/text:all_files",
|
||||
"//tensorflow/contrib/tfprof:all_files",
|
||||
|
45
tensorflow/contrib/tensorrt/BUILD
Normal file
45
tensorflow/contrib/tensorrt/BUILD
Normal file
@ -0,0 +1,45 @@
|
||||
# Description:
|
||||
# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow.
|
||||
# APIs are meant to change over time.
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load(
|
||||
"@local_config_tensorrt//:build_defs.bzl",
|
||||
"if_tensorrt",
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "tensorrt_test_cc",
|
||||
size = "small",
|
||||
srcs = ["tensorrt_test.cc"],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
] + if_tensorrt([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_tensorrt//:nv_infer",
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
159
tensorflow/contrib/tensorrt/tensorrt_test.cc
Normal file
159
tensorflow/contrib/tensorrt/tensorrt_test.cc
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda.h"
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class Logger : public nvinfer1::ILogger {
|
||||
public:
|
||||
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
|
||||
switch (severity) {
|
||||
case Severity::kINFO:
|
||||
LOG(INFO) << msg;
|
||||
break;
|
||||
case Severity::kWARNING:
|
||||
LOG(WARNING) << msg;
|
||||
break;
|
||||
case Severity::kINTERNAL_ERROR:
|
||||
case Severity::kERROR:
|
||||
LOG(ERROR) << msg;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ScopedWeights {
|
||||
public:
|
||||
ScopedWeights(float value) : value_(value) {
|
||||
w.type = nvinfer1::DataType::kFLOAT;
|
||||
w.values = &value_;
|
||||
w.count = 1;
|
||||
}
|
||||
const nvinfer1::Weights& get() { return w; }
|
||||
|
||||
private:
|
||||
float value_;
|
||||
nvinfer1::Weights w;
|
||||
};
|
||||
|
||||
const char* kInputTensor = "input";
|
||||
const char* kOutputTensor = "output";
|
||||
|
||||
// Creates a network to compute y=2x+3.
|
||||
nvinfer1::IHostMemory* CreateNetwork() {
|
||||
Logger logger;
|
||||
nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
|
||||
ScopedWeights weights(2.0);
|
||||
ScopedWeights bias(3.0);
|
||||
|
||||
nvinfer1::INetworkDefinition* network = builder->createNetwork();
|
||||
// Add the input.
|
||||
auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
|
||||
nvinfer1::DimsCHW{1, 1, 1});
|
||||
EXPECT_NE(input, nullptr);
|
||||
// Add the hidden layer.
|
||||
auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
|
||||
EXPECT_NE(layer, nullptr);
|
||||
// Mark the output.
|
||||
auto output = layer->getOutput(0);
|
||||
output->setName(kOutputTensor);
|
||||
network->markOutput(*output);
|
||||
// Build the engine
|
||||
builder->setMaxBatchSize(1);
|
||||
builder->setMaxWorkspaceSize(1 << 10);
|
||||
auto engine = builder->buildCudaEngine(*network);
|
||||
EXPECT_NE(engine, nullptr);
|
||||
// Serialize the engine to create a model, then close everything.
|
||||
nvinfer1::IHostMemory* model = engine->serialize();
|
||||
network->destroy();
|
||||
engine->destroy();
|
||||
builder->destroy();
|
||||
return model;
|
||||
}
|
||||
|
||||
// Executes the network.
|
||||
void Execute(nvinfer1::IExecutionContext& context, const float* input,
|
||||
float* output) {
|
||||
const nvinfer1::ICudaEngine& engine = context.getEngine();
|
||||
|
||||
// We have two bindings: input and output.
|
||||
ASSERT_EQ(engine.getNbBindings(), 2);
|
||||
const int input_index = engine.getBindingIndex(kInputTensor);
|
||||
const int output_index = engine.getBindingIndex(kOutputTensor);
|
||||
|
||||
// Create GPU buffers and a stream
|
||||
void* buffers[2];
|
||||
ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float)));
|
||||
ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float)));
|
||||
cudaStream_t stream;
|
||||
ASSERT_EQ(0, cudaStreamCreate(&stream));
|
||||
|
||||
// Copy the input to the GPU, execute the network, and copy the output back.
|
||||
//
|
||||
// Note that since the host buffer was not created as pinned memory, these
|
||||
// async copies are turned into sync copies. So the following synchronization
|
||||
// could be removed.
|
||||
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
|
||||
cudaMemcpyHostToDevice, stream));
|
||||
context.enqueue(1, buffers, stream, nullptr);
|
||||
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
|
||||
cudaMemcpyDeviceToHost, stream));
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
// Release the stream and the buffers
|
||||
cudaStreamDestroy(stream);
|
||||
ASSERT_EQ(0, cudaFree(buffers[input_index]));
|
||||
ASSERT_EQ(0, cudaFree(buffers[output_index]));
|
||||
}
|
||||
|
||||
TEST(TensorrtTest, BasicFunctions) {
|
||||
// Create the network model.
|
||||
nvinfer1::IHostMemory* model = CreateNetwork();
|
||||
// Use the model to create an engine and then an execution context.
|
||||
Logger logger;
|
||||
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
|
||||
nvinfer1::ICudaEngine* engine =
|
||||
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
|
||||
model->destroy();
|
||||
nvinfer1::IExecutionContext* context = engine->createExecutionContext();
|
||||
|
||||
// Execute the network.
|
||||
float input = 1234;
|
||||
float output;
|
||||
Execute(*context, &input, &output);
|
||||
EXPECT_EQ(output, input * 2 + 3);
|
||||
|
||||
// Destroy the engine.
|
||||
context->destroy();
|
||||
engine->destroy();
|
||||
runtime->destroy();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -10,6 +10,10 @@ load(
|
||||
"tf_additional_xla_deps_py",
|
||||
"if_static",
|
||||
)
|
||||
load(
|
||||
"@local_config_tensorrt//:build_defs.bzl",
|
||||
"if_tensorrt",
|
||||
)
|
||||
load(
|
||||
"@local_config_cuda//cuda:build_defs.bzl",
|
||||
"if_cuda",
|
||||
@ -197,6 +201,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False):
|
||||
"-fno-exceptions",
|
||||
"-ftemplate-depth=900"])
|
||||
+ if_cuda(["-DGOOGLE_CUDA=1"])
|
||||
+ if_tensorrt(["-DGOOGLE_TENSORRT=1"])
|
||||
+ if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML", "-fopenmp",])
|
||||
+ if_android_arm(["-mfpu=neon"])
|
||||
+ if_linux_x86_64(["-msse3"])
|
||||
@ -861,9 +866,11 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
|
||||
|
||||
When the library is built with --config=cuda:
|
||||
|
||||
- both deps and cuda_deps are used as dependencies
|
||||
- the cuda runtime is added as a dependency (if necessary)
|
||||
- The library additionally passes -DGOOGLE_CUDA=1 to the list of copts
|
||||
- Both deps and cuda_deps are used as dependencies.
|
||||
- The cuda runtime is added as a dependency (if necessary).
|
||||
- The library additionally passes -DGOOGLE_CUDA=1 to the list of copts.
|
||||
- In addition, when the library is also built with TensorRT enabled, it
|
||||
additionally passes -DGOOGLE_TENSORRT=1 to the list of copts.
|
||||
|
||||
Args:
|
||||
- cuda_deps: BUILD dependencies which will be linked if and only if:
|
||||
@ -882,7 +889,8 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
|
||||
clean_dep("//tensorflow/core:cuda"),
|
||||
"@local_config_cuda//cuda:cuda_headers"
|
||||
]),
|
||||
copts=copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]),
|
||||
copts=(copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
|
||||
if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
||||
**kwargs)
|
||||
|
||||
register_extension_info(
|
||||
|
@ -1,6 +1,7 @@
|
||||
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
|
||||
|
||||
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
|
||||
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
|
||||
load("//third_party/git:git_configure.bzl", "git_configure")
|
||||
load("//third_party/py:python_configure.bzl", "python_configure")
|
||||
@ -68,6 +69,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
check_bazel_version_at_least("0.5.4")
|
||||
clang6_configure(name="local_config_clang6")
|
||||
cuda_configure(name="local_config_cuda")
|
||||
tensorrt_configure(name="local_config_tensorrt")
|
||||
git_configure(name="local_config_git")
|
||||
sycl_configure(name="local_config_sycl")
|
||||
python_configure(name="local_config_python")
|
||||
|
100
third_party/gpus/cuda_configure.bzl
vendored
100
third_party/gpus/cuda_configure.bzl
vendored
@ -236,7 +236,7 @@ def _cudnn_install_basedir(repository_ctx):
|
||||
return cudnn_install_path
|
||||
|
||||
|
||||
def _matches_version(environ_version, detected_version):
|
||||
def matches_version(environ_version, detected_version):
|
||||
"""Checks whether the user-specified version matches the detected version.
|
||||
|
||||
This function performs a weak matching so that if the user specifies only the
|
||||
@ -317,7 +317,7 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
|
||||
environ_version = ""
|
||||
if _TF_CUDA_VERSION in repository_ctx.os.environ:
|
||||
environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
|
||||
if environ_version and not _matches_version(environ_version, full_version):
|
||||
if environ_version and not matches_version(environ_version, full_version):
|
||||
auto_configure_fail(
|
||||
("CUDA version detected from nvcc (%s) does not match " +
|
||||
"TF_CUDA_VERSION (%s)") % (full_version, environ_version))
|
||||
@ -338,35 +338,49 @@ _DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
|
||||
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
|
||||
|
||||
|
||||
def _find_cuda_define(repository_ctx, cudnn_header_dir, define):
|
||||
"""Returns the value of a #define in cudnn.h
|
||||
def find_cuda_define(repository_ctx, header_dir, header_file, define):
|
||||
"""Returns the value of a #define in a header file.
|
||||
|
||||
Greps through cudnn.h and returns the value of the specified #define. If the
|
||||
#define is not found, then raise an error.
|
||||
Greps through a header file and returns the value of the specified #define.
|
||||
If the #define is not found, then raise an error.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
cudnn_header_dir: The directory containing the cuDNN header.
|
||||
header_dir: The directory containing the header file.
|
||||
header_file: The header file name.
|
||||
define: The #define to search for.
|
||||
|
||||
Returns:
|
||||
The value of the #define found in cudnn.h.
|
||||
The value of the #define found in the header.
|
||||
"""
|
||||
# Confirm location of cudnn.h and grep for the line defining CUDNN_MAJOR.
|
||||
cudnn_h_path = repository_ctx.path("%s/cudnn.h" % cudnn_header_dir)
|
||||
if not cudnn_h_path.exists:
|
||||
auto_configure_fail("Cannot find cudnn.h at %s" % str(cudnn_h_path))
|
||||
result = repository_ctx.execute(["grep", "--color=never", "-E", define, str(cudnn_h_path)])
|
||||
# Confirm location of the header and grep for the line defining the macro.
|
||||
h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
|
||||
if not h_path.exists:
|
||||
auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
|
||||
result = repository_ctx.execute(
|
||||
# Grep one more lines as some #defines are splitted into two lines.
|
||||
["grep", "--color=never", "-A1", "-E", define, str(h_path)])
|
||||
if result.stderr:
|
||||
auto_configure_fail("Error reading %s: %s" %
|
||||
(result.stderr, str(cudnn_h_path)))
|
||||
auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
|
||||
|
||||
# Parse the cuDNN major version from the line defining CUDNN_MAJOR
|
||||
lines = result.stdout.splitlines()
|
||||
if len(lines) == 0 or lines[0].find(define) == -1:
|
||||
# Parse the version from the line defining the macro.
|
||||
if result.stdout.find(define) == -1:
|
||||
auto_configure_fail("Cannot find line containing '%s' in %s" %
|
||||
(define, str(cudnn_h_path)))
|
||||
return lines[0].replace(define, "").strip()
|
||||
(define, h_path))
|
||||
version = result.stdout
|
||||
# Remove the new line and '\' character if any.
|
||||
version = version.replace("\\", " ")
|
||||
version = version.replace("\n", " ")
|
||||
version = version.replace(define, "").lstrip()
|
||||
# Remove the code after the version number.
|
||||
version_end = version.find(" ")
|
||||
if version_end != -1:
|
||||
if version_end == 0:
|
||||
auto_configure_fail(
|
||||
"Cannot extract the version from line containing '%s' in %s" %
|
||||
(define, str(h_path)))
|
||||
version = version[:version_end].strip()
|
||||
return version
|
||||
|
||||
|
||||
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
|
||||
@ -382,12 +396,12 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
|
||||
"""
|
||||
cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
|
||||
cudnn_install_basedir)
|
||||
major_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
|
||||
_DEFINE_CUDNN_MAJOR)
|
||||
minor_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
|
||||
_DEFINE_CUDNN_MINOR)
|
||||
patch_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
|
||||
_DEFINE_CUDNN_PATCHLEVEL)
|
||||
major_version = find_cuda_define(
|
||||
repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MAJOR)
|
||||
minor_version = find_cuda_define(
|
||||
repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MINOR)
|
||||
patch_version = find_cuda_define(
|
||||
repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_PATCHLEVEL)
|
||||
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
|
||||
|
||||
# Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
|
||||
@ -395,7 +409,7 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
|
||||
environ_version = ""
|
||||
if _TF_CUDNN_VERSION in repository_ctx.os.environ:
|
||||
environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
|
||||
if environ_version and not _matches_version(environ_version, full_version):
|
||||
if environ_version and not matches_version(environ_version, full_version):
|
||||
cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
|
||||
cudnn_install_basedir)
|
||||
auto_configure_fail(
|
||||
@ -427,7 +441,7 @@ def _compute_capabilities(repository_ctx):
|
||||
return capabilities
|
||||
|
||||
|
||||
def _cpu_value(repository_ctx):
|
||||
def get_cpu_value(repository_ctx):
|
||||
"""Returns the name of the host operating system.
|
||||
|
||||
Args:
|
||||
@ -447,7 +461,7 @@ def _cpu_value(repository_ctx):
|
||||
|
||||
def _is_windows(repository_ctx):
|
||||
"""Returns true if the host operating system is windows."""
|
||||
return _cpu_value(repository_ctx) == "Windows"
|
||||
return get_cpu_value(repository_ctx) == "Windows"
|
||||
|
||||
def _lib_name(lib, cpu_value, version="", static=False):
|
||||
"""Constructs the platform-specific name of a library.
|
||||
@ -582,11 +596,8 @@ def _find_libs(repository_ctx, cuda_config):
|
||||
cuda_config: The CUDA config as returned by _get_cuda_config
|
||||
|
||||
Returns:
|
||||
Map of library names to structs of filename and path as returned by
|
||||
_find_cuda_lib and _find_cupti_lib.
|
||||
Map of library names to structs of filename and path.
|
||||
"""
|
||||
cudnn_version = cuda_config.cudnn_version
|
||||
cudnn_ext = ".%s" % cudnn_version if cudnn_version else ""
|
||||
cpu_value = cuda_config.cpu_value
|
||||
return {
|
||||
"cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
|
||||
@ -611,7 +622,7 @@ def _find_libs(repository_ctx, cuda_config):
|
||||
"cudnn": _find_cuda_lib(
|
||||
"cudnn", repository_ctx, cpu_value, cuda_config.cudnn_install_basedir,
|
||||
cuda_config.cudnn_version),
|
||||
"cupti": _find_cupti_lib(repository_ctx, cuda_config),
|
||||
"cupti": _find_cupti_lib(repository_ctx, cuda_config)
|
||||
}
|
||||
|
||||
|
||||
@ -654,7 +665,7 @@ def _get_cuda_config(repository_ctx):
|
||||
compute_capabilities: A list of the system's CUDA compute capabilities.
|
||||
cpu_value: The name of the host operating system.
|
||||
"""
|
||||
cpu_value = _cpu_value(repository_ctx)
|
||||
cpu_value = get_cpu_value(repository_ctx)
|
||||
cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
|
||||
cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
|
||||
cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
|
||||
@ -712,13 +723,13 @@ error_gpu_disabled()
|
||||
|
||||
|
||||
def _create_dummy_repository(repository_ctx):
|
||||
cpu_value = _cpu_value(repository_ctx)
|
||||
cpu_value = get_cpu_value(repository_ctx)
|
||||
|
||||
# Set up BUILD file for cuda/.
|
||||
_tpl(repository_ctx, "cuda:build_defs.bzl",
|
||||
{
|
||||
"%{cuda_is_configured}": "False",
|
||||
"%{cuda_extra_copts}": "[]"
|
||||
"%{cuda_extra_copts}": "[]",
|
||||
})
|
||||
_tpl(repository_ctx, "cuda:BUILD",
|
||||
{
|
||||
@ -805,7 +816,7 @@ def _norm_path(path):
|
||||
return path
|
||||
|
||||
|
||||
def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
|
||||
def symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
|
||||
src_files = [], dest_files = []):
|
||||
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
|
||||
|
||||
@ -913,11 +924,11 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
# cuda_toolkit_path
|
||||
cuda_toolkit_path = cuda_config.cuda_toolkit_path
|
||||
cuda_include_path = cuda_toolkit_path + "/include"
|
||||
genrules = [_symlink_genrule_for_dir(repository_ctx,
|
||||
genrules = [symlink_genrule_for_dir(repository_ctx,
|
||||
cuda_include_path, "cuda/include", "cuda-include")]
|
||||
genrules.append(_symlink_genrule_for_dir(repository_ctx,
|
||||
genrules.append(symlink_genrule_for_dir(repository_ctx,
|
||||
cuda_toolkit_path + "/nvvm", "cuda/nvvm", "cuda-nvvm"))
|
||||
genrules.append(_symlink_genrule_for_dir(repository_ctx,
|
||||
genrules.append(symlink_genrule_for_dir(repository_ctx,
|
||||
cuda_toolkit_path + "/extras/CUPTI/include",
|
||||
"cuda/extras/CUPTI/include", "cuda-extras"))
|
||||
|
||||
@ -927,15 +938,15 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
for lib in cuda_libs.values():
|
||||
cuda_lib_src.append(lib.path)
|
||||
cuda_lib_dest.append("cuda/lib/" + lib.file_name)
|
||||
genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
|
||||
genrules.append(symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
|
||||
cuda_lib_src, cuda_lib_dest))
|
||||
|
||||
# Set up the symbolic links for cudnn if cudnn was was not installed to
|
||||
# Set up the symbolic links for cudnn if cndnn was not installed to
|
||||
# CUDA_TOOLKIT_PATH.
|
||||
included_files = _read_dir(repository_ctx, cuda_include_path).replace(
|
||||
cuda_include_path, '').splitlines()
|
||||
if '/cudnn.h' not in included_files:
|
||||
genrules.append(_symlink_genrule_for_dir(repository_ctx, None,
|
||||
genrules.append(symlink_genrule_for_dir(repository_ctx, None,
|
||||
"cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"],
|
||||
["cudnn.h"]))
|
||||
else:
|
||||
@ -952,7 +963,6 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
"%{cuda_is_configured}": "True",
|
||||
"%{cuda_extra_copts}": _compute_cuda_extra_copts(
|
||||
repository_ctx, cuda_config.compute_capabilities),
|
||||
|
||||
})
|
||||
_tpl(repository_ctx, "cuda:BUILD",
|
||||
{
|
||||
|
0
third_party/tensorrt/BUILD
vendored
Normal file
0
third_party/tensorrt/BUILD
vendored
Normal file
67
third_party/tensorrt/BUILD.tpl
vendored
Normal file
67
third_party/tensorrt/BUILD.tpl
vendored
Normal file
@ -0,0 +1,67 @@
|
||||
# NVIDIA TensorRT
|
||||
# A high-performance deep learning inference optimizer and runtime.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "tensorrt_headers",
|
||||
hdrs = [%{tensorrt_headers}],
|
||||
includes = [
|
||||
"include",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nv_infer",
|
||||
srcs = [%{nv_infer}],
|
||||
data = [%{nv_infer}],
|
||||
includes = [
|
||||
"include",
|
||||
],
|
||||
copts= cuda_default_copts(),
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda",
|
||||
":tensorrt_headers",
|
||||
],
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nv_infer_plugin",
|
||||
srcs = [%{nv_infer_plugin}],
|
||||
data = [%{nv_infer_plugin}],
|
||||
includes = [
|
||||
"include",
|
||||
],
|
||||
copts= cuda_default_copts(),
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda",
|
||||
":nv_infer",
|
||||
":tensorrt_headers",
|
||||
],
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nv_parsers",
|
||||
srcs = [%{nv_parsers}],
|
||||
data = [%{nv_parsers}],
|
||||
includes = [
|
||||
"include",
|
||||
],
|
||||
copts= cuda_default_copts(),
|
||||
deps = [
|
||||
":tensorrt_headers",
|
||||
],
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
%{tensorrt_genrules}
|
7
third_party/tensorrt/build_defs.bzl.tpl
vendored
Normal file
7
third_party/tensorrt/build_defs.bzl.tpl
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
# Build configurations for TensorRT.
|
||||
|
||||
def if_tensorrt(if_true, if_false=[]):
|
||||
"""Tests whether TensorRT was enabled during the configure process."""
|
||||
if %{tensorrt_is_configured}:
|
||||
return if_true
|
||||
return if_false
|
224
third_party/tensorrt/tensorrt_configure.bzl
vendored
Normal file
224
third_party/tensorrt/tensorrt_configure.bzl
vendored
Normal file
@ -0,0 +1,224 @@
|
||||
# -*- Python -*-
|
||||
"""Repository rule for TensorRT configuration.
|
||||
|
||||
`tensorrt_configure` depends on the following environment variables:
|
||||
|
||||
* `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version.
|
||||
* `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library.
|
||||
"""
|
||||
|
||||
load(
|
||||
"//third_party/gpus:cuda_configure.bzl",
|
||||
"auto_configure_fail",
|
||||
"get_cpu_value",
|
||||
"find_cuda_define",
|
||||
"matches_version",
|
||||
"symlink_genrule_for_dir",
|
||||
)
|
||||
|
||||
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
|
||||
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
|
||||
|
||||
_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin", "nvparsers"]
|
||||
_TF_TENSORRT_HEADERS = [
|
||||
"NvInfer.h", "NvInferPlugin.h", "NvCaffeParser.h", "NvUffParser.h",
|
||||
"NvUtils.h"
|
||||
]
|
||||
|
||||
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
|
||||
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
|
||||
_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
|
||||
|
||||
|
||||
def _headers_exist(repository_ctx, path):
|
||||
"""Returns whether all TensorRT header files could be found in 'path'.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
path: The TensorRT include path to check.
|
||||
|
||||
Returns:
|
||||
True if all TensorRT header files can be found in the path.
|
||||
"""
|
||||
for h in _TF_TENSORRT_HEADERS:
|
||||
if not repository_ctx.path("%s/%s" % (path, h)).exists:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _find_trt_header_dir(repository_ctx, trt_install_path):
|
||||
"""Returns the path to the directory containing headers of TensorRT.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
trt_install_path: The TensorRT library install directory.
|
||||
|
||||
Returns:
|
||||
The path of the directory containing the TensorRT header.
|
||||
"""
|
||||
if trt_install_path == "/usr/lib/x86_64-linux-gnu":
|
||||
path = "/usr/include/x86_64-linux-gnu"
|
||||
if _headers_exist(repository_ctx, path):
|
||||
return path
|
||||
path = str(repository_ctx.path("%s/../include" % trt_install_path).realpath)
|
||||
if _headers_exist(repository_ctx, path):
|
||||
return path
|
||||
auto_configure_fail(
|
||||
"Cannot find NvInfer.h with TensorRT install path %s" % trt_install_path)
|
||||
|
||||
|
||||
def _trt_lib_version(repository_ctx, trt_install_path):
|
||||
"""Detects the library (e.g. libnvinfer) version of TensorRT.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
trt_install_path: The TensorRT library install directory.
|
||||
|
||||
Returns:
|
||||
A string containing the library version of TensorRT.
|
||||
"""
|
||||
trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
|
||||
major_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
|
||||
_DEFINE_TENSORRT_SONAME_MAJOR)
|
||||
minor_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
|
||||
_DEFINE_TENSORRT_SONAME_MINOR)
|
||||
patch_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
|
||||
_DEFINE_TENSORRT_SONAME_PATCH)
|
||||
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
|
||||
environ_version = repository_ctx.os.environ[_TF_TENSORRT_VERSION].strip()
|
||||
if not matches_version(environ_version, full_version):
|
||||
auto_configure_fail(
|
||||
("TensorRT library version detected from %s/%s (%s) does not match " +
|
||||
"TF_TENSORRT_VERSION (%s). To fix this rerun configure again.") %
|
||||
(trt_header_dir, "NvInfer.h", full_version, environ_version))
|
||||
return environ_version
|
||||
|
||||
|
||||
def _find_trt_libs(repository_ctx, trt_install_path, trt_lib_version):
|
||||
"""Finds the given TensorRT library on the system.
|
||||
|
||||
Adapted from code contributed by Sami Kama (https://github.com/samikama).
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
trt_install_path: The TensorRT library installation directory.
|
||||
trt_lib_version: The version of TensorRT library files as returned
|
||||
by _trt_lib_version.
|
||||
|
||||
Returns:
|
||||
Map of library names to structs with the following fields:
|
||||
src_file_path: The full path to the library found on the system.
|
||||
dst_file_name: The basename of the target library.
|
||||
"""
|
||||
objdump = repository_ctx.which("objdump")
|
||||
result = {}
|
||||
for lib in _TF_TENSORRT_LIBS:
|
||||
dst_file_name = "lib%s.so.%s" % (lib, trt_lib_version)
|
||||
src_file_path = repository_ctx.path("%s/%s" % (trt_install_path,
|
||||
dst_file_name))
|
||||
if not src_file_path.exists:
|
||||
auto_configure_fail(
|
||||
"Cannot find TensorRT library %s" % str(src_file_path))
|
||||
if objdump != None:
|
||||
objdump_out = repository_ctx.execute([objdump, "-p", str(src_file_path)])
|
||||
for line in objdump_out.stdout.splitlines():
|
||||
if "SONAME" in line:
|
||||
dst_file_name = line.strip().split(" ")[-1]
|
||||
result.update({
|
||||
lib:
|
||||
struct(
|
||||
dst_file_name=dst_file_name,
|
||||
src_file_path=str(src_file_path.realpath))
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def _tpl(repository_ctx, tpl, substitutions):
|
||||
repository_ctx.template(tpl, Label("//third_party/tensorrt:%s.tpl" % tpl),
|
||||
substitutions)
|
||||
|
||||
|
||||
def _create_dummy_repository(repository_ctx):
|
||||
"""Create a dummy TensorRT repository."""
|
||||
_tpl(repository_ctx, "build_defs.bzl", {"%{tensorrt_is_configured}": "False"})
|
||||
substitutions = {
|
||||
"%{tensorrt_genrules}": "",
|
||||
"%{tensorrt_headers}": "",
|
||||
}
|
||||
for lib in _TF_TENSORRT_LIBS:
|
||||
k = "%%{%s}" % lib.replace("nv", "nv_")
|
||||
substitutions.update({k: ""})
|
||||
_tpl(repository_ctx, "BUILD", substitutions)
|
||||
|
||||
|
||||
def _tensorrt_configure_impl(repository_ctx):
|
||||
"""Implementation of the tensorrt_configure repository rule."""
|
||||
if _TENSORRT_INSTALL_PATH not in repository_ctx.os.environ:
|
||||
_create_dummy_repository(repository_ctx)
|
||||
return
|
||||
|
||||
if (get_cpu_value(repository_ctx) != "Linux"):
|
||||
auto_configure_fail("TensorRT is supported only on Linux.")
|
||||
if _TF_TENSORRT_VERSION not in repository_ctx.os.environ:
|
||||
auto_configure_fail("TensorRT library (libnvinfer) version is not set.")
|
||||
trt_install_path = repository_ctx.os.environ[_TENSORRT_INSTALL_PATH].strip()
|
||||
if not repository_ctx.path(trt_install_path).exists:
|
||||
auto_configure_fail(
|
||||
"Cannot find TensorRT install path %s." % trt_install_path)
|
||||
|
||||
# Set up the symbolic links for the library files.
|
||||
trt_lib_version = _trt_lib_version(repository_ctx, trt_install_path)
|
||||
trt_libs = _find_trt_libs(repository_ctx, trt_install_path, trt_lib_version)
|
||||
trt_lib_src = []
|
||||
trt_lib_dest = []
|
||||
for lib in trt_libs.values():
|
||||
trt_lib_src.append(lib.src_file_path)
|
||||
trt_lib_dest.append(lib.dst_file_name)
|
||||
genrules = [
|
||||
symlink_genrule_for_dir(repository_ctx, None, "tensorrt/lib/",
|
||||
"tensorrt_lib", trt_lib_src, trt_lib_dest)
|
||||
]
|
||||
|
||||
# Set up the symbolic links for the header files.
|
||||
trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
|
||||
src_files = [
|
||||
"%s/%s" % (trt_header_dir, header) for header in _TF_TENSORRT_HEADERS
|
||||
]
|
||||
dest_files = _TF_TENSORRT_HEADERS
|
||||
genrules.append(
|
||||
symlink_genrule_for_dir(repository_ctx, None, "tensorrt/include/",
|
||||
"tensorrt_include", src_files, dest_files))
|
||||
|
||||
# Set up config file.
|
||||
_tpl(repository_ctx, "build_defs.bzl", {"%{tensorrt_is_configured}": "True"})
|
||||
|
||||
# Set up BUILD file.
|
||||
substitutions = {
|
||||
"%{tensorrt_genrules}": "\n".join(genrules),
|
||||
"%{tensorrt_headers}": '":tensorrt_include"',
|
||||
}
|
||||
for lib in _TF_TENSORRT_LIBS:
|
||||
k = "%%{%s}" % lib.replace("nv", "nv_")
|
||||
v = '"tensorrt/lib/%s"' % trt_libs[lib].dst_file_name
|
||||
substitutions.update({k: v})
|
||||
_tpl(repository_ctx, "BUILD", substitutions)
|
||||
|
||||
|
||||
tensorrt_configure = repository_rule(
|
||||
implementation=_tensorrt_configure_impl,
|
||||
environ=[
|
||||
_TENSORRT_INSTALL_PATH,
|
||||
_TF_TENSORRT_VERSION,
|
||||
],
|
||||
)
|
||||
"""Detects and configures the local CUDA toolchain.
|
||||
|
||||
Add the following to your WORKSPACE FILE:
|
||||
|
||||
```python
|
||||
tensorrt_configure(name = "local_config_tensorrt")
|
||||
```
|
||||
|
||||
Args:
|
||||
name: A unique name for this workspace rule.
|
||||
"""
|
Loading…
Reference in New Issue
Block a user