Update multi_process_lib to handle file path for OSS keras build/test.
PiperOrigin-RevId: 346188693 Change-Id: I5aa80ee4e262989666b72a529d161ccef1f6ac37
This commit is contained in:
parent
0a15fbc048
commit
0b66713efa
@ -98,23 +98,27 @@ def _set_spawn_exe_path():
|
||||
"""
|
||||
# TODO(b/150264776): This does not work with Windows. Find a solution.
|
||||
if sys.argv[0].endswith('.py'):
|
||||
path = None
|
||||
# If all we have is a python module path, we'll need to make a guess for
|
||||
# the actual executable path.
|
||||
if 'bazel-out' in sys.argv[0]:
|
||||
# Guess the binary path under bazel. For target
|
||||
# //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
|
||||
# argv[0] is in the form of
|
||||
# /.../tensorflow/python/distribute/input_lib_test.py
|
||||
# and the binary is
|
||||
# /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
|
||||
org_tensorflow_base = sys.argv[0][:sys.argv[0].rfind('/org_tensorflow')]
|
||||
binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
|
||||
possible_path = os.path.join(org_tensorflow_base, 'org_tensorflow',
|
||||
binary)
|
||||
logging.info('Guessed test binary path: %s', possible_path)
|
||||
if os.access(possible_path, os.X_OK):
|
||||
path = possible_path
|
||||
def guess_path(package_root):
|
||||
# If all we have is a python module path, we'll need to make a guess for
|
||||
# the actual executable path.
|
||||
if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
|
||||
# Guess the binary path under bazel. For target
|
||||
# //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
|
||||
# argv[0] is in the form of
|
||||
# /.../tensorflow/python/distribute/input_lib_test.py
|
||||
# and the binary is
|
||||
# /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
|
||||
package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
|
||||
binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
|
||||
possible_path = os.path.join(package_root_base, package_root,
|
||||
binary)
|
||||
logging.info('Guessed test binary path: %s', possible_path)
|
||||
if os.access(possible_path, os.X_OK):
|
||||
return possible_path
|
||||
return None
|
||||
path = guess_path('org_tensorflow')
|
||||
if not path:
|
||||
path = guess_path('org_keras')
|
||||
if path is None:
|
||||
logging.error(
|
||||
'Cannot determine binary path. sys.argv[0]=%s os.environ=%s',
|
||||
|
Loading…
Reference in New Issue
Block a user