diff --git a/tensorflow/python/distribute/multi_worker_test_base.py b/tensorflow/python/distribute/multi_worker_test_base.py index 9e56e6d1bf7..5809182b2a8 100644 --- a/tensorflow/python/distribute/multi_worker_test_base.py +++ b/tensorflow/python/distribute/multi_worker_test_base.py @@ -234,6 +234,11 @@ class MultiProcessCluster(object): server_config = config_pb2.ConfigProto() server_config.device_count['GPU'] = 0 + # Set the environment variable to prevent hanging upon job failure and + # restart. Note that it defaults to 'use_caller' at Google, but defaults + # to False in OSS. + os.environ['GRPC_FAIL_FAST'] = 'use_caller' + server_lib.Server( cluster_spec, job_name=task_type,