Merge pull request from angerson:master

PiperOrigin-RevId: 311025598
Change-Id: Ib47f014000e9183bd25b413ebeb43a4adf543d82
This commit is contained in:
TensorFlower Gardener 2020-05-11 17:21:17 -07:00
commit 94c821c0ef
13 changed files with 210 additions and 85 deletions

View File

@ -264,6 +264,7 @@ py_library(
deps = [ deps = [
":_pywrap_util_port", ":_pywrap_util_port",
":lib", ":lib",
":platform_build_info",
":pywrap_tfe", ":pywrap_tfe",
":util", ":util",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
@ -328,6 +329,24 @@ tf_py_test(
], ],
) )
tf_py_test(
name = "sysconfig_test",
size = "small",
srcs = ["platform/sysconfig_test.py"],
data = [
"platform/sysconfig.py",
],
python_version = "PY3",
tags = [
"no_pip",
"no_windows",
],
deps = [
":platform",
":platform_test",
],
)
tf_py_test( tf_py_test(
name = "flags_test", name = "flags_test",
size = "small", size = "small",

View File

@ -601,7 +601,7 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
# (6 * units) # (6 * units)
bias = array_ops.split(K.flatten(bias), 6) bias = array_ops.split(K.flatten(bias), 6)
if build_info.is_cuda_build: if build_info.build_info['is_cuda_build']:
# Note that the gate order for CuDNN is different from the canonical format. # Note that the gate order for CuDNN is different from the canonical format.
# canonical format is [z, r, h], whereas CuDNN is [r, z, h]. The swap need # canonical format is [z, r, h], whereas CuDNN is [r, z, h]. The swap need
# to be done for kernel, recurrent_kernel, input_bias, recurrent_bias. # to be done for kernel, recurrent_kernel, input_bias, recurrent_bias.
@ -1361,7 +1361,7 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
# so that mathematically it is same as the canonical LSTM implementation. # so that mathematically it is same as the canonical LSTM implementation.
full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0) full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0)
if build_info.is_rocm_build: if build_info.build_info['is_rocm_build']:
# ROCm MIOpen's weight sequence for LSTM is different from both canonical # ROCm MIOpen's weight sequence for LSTM is different from both canonical
# and Cudnn format # and Cudnn format
# MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o] # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]

View File

@ -25,8 +25,10 @@ from tensorflow.python.platform import test
class BuildInfoTest(test.TestCase): class BuildInfoTest(test.TestCase):
def testBuildInfo(self): def testBuildInfo(self):
self.assertEqual(build_info.is_rocm_build, test.is_built_with_rocm()) self.assertEqual(build_info.build_info['is_rocm_build'],
self.assertEqual(build_info.is_cuda_build, test.is_built_with_cuda()) test.is_built_with_rocm())
self.assertEqual(build_info.build_info['is_cuda_build'],
test.is_built_with_cuda())
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import os import os
MSVCP_DLL_NAMES = "msvcp_dll_names"
try: try:
from tensorflow.python.platform import build_info from tensorflow.python.platform import build_info
@ -42,9 +43,9 @@ def preload_check():
# we load the Python extension, so that we can raise an actionable error # we load the Python extension, so that we can raise an actionable error
# message if they are not found. # message if they are not found.
import ctypes # pylint: disable=g-import-not-at-top import ctypes # pylint: disable=g-import-not-at-top
if hasattr(build_info, "msvcp_dll_names"): if MSVCP_DLL_NAMES in build_info.build_info:
missing = [] missing = []
for dll_name in build_info.msvcp_dll_names.split(","): for dll_name in build_info.build_info[MSVCP_DLL_NAMES].split(","):
try: try:
ctypes.WinDLL(dll_name) ctypes.WinDLL(dll_name)
except OSError: except OSError:

View File

