From 209ede00c524a81c3b5f9d6d71ab2603e568447a Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Mon, 19 Oct 2020 14:33:30 -0700 Subject: [PATCH] [tf.data service] Fix timeouts in distributed_epoch tests. This fixes an issue where tests would get stuck retrying GetSplit requests to the dispatcher after the test had already shut down the dispatcher. Now the retry period is configurable, so that tests can set it to 1 second instead of the default 1 hour. PiperOrigin-RevId: 337935450 Change-Id: I11bdfba4e731af314fc0016ee5b113dbc339cbdd --- tensorflow/core/data/service/split_provider.cc | 7 ++----- tensorflow/core/data/service/split_provider.h | 9 +++++++-- tensorflow/core/data/service/worker_impl.cc | 2 +- .../data/experimental/service_config.proto | 3 +++ .../kernel_tests/data_service_ops_test.py | 7 ------- .../kernel_tests/data_service_test_base.py | 3 ++- .../data/experimental/service/server_lib.py | 15 +++++++++++---- ...data.experimental.service.-worker-config.pbtxt | 4 ++++ ...data.experimental.service.-worker-config.pbtxt | 4 ++++ 9 files changed, 34 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/data/service/split_provider.cc b/tensorflow/core/data/service/split_provider.cc index b3100d52ff1..4ebb25348b6 100644 --- a/tensorflow/core/data/service/split_provider.cc +++ b/tensorflow/core/data/service/split_provider.cc @@ -22,10 +22,6 @@ limitations under the License. namespace tensorflow { namespace data { -namespace { -const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes. -} // namespace - Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { mutex_lock l(mu_); if (!dispatcher_) { @@ -38,7 +34,8 @@ Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { *end_of_splits); }, "get next split", - /*deadline_micros=*/Env::Default()->NowMicros() + kRetryTimeoutMicros); + /*deadline_micros=*/Env::Default()->NowMicros() + + (timeout_ms_ * EnvTime::kMillisToMicros)); } Status DataServiceSplitProvider::Reset() { diff --git a/tensorflow/core/data/service/split_provider.h b/tensorflow/core/data/service/split_provider.h index 110b9e26ec7..57091de9db1 100644 --- a/tensorflow/core/data/service/split_provider.h +++ b/tensorflow/core/data/service/split_provider.h @@ -28,8 +28,12 @@ namespace data { class DataServiceSplitProvider : public SplitProvider { public: DataServiceSplitProvider(const std::string& address, - const std::string& protocol, int64 job_id) - : address_(address), protocol_(protocol), job_id_(job_id) {} + const std::string& protocol, int64 job_id, + int64 timeout_ms) + : address_(address), + protocol_(protocol), + job_id_(job_id), + timeout_ms_(timeout_ms) {} Status GetNext(Tensor* split, bool* end_of_splits) override; Status Reset() override; @@ -42,6 +46,7 @@ class DataServiceSplitProvider : public SplitProvider { const std::string address_; const std::string protocol_; const int64 job_id_; + const int64 timeout_ms_; mutex mu_; int64 repetition_ = 0; diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 98862b1f176..4621e1e8a80 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -150,7 +150,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized( case DISTRIBUTED_EPOCH: { auto split_provider = absl::make_unique( config_.dispatcher_address(), config_.protocol(), - task.task_def.job_id()); + task.task_def.job_id(), config_.dispatcher_timeout_ms()); TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider), &task.iterator)); break; diff --git a/tensorflow/core/protobuf/data/experimental/service_config.proto b/tensorflow/core/protobuf/data/experimental/service_config.proto index 7a0aa16e2c4..3dcd2cd48d0 100644 --- a/tensorflow/core/protobuf/data/experimental/service_config.proto +++ b/tensorflow/core/protobuf/data/experimental/service_config.proto @@ -37,4 +37,7 @@ message WorkerConfig { string worker_address = 4; // How often the worker should heartbeat to the master. int64 heartbeat_interval_ms = 5; + // How long to retry requests to the dispatcher before giving up and reporting + // an error. + int64 dispatcher_timeout_ms = 6; } diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py index 8a0617f4dee..ddd301d1540 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py @@ -380,7 +380,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochTensorSlices(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) @@ -390,7 +389,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochInterleave(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) @@ -401,7 +399,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochParallelInterleave(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) @@ -414,7 +411,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochFlatMap(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) @@ -425,7 +421,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochRepeat(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 @@ -437,7 +432,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochShuffleAndRepeat(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 @@ -462,7 +456,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpoch(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py index 0bb1383a56b..0e48e1f4dd9 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py @@ -99,7 +99,8 @@ class TestCluster(object): server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(), - heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS), + heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS, + dispatcher_timeout_ms=1000), start=start)) def start_dispatcher(self): diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index 95179a4a7df..addd20fb73b 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -217,7 +217,7 @@ class DispatchServer(object): class WorkerConfig( collections.namedtuple("WorkerConfig", [ "dispatcher_address", "worker_address", "port", "protocol", - "heartbeat_interval_ms" + "heartbeat_interval_ms", "dispatcher_timeout_ms" ])): """Configuration class for tf.data service dispatchers. @@ -235,6 +235,8 @@ class WorkerConfig( reasonable default. A higher value will reduce the load on the dispatcher, while a lower value will reduce the time it takes to reclaim resources from finished jobs. + dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the + dispatcher before giving up and reporting an error. Defaults to 1 hour. """ def __new__(cls, @@ -242,15 +244,19 @@ class WorkerConfig( worker_address=None, port=0, protocol="grpc", - heartbeat_interval_ms=None): + heartbeat_interval_ms=None, + dispatcher_timeout_ms=None): if worker_address is None: worker_address = "localhost:%port%" if heartbeat_interval_ms is None: heartbeat_interval_ms = 30 * 1000 # 30 seconds + if dispatcher_timeout_ms is None: + dispatcher_timeout_ms = 60 * 60 * 1000 # 1 hour return super(WorkerConfig, cls).__new__(cls, dispatcher_address, worker_address, port, - protocol, heartbeat_interval_ms) + protocol, heartbeat_interval_ms, + dispatcher_timeout_ms) @tf_export("data.experimental.service.WorkerServer", v1=[]) @@ -299,7 +305,8 @@ class WorkerServer(object): worker_address=config.worker_address, port=config.port, protocol=config.protocol, - heartbeat_interval_ms=config.heartbeat_interval_ms) + heartbeat_interval_ms=config.heartbeat_interval_ms, + dispatcher_timeout_ms=config.dispatcher_timeout_ms) self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( config_proto.SerializeToString()) if start: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt index d8eaf9bc7d7..63878ebc25d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "dispatcher_timeout_ms" + mtype: "" + } member { name: "heartbeat_interval_ms" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt index d8eaf9bc7d7..63878ebc25d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "dispatcher_timeout_ms" + mtype: "" + } member { name: "heartbeat_interval_ms" mtype: ""