Shorten find_cuda_config finder

PiperOrigin-RevId: 353905388
Change-Id: Iaeed674c38bc9346d26e4701f413163212217c37
This commit is contained in:
Austin Anderson 2021-01-26 11:07:14 -08:00 committed by TensorFlower Gardener
parent f5d89fe581
commit b1f946aed5

View File

@ -20,8 +20,8 @@ from __future__ import print_function
import argparse
import errno
import glob
import os
import pathlib
import platform
import re
import subprocess
@ -1239,18 +1239,12 @@ def validate_cuda_config(environ_cp):
if environ_cp.get('TF_NCCL_VERSION', None):
cuda_libraries.append('nccl')
find_cuda_path = pathlib.Path('third_party/gpus/find_cuda_config.py')
if not pathlib.Path(find_cuda_path).is_file():
find_cuda_path = pathlib.Path('.').glob('**/' + str(find_cuda_path))
try:
find_cuda_path = find_cuda_path.__next__()
except StopIteration:
raise FileNotFoundError(
"Can't find 'find_cuda_config.py' script inside working directory")
paths = glob.glob('**/third_party/gpus/find_cuda_config.py', recursive=True)
if not paths:
raise FileNotFoundError(
"Can't find 'find_cuda_config.py' script inside working directory")
proc = subprocess.Popen(
[environ_cp['PYTHON_BIN_PATH'],
str(find_cuda_path)] + cuda_libraries,
[environ_cp['PYTHON_BIN_PATH'], paths[0]] + cuda_libraries,
stdout=subprocess.PIPE,
env=maybe_encode_env(environ_cp))