@ -24,6 +24,7 @@ import platform as _platform
from tensorflow.python.framework.versions import CXX11_ABI_FLAG as _CXX11_ABI_FLAG from tensorflow.python.framework.versions import CXX11_ABI_FLAG as _CXX11_ABI_FLAG
from tensorflow.python.framework.versions import MONOLITHIC_BUILD as _MONOLITHIC_BUILD from tensorflow.python.framework.versions import MONOLITHIC_BUILD as _MONOLITHIC_BUILD
from tensorflow.python.framework.versions import VERSION as _VERSION from tensorflow.python.framework.versions import VERSION as _VERSION
from tensorflow.python.platform import build_info
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -84,3 +85,30 @@ def get_link_flags():
else: else:
flags.append('-l:libtensorflow_framework.so.%s' % ver) flags.append('-l:libtensorflow_framework.so.%s' % ver)
return flags return flags
@tf_export('sysconfig.get_build_info')
def get_build_info():
"""Get a dictionary describing TensorFlow's build environment.
Values are generated when TensorFlow is compiled, and are static for each
TensorFlow package. The return value is a dictionary with string keys such as:
- cuda_version
- cudnn_version
- tensorrt_version
- nccl_version
- is_cuda_build
- is_rocm_build
- msvcp_dll_names
- nvcuda_dll_name
- cudart_dll_name
- cudnn_dll_name
Note that the actual keys and values returned by this function is subject to
change across different versions of TensorFlow or across platforms.
Returns:
A Dictionary describing TensorFlow's build environment.
"""
return build_info.build_info

View File

@ -0,0 +1,38 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import googletest
from tensorflow.python.platform import sysconfig
from tensorflow.python.platform import test
class SysconfigTest(googletest.TestCase):
def test_get_build_info_works(self):
build_info = sysconfig.get_build_info()
self.assertIsInstance(build_info, dict)
def test_rocm_cuda_info_matches(self):
build_info = sysconfig.get_build_info()
self.assertEqual(build_info["is_rocm_build"], test.is_built_with_rocm())
self.assertEqual(build_info["is_cuda_build"], test.is_built_with_cuda())
if __name__ == "__main__":
googletest.main()

View File

