Provide NVIDIA CUDA build data in metadata and API
This change is a second attempt at #38964, which was rolled back because it was fragile. First, cuda_configure.bzl templates a file with data it already pulled from get_cuda_config. gen_build_info loads that file to provide package build information within TensorFlow: from tensorflow.python.platform import build_info print(build_info.build_info) {'cuda_version': '10.2', 'cudnn_version': '7', ... } Also is exposed through tf.sysconfig.get_build_info(), a public API change. setup.py pulls build_info into package metadata. The wheel's long description ends with: TensorFlow 2.2.0 for NVIDIA GPUs was built with these platform and library versions: - NVIDIA CUDA 10.2 - NVIDIA cuDNN 7 - NVIDIA CUDA Compute Capabilities compute_30, compute_70 (etc.) I set one of the new CUDA Classifiers, and add metadata to the "platform" tag: >>> import pkginfo >>> a = pkginfo.Wheel('./tf_nightly_gpu-2.1.0-cp36-cp36m-linux_x86_64.whl') >>> a.platforms ['cuda_version:10.2', 'cudnn_version:7', ...] I'm not 100% confident this is the best way to accomplish this. It still seems odd to import like this setup.py, even though it works, even in an environment with TensorFlow installed. This method is much better than the old method as it uses data that was already gathered. It could be extended to gather tensorrt, nccl, etc. from other .bzl files, but I wanted to get feedback (and ensure this lands in 2.3) before designing something like that. Currently tested only on Linux GPU (Remote Build) for Python 3.6. I'd like to see more tests before merging. The API is the same as the earlier change. Resolves https://github.com/tensorflow/tensorflow/issues/38351. PiperOrigin-RevId: 315018663 Change-Id: Idf68a8fe4d1585164d22b5870894c879537c280d
This commit is contained in:
parent
f6db19cc34
commit
0a2db3d354
@ -96,6 +96,7 @@ tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.py.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/find_cuda_config.py.gz.base64
|
||||
|
@ -288,6 +288,7 @@ py_library(
|
||||
deps = [
|
||||
":_pywrap_util_port",
|
||||
":lib",
|
||||
":platform_build_info",
|
||||
":pywrap_tfe",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
@ -352,6 +353,24 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "sysconfig_test",
|
||||
size = "small",
|
||||
srcs = ["platform/sysconfig_test.py"],
|
||||
data = [
|
||||
"platform/sysconfig.py",
|
||||
],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":platform",
|
||||
":platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "flags_test",
|
||||
size = "small",
|
||||
|
@ -603,7 +603,7 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
|
||||
# (6 * units)
|
||||
bias = array_ops.split(K.flatten(bias), 6)
|
||||
|
||||
if build_info.is_cuda_build:
|
||||
if build_info.build_info['is_cuda_build']:
|
||||
# Note that the gate order for CuDNN is different from the canonical format.
|
||||
# canonical format is [z, r, h], whereas CuDNN is [r, z, h]. The swap need
|
||||
# to be done for kernel, recurrent_kernel, input_bias, recurrent_bias.
|
||||
@ -1365,7 +1365,7 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
|
||||
# so that mathematically it is same as the canonical LSTM implementation.
|
||||
full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0)
|
||||
|
||||
if build_info.is_rocm_build:
|
||||
if build_info.build_info['is_rocm_build']:
|
||||
# ROCm MIOpen's weight sequence for LSTM is different from both canonical
|
||||
# and Cudnn format
|
||||
# MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
|
||||
|
@ -25,8 +25,10 @@ from tensorflow.python.platform import test
|
||||
class BuildInfoTest(test.TestCase):
|
||||
|
||||
def testBuildInfo(self):
|
||||
self.assertEqual(build_info.is_rocm_build, test.is_built_with_rocm())
|
||||
self.assertEqual(build_info.is_cuda_build, test.is_built_with_cuda())
|
||||
self.assertEqual(build_info.build_info['is_rocm_build'],
|
||||
test.is_built_with_rocm())
|
||||
self.assertEqual(build_info.build_info['is_cuda_build'],
|
||||
test.is_built_with_cuda())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
MSVCP_DLL_NAMES = "msvcp_dll_names"
|
||||
|
||||
try:
|
||||
from tensorflow.python.platform import build_info
|
||||
@ -42,9 +43,9 @@ def preload_check():
|
||||
# we load the Python extension, so that we can raise an actionable error
|
||||
# message if they are not found.
|
||||
import ctypes # pylint: disable=g-import-not-at-top
|
||||
if hasattr(build_info, "msvcp_dll_names"):
|
||||
if MSVCP_DLL_NAMES in build_info.build_info:
|
||||
missing = []
|
||||
for dll_name in build_info.msvcp_dll_names.split(","):
|
||||
for dll_name in build_info.build_info[MSVCP_DLL_NAMES].split(","):
|
||||
try:
|
||||
ctypes.WinDLL(dll_name)
|
||||
except OSError:
|
||||
|
@ -24,6 +24,7 @@ import platform as _platform
|
||||
from tensorflow.python.framework.versions import CXX11_ABI_FLAG as _CXX11_ABI_FLAG
|
||||
from tensorflow.python.framework.versions import MONOLITHIC_BUILD as _MONOLITHIC_BUILD
|
||||
from tensorflow.python.framework.versions import VERSION as _VERSION
|
||||
from tensorflow.python.platform import build_info
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -84,3 +85,28 @@ def get_link_flags():
|
||||
else:
|
||||
flags.append('-l:libtensorflow_framework.so.%s' % ver)
|
||||
return flags
|
||||
|
||||
|
||||
@tf_export('sysconfig.get_build_info')
|
||||
def get_build_info():
|
||||
"""Get a dictionary describing TensorFlow's build environment.
|
||||
|
||||
Values are generated when TensorFlow is compiled, and are static for each
|
||||
TensorFlow package. The return value is a dictionary with string keys such as:
|
||||
|
||||
- cuda_version
|
||||
- cudnn_version
|
||||
- is_cuda_build
|
||||
- is_rocm_build
|
||||
- msvcp_dll_names
|
||||
- nvcuda_dll_name
|
||||
- cudart_dll_name
|
||||
- cudnn_dll_name
|
||||
|
||||
Note that the actual keys and values returned by this function is subject to
|
||||
change across different versions of TensorFlow or across platforms.
|
||||
|
||||
Returns:
|
||||
A Dictionary describing TensorFlow's build environment.
|
||||
"""
|
||||
return build_info.build_info
|
||||
|
38
tensorflow/python/platform/sysconfig_test.py
Normal file
38
tensorflow/python/platform/sysconfig_test.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import sysconfig
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SysconfigTest(googletest.TestCase):
|
||||
|
||||
def test_get_build_info_works(self):
|
||||
build_info = sysconfig.get_build_info()
|
||||
self.assertIsInstance(build_info, dict)
|
||||
|
||||
def test_rocm_cuda_info_matches(self):
|
||||
build_info = sysconfig.get_build_info()
|
||||
self.assertEqual(build_info["is_rocm_build"], test.is_built_with_rocm())
|
||||
self.assertEqual(build_info["is_cuda_build"], test.is_built_with_cuda())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -2615,6 +2615,10 @@ def tf_version_info_genrule(name, out):
|
||||
arguments = "--generate \"$@\" --git_tag_override=${GIT_TAG_OVERRIDE:-}",
|
||||
)
|
||||
|
||||
def _dict_to_kv(d):
|
||||
"""Convert a dictionary to a space-joined list of key=value pairs."""
|
||||
return " " + " ".join(["%s=%s" % (k, v) for k, v in d.items()])
|
||||
|
||||
def tf_py_build_info_genrule(name, out):
|
||||
_local_genrule(
|
||||
name = name,
|
||||
@ -2622,16 +2626,16 @@ def tf_py_build_info_genrule(name, out):
|
||||
exec_tool = "//tensorflow/tools/build_info:gen_build_info",
|
||||
arguments =
|
||||
"--raw_generate \"$@\" " +
|
||||
" --is_config_cuda " + if_cuda("True", "False") +
|
||||
" --is_config_rocm " + if_rocm("True", "False") +
|
||||
" --key_value " +
|
||||
if_cuda(" cuda_version_number=${TF_CUDA_VERSION:-} cudnn_version_number=${TF_CUDNN_VERSION:-} ", "") +
|
||||
if_windows(" msvcp_dll_names=msvcp140.dll,msvcp140_1.dll ", "") +
|
||||
if_windows_cuda(" ".join([
|
||||
"nvcuda_dll_name=nvcuda.dll",
|
||||
"cudart_dll_name=cudart64_$(echo $${TF_CUDA_VERSION:-} | sed \"s/\\.//\").dll",
|
||||
"cudnn_dll_name=cudnn64_${TF_CUDNN_VERSION:-}.dll",
|
||||
]), ""),
|
||||
" --key_value" +
|
||||
" is_rocm_build=" + if_rocm("True", "False") +
|
||||
" is_cuda_build=" + if_cuda("True", "False") +
|
||||
if_windows(_dict_to_kv({
|
||||
"msvcp_dll_names": "msvcp140.dll,msvcp140_1.dll",
|
||||
}), "") + if_windows_cuda(_dict_to_kv({
|
||||
"nvcuda_dll_name": "nvcuda.dll",
|
||||
"cudart_dll_name": "cudart64_$$(echo $${TF_CUDA_VERSION:-} | sed \"s/\\.//\").dll",
|
||||
"cudnn_dll_name": "cudnn64_$${TF_CUDNN_VERSION:-}.dll",
|
||||
}), ""),
|
||||
)
|
||||
|
||||
def cc_library_with_android_deps(
|
||||
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "MONOLITHIC_BUILD"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "get_build_info"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_compile_flags"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "MONOLITHIC_BUILD"
|
||||
mtype: "<type \'int\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "get_build_info"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_compile_flags"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -14,6 +14,7 @@ py_binary(
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no-remote-exec"],
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda_config_py",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Lint as: python3
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -22,47 +22,37 @@ import argparse
|
||||
|
||||
import six
|
||||
|
||||
# cuda.cuda is only valid in OSS
|
||||
try:
|
||||
from cuda.cuda import cuda_config # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
cuda_config = None
|
||||
|
||||
def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list):
|
||||
|
||||
def write_build_info(filename, key_value_list):
|
||||
"""Writes a Python that describes the build.
|
||||
|
||||
Args:
|
||||
filename: filename to write to.
|
||||
is_config_cuda: Whether this build is using CUDA.
|
||||
is_config_rocm: Whether this build is using ROCm.
|
||||
key_value_list: A list of "key=value" strings that will be added to the
|
||||
module as additional fields.
|
||||
|
||||
Raises:
|
||||
ValueError: If `key_value_list` includes the key "is_cuda_build", which
|
||||
would clash with one of the default fields.
|
||||
module's "build_info" dictionary as additional entries.
|
||||
"""
|
||||
module_docstring = "\"\"\"Generates a Python module containing information "
|
||||
module_docstring += "about the build.\"\"\""
|
||||
|
||||
build_config_rocm_bool = "False"
|
||||
build_config_cuda_bool = "False"
|
||||
build_info = {}
|
||||
for arg in key_value_list:
|
||||
key, value = six.ensure_str(arg).split("=")
|
||||
if value.lower() == "true":
|
||||
build_info[key] = True
|
||||
elif value.lower() == "false":
|
||||
build_info[key] = False
|
||||
else:
|
||||
build_info[key] = value
|
||||
|
||||
if is_config_rocm == "True":
|
||||
build_config_rocm_bool = "True"
|
||||
elif is_config_cuda == "True":
|
||||
build_config_cuda_bool = "True"
|
||||
|
||||
key_value_pair_stmts = []
|
||||
if key_value_list:
|
||||
for arg in key_value_list:
|
||||
key, value = six.ensure_str(arg).split("=")
|
||||
if key == "is_cuda_build":
|
||||
raise ValueError("The key \"is_cuda_build\" cannot be passed as one of "
|
||||
"the --key_value arguments.")
|
||||
if key == "is_rocm_build":
|
||||
raise ValueError("The key \"is_rocm_build\" cannot be passed as one of "
|
||||
"the --key_value arguments.")
|
||||
key_value_pair_stmts.append("%s = %r" % (key, value))
|
||||
key_value_pair_content = "\n".join(key_value_pair_stmts)
|
||||
if cuda_config:
|
||||
build_info.update(cuda_config.config)
|
||||
|
||||
contents = """
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -76,33 +66,19 @@ def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list):
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
%s
|
||||
\"\"\"Auto-generated module providing information about the build.\"\"\"
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
is_rocm_build = %s
|
||||
is_cuda_build = %s
|
||||
|
||||
%s
|
||||
""" % (module_docstring, build_config_rocm_bool, build_config_cuda_bool,
|
||||
key_value_pair_content)
|
||||
build_info = {build_info}
|
||||
""".format(build_info=build_info)
|
||||
open(filename, "w").write(contents)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Build info injection into the PIP package.""")
|
||||
|
||||
parser.add_argument(
|
||||
"--is_config_cuda",
|
||||
type=str,
|
||||
help="'True' for CUDA GPU builds, 'False' otherwise.")
|
||||
|
||||
parser.add_argument(
|
||||
"--is_config_rocm",
|
||||
type=str,
|
||||
help="'True' for ROCm GPU builds, 'False' otherwise.")
|
||||
|
||||
parser.add_argument("--raw_generate", type=str, help="Generate build_info.py")
|
||||
|
||||
parser.add_argument(
|
||||
@ -110,10 +86,7 @@ parser.add_argument(
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if (args.raw_generate is not None) and (args.is_config_cuda is not None) and (
|
||||
args.is_config_rocm is not None):
|
||||
write_build_info(args.raw_generate, args.is_config_cuda, args.is_config_rocm,
|
||||
args.key_value)
|
||||
if args.raw_generate:
|
||||
write_build_info(args.raw_generate, args.key_value)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"--raw_generate, --is_config_cuda and --is_config_rocm must be used")
|
||||
raise RuntimeError("--raw_generate must be used.")
|
||||
|
@ -1,3 +1,4 @@
|
||||
# lint as: python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -43,6 +44,8 @@ from setuptools import setup
|
||||
from setuptools.command.install import install as InstallCommandBase
|
||||
from setuptools.dist import Distribution
|
||||
|
||||
from tensorflow.python.platform import build_info
|
||||
|
||||
# This version string is semver compatible, but incompatible with pip.
|
||||
# For pip, we will remove all '-' characters from this string, and use the
|
||||
# result for pip.
|
||||
@ -70,6 +73,27 @@ REQUIRED_PACKAGES = [
|
||||
'scipy == 1.4.1',
|
||||
]
|
||||
|
||||
# Generate a footer describing the CUDA technology this release was built
|
||||
# against.
|
||||
GPU_DESCRIPTION = ''
|
||||
gpu_classifiers = []
|
||||
if build_info.build_info['is_cuda_build']:
|
||||
gpu_header = ('\nTensorFlow {} for NVIDIA GPUs was built with these '
|
||||
'platform and library versions:\n\n - ').format(_VERSION)
|
||||
bi = build_info.build_info
|
||||
desc_lines = []
|
||||
gpu_classifiers.append('Environment :: GPU :: NVIDIA CUDA :: ' +
|
||||
bi['cuda_version'])
|
||||
if 'cuda_version' in bi:
|
||||
desc_lines.append('NVIDIA CUDA ' + bi['cuda_version'])
|
||||
if 'cudnn_version' in bi:
|
||||
desc_lines.append('NVIDIA cuDNN ' + bi['cudnn_version'])
|
||||
if 'cuda_compute_capabilities' in bi:
|
||||
desc_lines.append('NVIDIA CUDA Compute Capabilities ' +
|
||||
', '.join(bi['cuda_compute_capabilities']))
|
||||
if desc_lines:
|
||||
GPU_DESCRIPTION = gpu_header + '\n - '.join(desc_lines)
|
||||
|
||||
if sys.byteorder == 'little':
|
||||
# grpcio does not build correctly on big-endian machines due to lack of
|
||||
# BoringSSL support.
|
||||
@ -113,7 +137,8 @@ CONSOLE_SCRIPTS = [
|
||||
# even though the command is not removed, just moved to a different wheel.
|
||||
'tensorboard = tensorboard.main:run_main',
|
||||
'tf_upgrade_v2 = tensorflow.tools.compatibility.tf_upgrade_v2_main:main',
|
||||
'estimator_ckpt_converter = tensorflow_estimator.python.estimator.tools.checkpoint_converter:main',
|
||||
'estimator_ckpt_converter = '
|
||||
'tensorflow_estimator.python.estimator.tools.checkpoint_converter:main',
|
||||
]
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
@ -152,11 +177,10 @@ class InstallHeaders(Command):
|
||||
"""
|
||||
description = 'install C/C++ header files'
|
||||
|
||||
user_options = [('install-dir=', 'd',
|
||||
'directory to install header files to'),
|
||||
('force', 'f',
|
||||
'force installation (overwrite existing files)'),
|
||||
]
|
||||
user_options = [
|
||||
('install-dir=', 'd', 'directory to install header files to'),
|
||||
('force', 'f', 'force installation (overwrite existing files)'),
|
||||
]
|
||||
|
||||
boolean_options = ['force']
|
||||
|
||||
@ -166,8 +190,7 @@ class InstallHeaders(Command):
|
||||
self.outfiles = []
|
||||
|
||||
def finalize_options(self):
|
||||
self.set_undefined_options('install',
|
||||
('install_headers', 'install_dir'),
|
||||
self.set_undefined_options('install', ('install_headers', 'install_dir'),
|
||||
('force', 'force'))
|
||||
|
||||
def mkdir_and_copy_file(self, header):
|
||||
@ -227,9 +250,7 @@ so_lib_paths = [
|
||||
|
||||
matches = []
|
||||
for path in so_lib_paths:
|
||||
matches.extend(
|
||||
['../' + x for x in find_files('*', path) if '.py' not in x]
|
||||
)
|
||||
matches.extend(['../' + x for x in find_files('*', path) if '.py' not in x])
|
||||
|
||||
if os.name == 'nt':
|
||||
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd'
|
||||
@ -250,17 +271,16 @@ headers = (
|
||||
list(find_files('*.h', 'tensorflow/stream_executor')) +
|
||||
list(find_files('*.h', 'google/com_google_protobuf/src')) +
|
||||
list(find_files('*.inc', 'google/com_google_protobuf/src')) +
|
||||
list(find_files('*', 'third_party/eigen3')) + list(
|
||||
find_files('*.h', 'tensorflow/include/external/com_google_absl')) +
|
||||
list(
|
||||
find_files('*.inc', 'tensorflow/include/external/com_google_absl'))
|
||||
+ list(find_files('*', 'tensorflow/include/external/eigen_archive')))
|
||||
list(find_files('*', 'third_party/eigen3')) +
|
||||
list(find_files('*.h', 'tensorflow/include/external/com_google_absl')) +
|
||||
list(find_files('*.inc', 'tensorflow/include/external/com_google_absl')) +
|
||||
list(find_files('*', 'tensorflow/include/external/eigen_archive')))
|
||||
|
||||
setup(
|
||||
name=project_name,
|
||||
version=_VERSION.replace('-', ''),
|
||||
description=DOCLINES[0],
|
||||
long_description='\n'.join(DOCLINES[2:]),
|
||||
long_description='\n'.join(DOCLINES[2:]) + GPU_DESCRIPTION,
|
||||
url='https://www.tensorflow.org/',
|
||||
download_url='https://github.com/tensorflow/tensorflow/tags',
|
||||
author='Google Inc.',
|
||||
@ -281,13 +301,18 @@ setup(
|
||||
] + matches,
|
||||
},
|
||||
zip_safe=False,
|
||||
# Accessible with importlib.metadata.metadata('tf-pkg-name').items()
|
||||
platforms=[
|
||||
'{}:{}'.format(key, value)
|
||||
for key, value in build_info.build_info.items()
|
||||
],
|
||||
distclass=BinaryDistribution,
|
||||
cmdclass={
|
||||
'install_headers': InstallHeaders,
|
||||
'install': InstallCommand,
|
||||
},
|
||||
# PyPI package information.
|
||||
classifiers=[
|
||||
classifiers=sorted([
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: Education',
|
||||
@ -305,7 +330,7 @@ setup(
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
] + gpu_classifiers),
|
||||
license='Apache 2.0',
|
||||
keywords='tensorflow tensor machine learning',
|
||||
)
|
||||
|
5
third_party/gpus/cuda/BUILD.tpl
vendored
5
third_party/gpus/cuda/BUILD.tpl
vendored
@ -218,4 +218,9 @@ bzl_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cuda_config_py",
|
||||
srcs = ["cuda/cuda_config.py"]
|
||||
)
|
||||
|
||||
%{copy_rules}
|
||||
|
17
third_party/gpus/cuda/cuda_config.py.tpl
vendored
Normal file
17
third_party/gpus/cuda/cuda_config.py.tpl
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
config = %{cuda_config}
|
31
third_party/gpus/cuda_configure.bzl
vendored
31
third_party/gpus/cuda_configure.bzl
vendored
@ -824,6 +824,15 @@ filegroup(name="cudnn-include")
|
||||
"cuda/cuda/cuda_config.h",
|
||||
)
|
||||
|
||||
# Set up cuda_config.py, which is used by gen_build_info to provide
|
||||
# static build environment info to the API
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"cuda:cuda_config.py",
|
||||
_py_tmpl_dict({}),
|
||||
"cuda/cuda/cuda_config.py",
|
||||
)
|
||||
|
||||
# If cuda_configure is not configured to build with GPU support, and the user
|
||||
# attempts to build with --config=cuda, add a dummy build rule to intercept
|
||||
# this and fail with an actionable error message.
|
||||
@ -938,6 +947,7 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
"crosstool:BUILD",
|
||||
"crosstool:cc_toolchain_config.bzl",
|
||||
"cuda:cuda_config.h",
|
||||
"cuda:cuda_config.py",
|
||||
]}
|
||||
tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
|
||||
find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
@ -1273,6 +1283,22 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
},
|
||||
)
|
||||
|
||||
# Set up cuda_config.py, which is used by gen_build_info to provide
|
||||
# static build environment info to the API
|
||||
repository_ctx.template(
|
||||
"cuda/cuda/cuda_config.py",
|
||||
tpl_paths["cuda:cuda_config.py"],
|
||||
_py_tmpl_dict({
|
||||
"cuda_version": cuda_config.cuda_version,
|
||||
"cudnn_version": cuda_config.cudnn_version,
|
||||
"cuda_compute_capabilities": cuda_config.compute_capabilities,
|
||||
"cpu_compiler": str(cc),
|
||||
}),
|
||||
)
|
||||
|
||||
def _py_tmpl_dict(d):
|
||||
return {"%{cuda_config}": str(d)}
|
||||
|
||||
def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
|
||||
"""Creates pointers to a remotely configured repo set up to build with CUDA."""
|
||||
_tpl(
|
||||
@ -1301,6 +1327,11 @@ def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
|
||||
config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.h"),
|
||||
{},
|
||||
)
|
||||
repository_ctx.template(
|
||||
"cuda/cuda/cuda_config.py",
|
||||
config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.py"),
|
||||
_py_tmpl_dict({}),
|
||||
)
|
||||
|
||||
repository_ctx.template(
|
||||
"crosstool/BUILD",
|
||||
|
Loading…
Reference in New Issue
Block a user