Update multi_process_lib to handle file path for OSS keras build/test.

PiperOrigin-RevId: 346188693
Change-Id: I5aa80ee4e262989666b72a529d161ccef1f6ac37
This commit is contained in:
Scott Zhu 2020-12-07 15:06:06 -08:00 committed by TensorFlower Gardener
parent 0a15fbc048
commit 0b66713efa

View File

@ -98,23 +98,27 @@ def _set_spawn_exe_path():
"""
# TODO(b/150264776): This does not work with Windows. Find a solution.
if sys.argv[0].endswith('.py'):
path = None
# If all we have is a python module path, we'll need to make a guess for
# the actual executable path.
if 'bazel-out' in sys.argv[0]:
# Guess the binary path under bazel. For target
# //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
# argv[0] is in the form of
# /.../tensorflow/python/distribute/input_lib_test.py
# and the binary is
# /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
org_tensorflow_base = sys.argv[0][:sys.argv[0].rfind('/org_tensorflow')]
binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
possible_path = os.path.join(org_tensorflow_base, 'org_tensorflow',
binary)
logging.info('Guessed test binary path: %s', possible_path)
if os.access(possible_path, os.X_OK):
path = possible_path
def guess_path(package_root):
# If all we have is a python module path, we'll need to make a guess for
# the actual executable path.
if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
# Guess the binary path under bazel. For target
# //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
# argv[0] is in the form of
# /.../tensorflow/python/distribute/input_lib_test.py
# and the binary is
# /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
possible_path = os.path.join(package_root_base, package_root,
binary)
logging.info('Guessed test binary path: %s', possible_path)
if os.access(possible_path, os.X_OK):
return possible_path
return None
path = guess_path('org_tensorflow')
if not path:
path = guess_path('org_keras')
if path is None:
logging.error(
'Cannot determine binary path. sys.argv[0]=%s os.environ=%s',