@ -2593,6 +2593,10 @@ def tf_version_info_genrule(name, out):
arguments = "--generate \"$@\" --git_tag_override=${GIT_TAG_OVERRIDE:-}", arguments = "--generate \"$@\" --git_tag_override=${GIT_TAG_OVERRIDE:-}",
) )
def dict_to_kv(d):
"""Convert a dictionary to a space-joined list of key=value pairs."""
return " " + " ".join(["%s=%s" % (k, v) for k, v in d.items()])
def tf_py_build_info_genrule(name, out): def tf_py_build_info_genrule(name, out):
_local_genrule( _local_genrule(
name = name, name = name,
@ -2600,16 +2604,17 @@ def tf_py_build_info_genrule(name, out):
exec_tool = "//tensorflow/tools/build_info:gen_build_info", exec_tool = "//tensorflow/tools/build_info:gen_build_info",
arguments = arguments =
"--raw_generate \"$@\" " + "--raw_generate \"$@\" " +
" --is_config_cuda " + if_cuda("True", "False") +
" --is_config_rocm " + if_rocm("True", "False") +
" --key_value" + " --key_value" +
if_cuda(" cuda_version_number=${TF_CUDA_VERSION:-} cudnn_version_number=${TF_CUDNN_VERSION:-} ", "") + " is_rocm_build=" + if_rocm("True", "False") +
if_windows(" msvcp_dll_names=msvcp140.dll,msvcp140_1.dll ", "") + " is_cuda_build=" + if_cuda("True", "False") +
if_windows_cuda(" ".join([ # TODO(angerson) Can we reliably load CUDA compute capabilities here?
"nvcuda_dll_name=nvcuda.dll", if_windows(dict_to_kv({
"cudart_dll_name=cudart64_$(echo $${TF_CUDA_VERSION:-} | sed \"s/\\.//\").dll", "msvcp_dll_names": "msvcp140.dll,msvcp140_1.dll",
"cudnn_dll_name=cudnn64_${TF_CUDNN_VERSION:-}.dll", }), "") + if_windows_cuda(dict_to_kv({
]), ""), "nvcuda_dll_name": "nvcuda.dll",
"cudart_dll_name": "cudart64_$$(echo $${TF_CUDA_VERSION:-} | sed \"s/\\.//\").dll",
"cudnn_dll_name": "cudnn64_$${TF_CUDNN_VERSION:-}.dll",
}), ""),
) )
def cc_library_with_android_deps( def cc_library_with_android_deps(

View File

@ -8,6 +8,10 @@ tf_module {
name: "MONOLITHIC_BUILD" name: "MONOLITHIC_BUILD"
mtype: "<type \'int\'>" mtype: "<type \'int\'>"
} }
member_method {
name: "get_build_info"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "get_compile_flags" name: "get_compile_flags"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"

View File

@ -8,6 +8,10 @@ tf_module {
name: "MONOLITHIC_BUILD" name: "MONOLITHIC_BUILD"
mtype: "<type \'int\'>" mtype: "<type \'int\'>"
} }
member_method {
name: "get_build_info"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "get_compile_flags" name: "get_compile_flags"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"

View File

@ -14,6 +14,7 @@ py_binary(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = ["no-remote-exec"], tags = ["no-remote-exec"],
deps = [ deps = [
"//third_party/gpus:find_cuda_config",
"@six_archive//:six", "@six_archive//:six",
], ],
) )

View File

@ -1,4 +1,4 @@
# Lint as: python2, python3 # Lint as: python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -19,50 +19,62 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
import platform
import sys
import six import six
# CUDA library gathering is only valid in OSS
try:
from third_party.gpus import find_cuda_config # pylint: disable=g-import-not-at-top
except ImportError:
find_cuda_config = None
def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list):
def write_build_info(filename, key_value_list):
"""Writes a Python that describes the build. """Writes a Python that describes the build.
Args: Args:
filename: filename to write to. filename: filename to write to.
is_config_cuda: Whether this build is using CUDA.
is_config_rocm: Whether this build is using ROCm.
key_value_list: A list of "key=value" strings that will be added to the key_value_list: A list of "key=value" strings that will be added to the
module as additional fields. module's "build_info" dictionary as additional entries.
Raises:
ValueError: If `key_value_list` includes the key "is_cuda_build", which
would clash with one of the default fields.
""" """
module_docstring = "\"\"\"Generates a Python module containing information "
module_docstring += "about the build.\"\"\""
build_config_rocm_bool = "False" build_info = {}
build_config_cuda_bool = "False"
if is_config_rocm == "True":
build_config_rocm_bool = "True"
elif is_config_cuda == "True":
build_config_cuda_bool = "True"
key_value_pair_stmts = []
if key_value_list:
for arg in key_value_list: for arg in key_value_list:
key, value = six.ensure_str(arg).split("=") key, value = six.ensure_str(arg).split("=")
if key == "is_cuda_build": if value.lower() == "true":
raise ValueError("The key \"is_cuda_build\" cannot be passed as one of " build_info[key] = True
"the --key_value arguments.") elif value.lower() == "false":
if key == "is_rocm_build": build_info[key] = False
raise ValueError("The key \"is_rocm_build\" cannot be passed as one of " else:
"the --key_value arguments.") build_info[key] = value
key_value_pair_stmts.append("%s = %r" % (key, value))
key_value_pair_content = "\n".join(key_value_pair_stmts) # Generate cuda_build_info, a dict describing the CUDA component versions
# used to build TensorFlow.
if find_cuda_config and build_info.get("is_cuda_build", False):
libs = ["_", "cuda", "cudnn"]
if platform.system() == "Linux":
if os.environ.get("TF_NEED_TENSORRT", "0") == "1":
libs.append("tensorrt")
if "TF_NCCL_VERSION" in os.environ:
libs.append("nccl")
# find_cuda_config accepts libraries to inspect as argv from the command
# line. We can work around this restriction by setting argv manually
# before calling find_cuda_config.
backup_argv = sys.argv
sys.argv = libs
cuda = find_cuda_config.find_cuda_config()
build_info["cuda_version"] = cuda["cuda_version"]
build_info["cudnn_version"] = cuda["cudnn_version"]
build_info["tensorrt_version"] = cuda.get("tensorrt_version", None)
build_info["nccl_version"] = cuda.get("nccl_version", None)
sys.argv = backup_argv
contents = """ contents = """
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -76,33 +88,21 @@ def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list):
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
%s \"\"\"Auto-generated module providing information about the build.\"\"\"
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
is_rocm_build = %s from collections import namedtuple
is_cuda_build = %s
%s build_info = {build_info}
""" % (module_docstring, build_config_rocm_bool, build_config_cuda_bool, """.format(build_info=build_info)
key_value_pair_content)
open(filename, "w").write(contents) open(filename, "w").write(contents)
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""Build info injection into the PIP package.""") description="""Build info injection into the PIP package.""")
parser.add_argument(
"--is_config_cuda",
type=str,
help="'True' for CUDA GPU builds, 'False' otherwise.")
parser.add_argument(
"--is_config_rocm",
type=str,
help="'True' for ROCm GPU builds, 'False' otherwise.")
parser.add_argument("--raw_generate", type=str, help="Generate build_info.py") parser.add_argument("--raw_generate", type=str, help="Generate build_info.py")
parser.add_argument( parser.add_argument(
@ -110,10 +110,7 @@ parser.add_argument(
args = parser.parse_args() args = parser.parse_args()
if (args.raw_generate is not None) and (args.is_config_cuda is not None) and ( if args.raw_generate:
args.is_config_rocm is not None): write_build_info(args.raw_generate, args.key_value)
write_build_info(args.raw_generate, args.is_config_cuda, args.is_config_rocm,
args.key_value)
else: else:
raise RuntimeError( raise RuntimeError("--raw_generate must be used.")
"--raw_generate, --is_config_cuda and --is_config_rocm must be used")

View File

@ -1,3 +1,4 @@
# lint as: python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -43,6 +44,8 @@ from setuptools import setup
from setuptools.command.install import install as InstallCommandBase from setuptools.command.install import install as InstallCommandBase
from setuptools.dist import Distribution from setuptools.dist import Distribution
from tensorflow.python.platform import build_info
DOCLINES = __doc__.split('\n') DOCLINES = __doc__.split('\n')
# This version string is semver compatible, but incompatible with pip. # This version string is semver compatible, but incompatible with pip.
@ -82,6 +85,22 @@ REQUIRED_PACKAGES = [
'scipy == 1.2.2;python_version<"3"', 'scipy == 1.2.2;python_version<"3"',
] ]
# Generate a footer describing the CUDA technology this release was built
# against.
GPU_DESCRIPTION = ''
if build_info.build_info['is_cuda_build']:
gpu_header = ('\nTensorFlow {} for NVIDIA GPUs was built with these '
'platform and library versions:\n\n - ').format(_VERSION)
bi = build_info.build_info
trt_ver = bi['tensorrt_version']
nccl_ver = bi['nccl_version']
GPU_DESCRIPTION = gpu_header + '\n - '.join([
'NVIDIA CUDA ' + bi['cuda_version'],
'NVIDIA cuDNN ' + bi['cudnn_version'],
'NVIDIA NCCL ' + 'not enabled' if not nccl_ver else nccl_ver,
'NVIDIA TensorRT ' + 'not enabled' if not trt_ver else trt_ver,
])
if sys.byteorder == 'little': if sys.byteorder == 'little':
# grpcio does not build correctly on big-endian machines due to lack of # grpcio does not build correctly on big-endian machines due to lack of
# BoringSSL support. # BoringSSL support.
@ -117,7 +136,8 @@ CONSOLE_SCRIPTS = [
# even though the command is not removed, just moved to a different wheel. # even though the command is not removed, just moved to a different wheel.
'tensorboard = tensorboard.main:run_main', 'tensorboard = tensorboard.main:run_main',
'tf_upgrade_v2 = tensorflow.tools.compatibility.tf_upgrade_v2_main:main', 'tf_upgrade_v2 = tensorflow.tools.compatibility.tf_upgrade_v2_main:main',
'estimator_ckpt_converter = tensorflow_estimator.python.estimator.tools.checkpoint_converter:main', 'estimator_ckpt_converter = '
'tensorflow_estimator.python.estimator.tools.checkpoint_converter:main',
] ]
# pylint: enable=line-too-long # pylint: enable=line-too-long
@ -161,10 +181,9 @@ class InstallHeaders(Command):
""" """
description = 'install C/C++ header files' description = 'install C/C++ header files'
user_options = [('install-dir=', 'd', user_options = [
'directory to install header files to'), ('install-dir=', 'd', 'directory to install header files to'),
('force', 'f', ('force', 'f', 'force installation (overwrite existing files)'),
'force installation (overwrite existing files)'),
] ]
boolean_options = ['force'] boolean_options = ['force']
@ -175,8 +194,7 @@ class InstallHeaders(Command):
self.outfiles = [] self.outfiles = []
def finalize_options(self): def finalize_options(self):
self.set_undefined_options('install', self.set_undefined_options('install', ('install_headers', 'install_dir'),
('install_headers', 'install_dir'),
('force', 'force')) ('force', 'force'))
def mkdir_and_copy_file(self, header): def mkdir_and_copy_file(self, header):
@ -236,9 +254,7 @@ so_lib_paths = [
matches = [] matches = []
for path in so_lib_paths: for path in so_lib_paths:
matches.extend( matches.extend(['../' + x for x in find_files('*', path) if '.py' not in x])
['../' + x for x in find_files('*', path) if '.py' not in x]
)
if os.name == 'nt': if os.name == 'nt':
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd' EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd'
@ -258,17 +274,16 @@ headers = (
list(find_files('*.h', 'tensorflow/stream_executor')) + list(find_files('*.h', 'tensorflow/stream_executor')) +
list(find_files('*.h', 'google/com_google_protobuf/src')) + list(find_files('*.h', 'google/com_google_protobuf/src')) +
list(find_files('*.inc', 'google/com_google_protobuf/src')) + list(find_files('*.inc', 'google/com_google_protobuf/src')) +
list(find_files('*', 'third_party/eigen3')) + list( list(find_files('*', 'third_party/eigen3')) +
find_files('*.h', 'tensorflow/include/external/com_google_absl')) + list(find_files('*.h', 'tensorflow/include/external/com_google_absl')) +
list( list(find_files('*.inc', 'tensorflow/include/external/com_google_absl')) +
find_files('*.inc', 'tensorflow/include/external/com_google_absl')) list(find_files('*', 'tensorflow/include/external/eigen_archive')))
+ list(find_files('*', 'tensorflow/include/external/eigen_archive')))
setup( setup(
name=project_name, name=project_name,
version=_VERSION.replace('-', ''), version=_VERSION.replace('-', ''),
description=DOCLINES[0], description=DOCLINES[0],
long_description='\n'.join(DOCLINES[2:]), long_description='\n'.join(DOCLINES[2:]) + GPU_DESCRIPTION,
url='https://www.tensorflow.org/', url='https://www.tensorflow.org/',
download_url='https://github.com/tensorflow/tensorflow/tags', download_url='https://github.com/tensorflow/tensorflow/tags',
author='Google Inc.', author='Google Inc.',
@ -289,6 +304,11 @@ setup(
] + matches, ] + matches,
}, },
zip_safe=False, zip_safe=False,
# Accessible with importlib.metadata.metadata('tf-pkg-name').items()
platforms=[
'{}:{}'.format(key, value)
for key, value in build_info.build_info.items()
],
distclass=BinaryDistribution, distclass=BinaryDistribution,
cmdclass={ cmdclass={
'install_headers': InstallHeaders, 'install_headers': InstallHeaders,

View File

@ -0,0 +1,6 @@
# Expose find_cuda_config.py as a library so other tools can reference it.
py_library(
name = "find_cuda_config",
srcs = ["find_cuda_config.py"],
visibility = ["//visibility:public"],
)