enhancing the build_info_test to check for test.is_built_with_rocm() functionality

This commit is contained in:
Deven Desai 2019-06-26 21:21:55 +00:00
parent c9e11b03d1
commit 9a74324a66
3 changed files with 51 additions and 30 deletions
tensorflow

View File

@ -25,6 +25,7 @@ from tensorflow.python.platform import test
class BuildInfoTest(test.TestCase):
def testBuildInfo(self):
self.assertEqual(build_info.is_rocm_build, test.is_built_with_rocm())
self.assertEqual(build_info.is_cuda_build, test.is_built_with_cuda())

View File

@ -2315,8 +2315,9 @@ def tf_py_build_info_genrule():
name = "py_build_info_gen",
outs = ["platform/build_info.py"],
cmd =
"$(location //tensorflow/tools/build_info:gen_build_info) --raw_generate \"$@\" --build_config " +
if_cuda("cuda", "cpu") +
"$(location //tensorflow/tools/build_info:gen_build_info) --raw_generate \"$@\" " +
" --is_config_cuda " + if_cuda("True", "False") +
" --is_config_rocm " + if_rocm("True", "False") +
" --key_value " +
if_cuda(" cuda_version_number=$${TF_CUDA_VERSION:-} cudnn_version_number=$${TF_CUDNN_VERSION:-} ", "") +
if_windows(" msvcp_dll_name=msvcp140.dll ", "") +

View File

@ -19,8 +19,8 @@ from __future__ import print_function
import argparse
def write_build_info(filename, build_config, key_value_list):
"""Writes a Python that describes the build.
def write_build_info(filename, is_config_cuda, is_config_rocm, key_value_list):
"""Writes a Python that describes the build.
Args:
filename: filename to write to.
@ -33,24 +33,33 @@ def write_build_info(filename, build_config, key_value_list):
ValueError: If `key_value_list` includes the key "is_cuda_build", which
would clash with one of the default fields.
"""
module_docstring = "\"\"\"Generates a Python module containing information "
module_docstring += "about the build.\"\"\""
if build_config == "cuda":
build_config_bool = "True"
else:
build_config_bool = "False"
module_docstring = "\"\"\"Generates a Python module containing information "
module_docstring += "about the build.\"\"\""
key_value_pair_stmts = []
if key_value_list:
for arg in key_value_list:
key, value = arg.split("=")
if key == "is_cuda_build":
raise ValueError("The key \"is_cuda_build\" cannot be passed as one of "
"the --key_value arguments.")
key_value_pair_stmts.append("%s = %r" % (key, value))
key_value_pair_content = "\n".join(key_value_pair_stmts)
build_config_rocm_bool = "False"
build_config_cuda_bool = "False"
contents = """
if is_config_rocm == "True":
build_config_rocm_bool = "True"
elif is_config_cuda == "True":
build_config_cuda_bool = "True"
key_value_pair_stmts = []
if key_value_list:
for arg in key_value_list:
key, value = arg.split("=")
if key == "is_cuda_build":
raise ValueError(
"The key \"is_cuda_build\" cannot be passed as one of "
"the --key_value arguments.")
if key == "is_rocm_build":
raise ValueError(
"The key \"is_rocm_build\" cannot be passed as one of "
"the --key_value arguments.")
key_value_pair_stmts.append("%s = %r" % (key, value))
key_value_pair_content = "\n".join(key_value_pair_stmts)
contents = """
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -70,29 +79,39 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
is_rocm_build = %s
is_cuda_build = %s
%s
""" % (module_docstring, build_config_bool, key_value_pair_content)
open(filename, "w").write(contents)
""" % (module_docstring, build_config_rocm_bool, build_config_cuda_bool,
key_value_pair_content)
open(filename, "w").write(contents)
parser = argparse.ArgumentParser(
description="""Build info injection into the PIP package.""")
parser.add_argument(
"--build_config",
type=str,
help="Either 'cuda' for GPU builds or 'cpu' for CPU builds.")
parser.add_argument("--is_config_cuda",
type=str,
help="'True' for CUDA GPU builds, 'False' otherwise.")
parser.add_argument("--is_config_rocm",
type=str,
help="'True' for ROCm GPU builds, 'False' otherwise.")
parser.add_argument("--raw_generate", type=str, help="Generate build_info.py")
parser.add_argument("--key_value", type=str, nargs="*",
parser.add_argument("--key_value",
type=str,
nargs="*",
help="List of key=value pairs.")
args = parser.parse_args()
if args.raw_generate is not None and args.build_config is not None:
write_build_info(args.raw_generate, args.build_config, args.key_value)
if (args.raw_generate is not None) and (args.is_config_cuda is not None) and (
args.is_config_rocm is not None):
write_build_info(args.raw_generate, args.is_config_cuda,
args.is_config_rocm, args.key_value)
else:
raise RuntimeError("--raw_generate and --build_config must be used")
raise RuntimeError(
"--raw_generate, --is_config_cuda and --is_config_rocm must be used")