diff --git a/tensorflow/core/data/service/split_provider.cc b/tensorflow/core/data/service/split_provider.cc index b3100d52ff1..4ebb25348b6 100644 --- a/tensorflow/core/data/service/split_provider.cc +++ b/tensorflow/core/data/service/split_provider.cc @@ -22,10 +22,6 @@ limitations under the License. namespace tensorflow { namespace data { -namespace { -const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes. -} // namespace - Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { mutex_lock l(mu_); if (!dispatcher_) { @@ -38,7 +34,8 @@ Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { *end_of_splits); }, "get next split", - /*deadline_micros=*/Env::Default()->NowMicros() + kRetryTimeoutMicros); + /*deadline_micros=*/Env::Default()->NowMicros() + + (timeout_ms_ * EnvTime::kMillisToMicros)); } Status DataServiceSplitProvider::Reset() { diff --git a/tensorflow/core/data/service/split_provider.h b/tensorflow/core/data/service/split_provider.h index 110b9e26ec7..57091de9db1 100644 --- a/tensorflow/core/data/service/split_provider.h +++ b/tensorflow/core/data/service/split_provider.h @@ -28,8 +28,12 @@ namespace data { class DataServiceSplitProvider : public SplitProvider { public: DataServiceSplitProvider(const std::string& address, - const std::string& protocol, int64 job_id) - : address_(address), protocol_(protocol), job_id_(job_id) {} + const std::string& protocol, int64 job_id, + int64 timeout_ms) + : address_(address), + protocol_(protocol), + job_id_(job_id), + timeout_ms_(timeout_ms) {} Status GetNext(Tensor* split, bool* end_of_splits) override; Status Reset() override; @@ -42,6 +46,7 @@ class DataServiceSplitProvider : public SplitProvider { const std::string address_; const std::string protocol_; const int64 job_id_; + const int64 timeout_ms_; mutex mu_; int64 repetition_ = 0; diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 98862b1f176..4621e1e8a80 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -150,7 +150,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized( case DISTRIBUTED_EPOCH: { auto split_provider = absl::make_unique( config_.dispatcher_address(), config_.protocol(), - task.task_def.job_id()); + task.task_def.job_id(), config_.dispatcher_timeout_ms()); TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider), &task.iterator)); break; diff --git a/tensorflow/core/protobuf/data/experimental/service_config.proto b/tensorflow/core/protobuf/data/experimental/service_config.proto index 7a0aa16e2c4..3dcd2cd48d0 100644 --- a/tensorflow/core/protobuf/data/experimental/service_config.proto +++ b/tensorflow/core/protobuf/data/experimental/service_config.proto @@ -37,4 +37,7 @@ message WorkerConfig { string worker_address = 4; // How often the worker should heartbeat to the master. int64 heartbeat_interval_ms = 5; + // How long to retry requests to the dispatcher before giving up and reporting + // an error. + int64 dispatcher_timeout_ms = 6; } diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py index 8a0617f4dee..ddd301d1540 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py @@ -380,7 +380,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochTensorSlices(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) @@ -390,7 +389,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochInterleave(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) @@ -401,7 +399,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochParallelInterleave(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) @@ -414,7 +411,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochFlatMap(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) elements = [1, 5, 0] ds = dataset_ops.Dataset.from_tensor_slices(elements) @@ -425,7 +421,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochRepeat(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 @@ -437,7 +432,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochShuffleAndRepeat(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 @@ -462,7 +456,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpoch(self): - self.skipTest("b/170910141") cluster = self.create_cluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py index 0bb1383a56b..0e48e1f4dd9 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py @@ -99,7 +99,8 @@ class TestCluster(object): server_lib.WorkerServer( server_lib.WorkerConfig( dispatcher_address=self.dispatcher_address(), - heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS), + heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS, + dispatcher_timeout_ms=1000), start=start)) def start_dispatcher(self): diff --git a/tensorflow/python/data/experimental/service/server_lib.py b/tensorflow/python/data/experimental/service/server_lib.py index 95179a4a7df..addd20fb73b 100644 --- a/tensorflow/python/data/experimental/service/server_lib.py +++ b/tensorflow/python/data/experimental/service/server_lib.py @@ -217,7 +217,7 @@ class DispatchServer(object): class WorkerConfig( collections.namedtuple("WorkerConfig", [ "dispatcher_address", "worker_address", "port", "protocol", - "heartbeat_interval_ms" + "heartbeat_interval_ms", "dispatcher_timeout_ms" ])): """Configuration class for tf.data service dispatchers. @@ -235,6 +235,8 @@ class WorkerConfig( reasonable default. A higher value will reduce the load on the dispatcher, while a lower value will reduce the time it takes to reclaim resources from finished jobs. + dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the + dispatcher before giving up and reporting an error. Defaults to 1 hour. """ def __new__(cls, @@ -242,15 +244,19 @@ class WorkerConfig( worker_address=None, port=0, protocol="grpc", - heartbeat_interval_ms=None): + heartbeat_interval_ms=None, + dispatcher_timeout_ms=None): if worker_address is None: worker_address = "localhost:%port%" if heartbeat_interval_ms is None: heartbeat_interval_ms = 30 * 1000 # 30 seconds + if dispatcher_timeout_ms is None: + dispatcher_timeout_ms = 60 * 60 * 1000 # 1 hour return super(WorkerConfig, cls).__new__(cls, dispatcher_address, worker_address, port, - protocol, heartbeat_interval_ms) + protocol, heartbeat_interval_ms, + dispatcher_timeout_ms) @tf_export("data.experimental.service.WorkerServer", v1=[]) @@ -299,7 +305,8 @@ class WorkerServer(object): worker_address=config.worker_address, port=config.port, protocol=config.protocol, - heartbeat_interval_ms=config.heartbeat_interval_ms) + heartbeat_interval_ms=config.heartbeat_interval_ms, + dispatcher_timeout_ms=config.dispatcher_timeout_ms) self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( config_proto.SerializeToString()) if start: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt index d8eaf9bc7d7..63878ebc25d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "dispatcher_timeout_ms" + mtype: "" + } member { name: "heartbeat_interval_ms" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt index d8eaf9bc7d7..63878ebc25d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.service.-worker-config.pbtxt @@ -7,6 +7,10 @@ tf_class { name: "dispatcher_address" mtype: "" } + member { + name: "dispatcher_timeout_ms" + mtype: "" + } member { name: "heartbeat_interval_ms" mtype: ""