From 2c71fe1ff34af5277673db7b67320e6796823e0b Mon Sep 17 00:00:00 2001 From: Austin Anderson <angerson@google.com> Date: Mon, 27 Apr 2020 16:50:08 -0700 Subject: [PATCH 1/6] Provide NVIDIA CUDA build data in metadata and API This change: First exposes //third_party/gpus:find_cuda_config as a library. Then, it extends gen_build_info.py with find_cuda_config to provide package build information within TensorFlow's API. This is accessible as a dictionary: from tensorflow.python.platform import build_info print(build_info.cuda_build_info) {'cuda_version': '10.2', 'cudnn_version': '7', 'tensorrt_version': None, 'nccl_version': None} Finally, setup.py pulls that into package metadata. The same wheel's long description ends with: TensorFlow 2.1.0 for NVIDIA GPUs was built with these platform and library versions: - NVIDIA CUDA 10.2 - NVIDIA cuDNN 7 - NVIDIA NCCL not enabled - NVIDIA TensorRT not enabled In lieu of NVIDIA CUDA classifiers [1], the same metadata is exposed in the normally-unused "platform" tag: >>> import pkginfo >>> a = pkginfo.Wheel('./tf_nightly_gpu-2.1.0-cp36-cp36m-linux_x86_64.whl') >>> a.platforms ['cuda_version:10.2', 'cudnn_version:7', 'tensorrt_version:None', 'nccl_version:None'] I'm not 100% confident this is the best way to accomplish this. It seems odd to import like this setup.py, even though it works, even in an environment with TensorFlow installed. One caveat for RBE: the contents of genrules still run on the local system, so I had to syncronize my local environment with the RBE environment I used to build TensorFlow. I'm not sure if this is going to require intervention on TensorFlow's current CI. Currently tested only on Linux GPU (Remote Build) for Python 3.6. I'd like to see more tests before merging. [1]: (https://github.com/pypa/trove-classifiers/issues/25), --- tensorflow/tools/build_info/BUILD | 1 + tensorflow/tools/build_info/gen_build_info.py | 44 +++++++++++++--- tensorflow/tools/pip_package/setup.py | 51 ++++++++++++------- third_party/gpus/BUILD | 6 +++ 4 files changed, 77 insertions(+), 25 deletions(-) diff --git a/tensorflow/tools/build_info/BUILD b/tensorflow/tools/build_info/BUILD index 556dd0c86f0..1baa16724fe 100644 --- a/tensorflow/tools/build_info/BUILD +++ b/tensorflow/tools/build_info/BUILD @@ -15,5 +15,6 @@ py_binary( tags = ["no-remote-exec"], deps = [ "@six_archive//:six", + "//third_party/gpus:find_cuda_config", ], ) diff --git a/tensorflow/tools/build_info/gen_build_info.py b/tensorflow/tools/build_info/gen_build_info.py index df9068fb3d1..3180010bb13 100755 --- a/tensorflow/tools/build_info/gen_build_info.py +++ b/tensorflow/tools/build_info/gen_build_info.py @@ -1,4 +1,4 @@ -# Lint as: python2, python3 +# Lint as: python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,9 +19,14 @@ from __future__ import division from __future__ import print_function import argparse +import os +import platform +import sys import six +from third_party.gpus import find_cuda_config + def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list): """Writes a Python that describes the build. @@ -61,7 +66,31 @@ def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list): key_value_pair_stmts.append("%s = %r" % (key, value)) key_value_pair_content = "\n".join(key_value_pair_stmts) - contents = """ + # Generate cuda_build_info, a dict describing the CUDA component versions + # used to build TensorFlow. + cuda_build_info = "{}" + if is_config_cuda == "True": + libs = ["_", "cuda", "cudnn"] + if platform.system() == "Linux": + if os.environ.get("TF_NEED_TENSORRT", "0") == "1": + libs.append("tensorrt") + if "TF_NCCL_VERSION" in os.environ: + libs.append("nccl") + # find_cuda_config accepts libraries to inspect as argv from the command + # line. We can work around this restriction by setting argv manually + # before calling find_cuda_config. + backup_argv = sys.argv + sys.argv = libs + cuda = find_cuda_config.find_cuda_config() + cuda_build_info = str({ + "cuda_version": cuda["cuda_version"], + "cudnn_version": cuda["cudnn_version"], + "tensorrt_version": cuda.get("tensorrt_version", None), + "nccl_version": cuda.get("nccl_version", None), + }) + sys.argv = backup_argv + + contents = f""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -76,17 +105,16 @@ def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list): # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -%s +{module_docstring} from __future__ import absolute_import from __future__ import division from __future__ import print_function -is_rocm_build = %s -is_cuda_build = %s +is_rocm_build = {build_config_rocm_bool} +is_cuda_build = {build_config_cuda_bool} +cuda_build_info = {cuda_build_info} -%s -""" % (module_docstring, build_config_rocm_bool, build_config_cuda_bool, - key_value_pair_content) +""" open(filename, "w").write(contents) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 1c9a37bf652..1e99d659830 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -1,3 +1,4 @@ +# lint as: python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -43,6 +44,8 @@ from setuptools import setup from setuptools.command.install import install as InstallCommandBase from setuptools.dist import Distribution +from tensorflow.python.platform import build_info + DOCLINES = __doc__.split('\n') # This version string is semver compatible, but incompatible with pip. @@ -82,6 +85,20 @@ REQUIRED_PACKAGES = [ 'scipy == 1.2.2;python_version<"3"', ] +GPU_DESCRIPTION = '' +if build_info.is_cuda_build: + gpu_header = (f'\nTensorFlow {_VERSION} for NVIDIA GPUs was built with these ' + 'platform and library versions:\n\n - ') + cbi = build_info.cuda_build_info + trt_ver = cbi['tensorrt_version'] + nccl_ver = cbi['nccl_version'] + GPU_DESCRIPTION = gpu_header + '\n - '.join([ + 'NVIDIA CUDA ' + cbi['cuda_version'], + 'NVIDIA cuDNN ' + cbi['cudnn_version'], + 'NVIDIA NCCL ' + 'not enabled' if not nccl_ver else nccl_ver, + 'NVIDIA TensorRT ' + 'not enabled' if not trt_ver else trt_ver, + ]) + if sys.byteorder == 'little': # grpcio does not build correctly on big-endian machines due to lack of # BoringSSL support. @@ -117,7 +134,8 @@ CONSOLE_SCRIPTS = [ # even though the command is not removed, just moved to a different wheel. 'tensorboard = tensorboard.main:run_main', 'tf_upgrade_v2 = tensorflow.tools.compatibility.tf_upgrade_v2_main:main', - 'estimator_ckpt_converter = tensorflow_estimator.python.estimator.tools.checkpoint_converter:main', + 'estimator_ckpt_converter = ' + 'tensorflow_estimator.python.estimator.tools.checkpoint_converter:main', ] # pylint: enable=line-too-long @@ -161,11 +179,10 @@ class InstallHeaders(Command): """ description = 'install C/C++ header files' - user_options = [('install-dir=', 'd', - 'directory to install header files to'), - ('force', 'f', - 'force installation (overwrite existing files)'), - ] + user_options = [ + ('install-dir=', 'd', 'directory to install header files to'), + ('force', 'f', 'force installation (overwrite existing files)'), + ] boolean_options = ['force'] @@ -175,8 +192,7 @@ class InstallHeaders(Command): self.outfiles = [] def finalize_options(self): - self.set_undefined_options('install', - ('install_headers', 'install_dir'), + self.set_undefined_options('install', ('install_headers', 'install_dir'), ('force', 'force')) def mkdir_and_copy_file(self, header): @@ -236,9 +252,7 @@ so_lib_paths = [ matches = [] for path in so_lib_paths: - matches.extend( - ['../' + x for x in find_files('*', path) if '.py' not in x] - ) + matches.extend(['../' + x for x in find_files('*', path) if '.py' not in x]) if os.name == 'nt': EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.pyd' @@ -257,17 +271,16 @@ headers = ( list(find_files('*.h', 'tensorflow/stream_executor')) + list(find_files('*.h', 'google/com_google_protobuf/src')) + list(find_files('*.inc', 'google/com_google_protobuf/src')) + - list(find_files('*', 'third_party/eigen3')) + list( - find_files('*.h', 'tensorflow/include/external/com_google_absl')) + - list( - find_files('*.inc', 'tensorflow/include/external/com_google_absl')) - + list(find_files('*', 'tensorflow/include/external/eigen_archive'))) + list(find_files('*', 'third_party/eigen3')) + + list(find_files('*.h', 'tensorflow/include/external/com_google_absl')) + + list(find_files('*.inc', 'tensorflow/include/external/com_google_absl')) + + list(find_files('*', 'tensorflow/include/external/eigen_archive'))) setup( name=project_name, version=_VERSION.replace('-', ''), description=DOCLINES[0], - long_description='\n'.join(DOCLINES[2:]), + long_description='\n'.join(DOCLINES[2:]) + GPU_DESCRIPTION, url='https://www.tensorflow.org/', download_url='https://github.com/tensorflow/tensorflow/tags', author='Google Inc.', @@ -288,6 +301,10 @@ setup( ] + matches, }, zip_safe=False, + # Accessible with importlib.metadata.metadata('tf-pkg-name').items() + platforms=[ + f'{key}:{value}' for key, value in build_info.cuda_build_info.items() + ], distclass=BinaryDistribution, cmdclass={ 'install_headers': InstallHeaders, diff --git a/third_party/gpus/BUILD b/third_party/gpus/BUILD index e69de29bb2d..d570c4894ce 100644 --- a/third_party/gpus/BUILD +++ b/third_party/gpus/BUILD @@ -0,0 +1,6 @@ +# Expose find_cuda_config.py as a library so other tools can reference it. +py_library( + name = "find_cuda_config", + srcs = ["find_cuda_config.py"], + visibility = ["//visibility:public"], +) From 019e9fca7be020133ec8bbbfd69aa166602bae47 Mon Sep 17 00:00:00 2001 From: Austin Anderson <angerson@google.com> Date: Tue, 28 Apr 2020 12:27:04 -0700 Subject: [PATCH 2/6] Add NVIDIA CUDA and cuDNN info to tf.config --- tensorflow/python/framework/config.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 5361d7290e8..12feaf3d89d 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import context +from tensorflow.python.platform import build_info from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -653,3 +654,25 @@ def disable_mlir_bridge(): def disable_mlir_graph_optimization(): """Disables experimental MLIR-Based TensorFlow Compiler Optimizations.""" context.context().enable_mlir_graph_optimization = False + + +@tf_export('config.get_cuda_version_used_to_compile_tf') +def get_cuda_version_used_to_compile_tf(): + """Get the version of NVIDIA CUDA used to compile this TensorFlow release. + + Returns: + String representation of CUDA version number (Major.Minor) if CUDA support + is included, otherwise None. + """ + return build_info.cuda_build_info.get('cuda_version', None) + + +@tf_export('config.get_cudnn_version_used_to_compile_tf') +def get_cudnn_version_used_to_compile_tf(): + """Get the version of NVIDIA cuDNN used to compile this TensorFlow release. + + Returns: + String representation of cuDNN version number (Major only) if cuDNN support + is included, otherwise None. + """ + return build_info.cuda_build_info.get('cudnn_version', None) From df1ee5d5516413642aee6d0e1c8ef08a4eff1902 Mon Sep 17 00:00:00 2001 From: Austin Anderson <angerson@google.com> Date: Wed, 29 Apr 2020 16:04:35 -0700 Subject: [PATCH 3/6] Convert build_info to dict format and expose it. Since this module now generates a dictionary to expose in tf.config, it doesn't make much sense to store only certain values in the build_info dictionary and others as module variables. This obsoletes a lot of code in gen_build_info.py and I've removed it. I also updated all the in-code references I've found to the build_info module. I think this may break whomever used to be using the build_info library, but since it wasn't part of the API, there was no guarantee that it would continue to be available. --- tensorflow/python/framework/config.py | 35 ++++---- .../python/keras/layers/recurrent_v2.py | 4 +- tensorflow/python/platform/build_info_test.py | 4 +- tensorflow/python/platform/self_check.py | 5 +- tensorflow/tensorflow.bzl | 25 +++--- tensorflow/tools/build_info/gen_build_info.py | 83 ++++++------------- tensorflow/tools/pip_package/setup.py | 2 +- 7 files changed, 67 insertions(+), 91 deletions(-) diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 12feaf3d89d..9997b9833a5 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -656,23 +656,26 @@ def disable_mlir_graph_optimization(): context.context().enable_mlir_graph_optimization = False -@tf_export('config.get_cuda_version_used_to_compile_tf') -def get_cuda_version_used_to_compile_tf(): - """Get the version of NVIDIA CUDA used to compile this TensorFlow release. +@tf_export('config.get_build_info()') +def get_build_info(): + """Get a dictionary describing TensorFlow's build environment. + + Values are generated when TensorFlow is compiled, and are static for each + TensorFlow package. This information is limited to a subset of the following + keys based on the platforms targeted by the package: + + - cuda_version + - cudnn_version + - tensorrt_version + - nccl_version + - is_cuda_build + - is_rocm_build + - msvcp_dll_names + - nvcuda_dll_name + - cudart_dll_name + - cudnn_dll_name Returns: - String representation of CUDA version number (Major.Minor) if CUDA support - is included, otherwise None. + A Dictionary describing TensorFlow's build environment. """ return build_info.cuda_build_info.get('cuda_version', None) - - -@tf_export('config.get_cudnn_version_used_to_compile_tf') -def get_cudnn_version_used_to_compile_tf(): - """Get the version of NVIDIA cuDNN used to compile this TensorFlow release. - - Returns: - String representation of cuDNN version number (Major only) if cuDNN support - is included, otherwise None. - """ - return build_info.cuda_build_info.get('cudnn_version', None) diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index a9d5ef8587c..d5a54a0a8c6 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -601,7 +601,7 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major, # (6 * units) bias = array_ops.split(K.flatten(bias), 6) - if build_info.is_cuda_build: + if build_info.build_info["is_cuda_build"]: # Note that the gate order for CuDNN is different from the canonical format. # canonical format is [z, r, h], whereas CuDNN is [r, z, h]. The swap need # to be done for kernel, recurrent_kernel, input_bias, recurrent_bias. @@ -1361,7 +1361,7 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, # so that mathematically it is same as the canonical LSTM implementation. full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0) - if build_info.is_rocm_build: + if build_info.build_info["is_rocm_build"]: # ROCm MIOpen's weight sequence for LSTM is different from both canonical # and Cudnn format # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o] diff --git a/tensorflow/python/platform/build_info_test.py b/tensorflow/python/platform/build_info_test.py index f0df0b756cc..81fb5a6e1e3 100644 --- a/tensorflow/python/platform/build_info_test.py +++ b/tensorflow/python/platform/build_info_test.py @@ -25,8 +25,8 @@ from tensorflow.python.platform import test class BuildInfoTest(test.TestCase): def testBuildInfo(self): - self.assertEqual(build_info.is_rocm_build, test.is_built_with_rocm()) - self.assertEqual(build_info.is_cuda_build, test.is_built_with_cuda()) + self.assertEqual(build_info.build_info["is_rocm_build"], test.is_built_with_rocm()) + self.assertEqual(build_info.build_info["is_cuda_build"], test.is_built_with_cuda()) if __name__ == '__main__': diff --git a/tensorflow/python/platform/self_check.py b/tensorflow/python/platform/self_check.py index f6cf7705e13..c10c4108c7d 100644 --- a/tensorflow/python/platform/self_check.py +++ b/tensorflow/python/platform/self_check.py @@ -20,6 +20,7 @@ from __future__ import print_function import os +MSVCP_DLL_NAMES = "msvcp_dll_names" try: from tensorflow.python.platform import build_info @@ -42,9 +43,9 @@ def preload_check(): # we load the Python extension, so that we can raise an actionable error # message if they are not found. import ctypes # pylint: disable=g-import-not-at-top - if hasattr(build_info, "msvcp_dll_names"): + if MSVCP_DLL_NAMES in build_info.build_info: missing = [] - for dll_name in build_info.msvcp_dll_names.split(","): + for dll_name in build_info.build_info[MSVCP_DLL_NAMES].split(","): try: ctypes.WinDLL(dll_name) except OSError: diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d9229e00306..ada75fef957 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2593,6 +2593,10 @@ def tf_version_info_genrule(name, out): arguments = "--generate \"$@\" --git_tag_override=${GIT_TAG_OVERRIDE:-}", ) +def dict_to_kv(d): + """Convert a dictionary to a space-joined list of key=value pairs.""" + return " " + " ".join(["%s=%s" % (k,v) for k, v in d.items()]) + def tf_py_build_info_genrule(name, out): _local_genrule( name = name, @@ -2600,16 +2604,17 @@ def tf_py_build_info_genrule(name, out): exec_tool = "//tensorflow/tools/build_info:gen_build_info", arguments = "--raw_generate \"$@\" " + - " --is_config_cuda " + if_cuda("True", "False") + - " --is_config_rocm " + if_rocm("True", "False") + - " --key_value " + - if_cuda(" cuda_version_number=${TF_CUDA_VERSION:-} cudnn_version_number=${TF_CUDNN_VERSION:-} ", "") + - if_windows(" msvcp_dll_names=msvcp140.dll,msvcp140_1.dll ", "") + - if_windows_cuda(" ".join([ - "nvcuda_dll_name=nvcuda.dll", - "cudart_dll_name=cudart64_$(echo $${TF_CUDA_VERSION:-} | sed \"s/\\.//\").dll", - "cudnn_dll_name=cudnn64_${TF_CUDNN_VERSION:-}.dll", - ]), ""), + " --key_value" + + " is_rocm_build=" + if_rocm("True", "False") + + " is_cuda_build=" + if_cuda("True", "False") + # TODO(angerson) Can we reliably load CUDA compute capabilities here? + if_windows(dict_to_kv({ + "msvcp_dll_names": "msvcp140.dll,msvcp140_1.dll" + }), "") + if_windows_cuda(dict_to_kv({ + "nvcuda_dll_name": "nvcuda.dll", + "cudart_dll_name": "cudart64_$$(echo $${TF_CUDA_VERSION:-} | sed \"s/\\.//\").dll", + "cudnn_dll_name": "cudnn64_$${TF_CUDNN_VERSION:-}.dll", + }), ""), ) def cc_library_with_android_deps( diff --git a/tensorflow/tools/build_info/gen_build_info.py b/tensorflow/tools/build_info/gen_build_info.py index 3180010bb13..0757a1f57a6 100755 --- a/tensorflow/tools/build_info/gen_build_info.py +++ b/tensorflow/tools/build_info/gen_build_info.py @@ -28,48 +28,28 @@ import six from third_party.gpus import find_cuda_config -def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list): +def write_build_info(filename, key_value_list): """Writes a Python that describes the build. Args: filename: filename to write to. - is_config_cuda: Whether this build is using CUDA. - is_config_rocm: Whether this build is using ROCm. key_value_list: A list of "key=value" strings that will be added to the - module as additional fields. - - Raises: - ValueError: If `key_value_list` includes the key "is_cuda_build", which - would clash with one of the default fields. + module's "build_info" dictionary as additional entries. """ - module_docstring = "\"\"\"Generates a Python module containing information " - module_docstring += "about the build.\"\"\"" - build_config_rocm_bool = "False" - build_config_cuda_bool = "False" - - if is_config_rocm == "True": - build_config_rocm_bool = "True" - elif is_config_cuda == "True": - build_config_cuda_bool = "True" - - key_value_pair_stmts = [] - if key_value_list: - for arg in key_value_list: - key, value = six.ensure_str(arg).split("=") - if key == "is_cuda_build": - raise ValueError("The key \"is_cuda_build\" cannot be passed as one of " - "the --key_value arguments.") - if key == "is_rocm_build": - raise ValueError("The key \"is_rocm_build\" cannot be passed as one of " - "the --key_value arguments.") - key_value_pair_stmts.append("%s = %r" % (key, value)) - key_value_pair_content = "\n".join(key_value_pair_stmts) + build_info = {} + for arg in key_value_list: + key, value = six.ensure_str(arg).split("=") + if value.lower() == "true": + build_info[key] = True + elif value.lower() == "false": + build_info[key] = False + else: + build_info[key] = value # Generate cuda_build_info, a dict describing the CUDA component versions # used to build TensorFlow. - cuda_build_info = "{}" - if is_config_cuda == "True": + if build_info.get("is_cuda_build", False): libs = ["_", "cuda", "cudnn"] if platform.system() == "Linux": if os.environ.get("TF_NEED_TENSORRT", "0") == "1": @@ -82,16 +62,15 @@ def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list): backup_argv = sys.argv sys.argv = libs cuda = find_cuda_config.find_cuda_config() - cuda_build_info = str({ - "cuda_version": cuda["cuda_version"], - "cudnn_version": cuda["cudnn_version"], - "tensorrt_version": cuda.get("tensorrt_version", None), - "nccl_version": cuda.get("nccl_version", None), - }) + + build_info["cuda_version"] = cuda["cuda_version"] + build_info["cudnn_version"] = cuda["cudnn_version"] + build_info["tensorrt_version"] = cuda.get("tensorrt_version", None) + build_info["nccl_version"] = cuda.get("nccl_version", None) sys.argv = backup_argv contents = f""" -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -105,15 +84,14 @@ def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list): # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -{module_docstring} +\"\"\"Auto-generated module providing information about the build.\"\"\"" from __future__ import absolute_import from __future__ import division from __future__ import print_function -is_rocm_build = {build_config_rocm_bool} -is_cuda_build = {build_config_cuda_bool} -cuda_build_info = {cuda_build_info} +from collections import namedtuple +build_info = {build_info} """ open(filename, "w").write(contents) @@ -121,16 +99,6 @@ cuda_build_info = {cuda_build_info} parser = argparse.ArgumentParser( description="""Build info injection into the PIP package.""") -parser.add_argument( - "--is_config_cuda", - type=str, - help="'True' for CUDA GPU builds, 'False' otherwise.") - -parser.add_argument( - "--is_config_rocm", - type=str, - help="'True' for ROCm GPU builds, 'False' otherwise.") - parser.add_argument("--raw_generate", type=str, help="Generate build_info.py") parser.add_argument( @@ -138,10 +106,9 @@ parser.add_argument( args = parser.parse_args() -if (args.raw_generate is not None) and (args.is_config_cuda is not None) and ( - args.is_config_rocm is not None): - write_build_info(args.raw_generate, args.is_config_cuda, args.is_config_rocm, - args.key_value) +if args.raw_generate: + print(args.key_value) + write_build_info(args.raw_generate, args.key_value) else: raise RuntimeError( - "--raw_generate, --is_config_cuda and --is_config_rocm must be used") + "--raw_generate must be used.") diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 1e99d659830..c39bd254442 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -86,7 +86,7 @@ REQUIRED_PACKAGES = [ ] GPU_DESCRIPTION = '' -if build_info.is_cuda_build: +if build_info.build_info['is_cuda_build']: gpu_header = (f'\nTensorFlow {_VERSION} for NVIDIA GPUs was built with these ' 'platform and library versions:\n\n - ') cbi = build_info.cuda_build_info From 777482a6269a61218a5e3a4c3468d883573261be Mon Sep 17 00:00:00 2001 From: Austin Anderson <angerson@google.com> Date: Thu, 30 Apr 2020 13:14:45 -0700 Subject: [PATCH 4/6] Fix incorrect call to tf_export --- tensorflow/python/framework/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 9997b9833a5..c00157f6193 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -656,7 +656,7 @@ def disable_mlir_graph_optimization(): context.context().enable_mlir_graph_optimization = False -@tf_export('config.get_build_info()') +@tf_export('config.get_build_info') def get_build_info(): """Get a dictionary describing TensorFlow's build environment. From 642dc30acaedcdc3daa2558766aa66871bd1714c Mon Sep 17 00:00:00 2001 From: Austin Anderson <angerson@google.com> Date: Thu, 30 Apr 2020 17:45:36 -0700 Subject: [PATCH 5/6] Fix typos from the latest version --- tensorflow/tensorflow.bzl | 2 +- tensorflow/tools/build_info/gen_build_info.py | 2 +- tensorflow/tools/pip_package/setup.py | 14 ++++++++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index ada75fef957..501b9f9b088 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2608,7 +2608,7 @@ def tf_py_build_info_genrule(name, out): + " is_rocm_build=" + if_rocm("True", "False") + " is_cuda_build=" + if_cuda("True", "False") # TODO(angerson) Can we reliably load CUDA compute capabilities here? - if_windows(dict_to_kv({ + + if_windows(dict_to_kv({ "msvcp_dll_names": "msvcp140.dll,msvcp140_1.dll" }), "") + if_windows_cuda(dict_to_kv({ "nvcuda_dll_name": "nvcuda.dll", diff --git a/tensorflow/tools/build_info/gen_build_info.py b/tensorflow/tools/build_info/gen_build_info.py index 0757a1f57a6..a00fc064fa5 100755 --- a/tensorflow/tools/build_info/gen_build_info.py +++ b/tensorflow/tools/build_info/gen_build_info.py @@ -84,7 +84,7 @@ def write_build_info(filename, key_value_list): # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -\"\"\"Auto-generated module providing information about the build.\"\"\"" +\"\"\"Auto-generated module providing information about the build.\"\"\" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index c39bd254442..abb00eecdfb 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -85,16 +85,18 @@ REQUIRED_PACKAGES = [ 'scipy == 1.2.2;python_version<"3"', ] +# Generate a footer describing the CUDA technology this release was built +# against. GPU_DESCRIPTION = '' if build_info.build_info['is_cuda_build']: gpu_header = (f'\nTensorFlow {_VERSION} for NVIDIA GPUs was built with these ' 'platform and library versions:\n\n - ') - cbi = build_info.cuda_build_info - trt_ver = cbi['tensorrt_version'] - nccl_ver = cbi['nccl_version'] + bi = build_info.build_info + trt_ver = bi['tensorrt_version'] + nccl_ver = bi['nccl_version'] GPU_DESCRIPTION = gpu_header + '\n - '.join([ - 'NVIDIA CUDA ' + cbi['cuda_version'], - 'NVIDIA cuDNN ' + cbi['cudnn_version'], + 'NVIDIA CUDA ' + bi['cuda_version'], + 'NVIDIA cuDNN ' + bi['cudnn_version'], 'NVIDIA NCCL ' + 'not enabled' if not nccl_ver else nccl_ver, 'NVIDIA TensorRT ' + 'not enabled' if not trt_ver else trt_ver, ]) @@ -303,7 +305,7 @@ setup( zip_safe=False, # Accessible with importlib.metadata.metadata('tf-pkg-name').items() platforms=[ - f'{key}:{value}' for key, value in build_info.cuda_build_info.items() + f'{key}:{value}' for key, value in build_info.build_info.items() ], distclass=BinaryDistribution, cmdclass={ From bf687433298f29af2ee7fc1068329b50ed310693 Mon Sep 17 00:00:00 2001 From: Austin Anderson <angerson@google.com> Date: Fri, 1 May 2020 11:56:12 -0700 Subject: [PATCH 6/6] Remove extra debug print statement --- tensorflow/tools/build_info/gen_build_info.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/tools/build_info/gen_build_info.py b/tensorflow/tools/build_info/gen_build_info.py index a00fc064fa5..42f62f5579f 100755 --- a/tensorflow/tools/build_info/gen_build_info.py +++ b/tensorflow/tools/build_info/gen_build_info.py @@ -107,7 +107,6 @@ parser.add_argument( args = parser.parse_args() if args.raw_generate: - print(args.key_value) write_build_info(args.raw_generate, args.key_value) else: raise RuntimeError(