[tf.data service] Fix timeouts in distributed_epoch tests.

This fixes an issue where tests would get stuck retrying GetSplit requests to the dispatcher after the test had already shut down the dispatcher. Now the retry period is configurable, so that tests can set it to 1 second instead of the default 1 hour.

PiperOrigin-RevId: 337935450
Change-Id: I11bdfba4e731af314fc0016ee5b113dbc339cbdd
This commit is contained in:
Andrew Audibert 2020-10-19 14:33:30 -07:00 committed by TensorFlower Gardener
parent cca4ca7344
commit 209ede00c5
9 changed files with 34 additions and 20 deletions

View File

@ -22,10 +22,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace {
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
} // namespace
Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) { Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
mutex_lock l(mu_); mutex_lock l(mu_);
if (!dispatcher_) { if (!dispatcher_) {
@ -38,7 +34,8 @@ Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
*end_of_splits); *end_of_splits);
}, },
"get next split", "get next split",
/*deadline_micros=*/Env::Default()->NowMicros() + kRetryTimeoutMicros); /*deadline_micros=*/Env::Default()->NowMicros() +
(timeout_ms_ * EnvTime::kMillisToMicros));
} }
Status DataServiceSplitProvider::Reset() { Status DataServiceSplitProvider::Reset() {

View File

@ -28,8 +28,12 @@ namespace data {
class DataServiceSplitProvider : public SplitProvider { class DataServiceSplitProvider : public SplitProvider {
public: public:
DataServiceSplitProvider(const std::string& address, DataServiceSplitProvider(const std::string& address,
const std::string& protocol, int64 job_id) const std::string& protocol, int64 job_id,
: address_(address), protocol_(protocol), job_id_(job_id) {} int64 timeout_ms)
: address_(address),
protocol_(protocol),
job_id_(job_id),
timeout_ms_(timeout_ms) {}
Status GetNext(Tensor* split, bool* end_of_splits) override; Status GetNext(Tensor* split, bool* end_of_splits) override;
Status Reset() override; Status Reset() override;
@ -42,6 +46,7 @@ class DataServiceSplitProvider : public SplitProvider {
const std::string address_; const std::string address_;
const std::string protocol_; const std::string protocol_;
const int64 job_id_; const int64 job_id_;
const int64 timeout_ms_;
mutex mu_; mutex mu_;
int64 repetition_ = 0; int64 repetition_ = 0;

View File

@ -150,7 +150,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
case DISTRIBUTED_EPOCH: { case DISTRIBUTED_EPOCH: {
auto split_provider = absl::make_unique<DataServiceSplitProvider>( auto split_provider = absl::make_unique<DataServiceSplitProvider>(
config_.dispatcher_address(), config_.protocol(), config_.dispatcher_address(), config_.protocol(),
task.task_def.job_id()); task.task_def.job_id(), config_.dispatcher_timeout_ms());
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider), TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider),
&task.iterator)); &task.iterator));
break; break;

View File

@ -37,4 +37,7 @@ message WorkerConfig {
string worker_address = 4; string worker_address = 4;
// How often the worker should heartbeat to the master. // How often the worker should heartbeat to the master.
int64 heartbeat_interval_ms = 5; int64 heartbeat_interval_ms = 5;
// How long to retry requests to the dispatcher before giving up and reporting
// an error.
int64 dispatcher_timeout_ms = 6;
} }

View File

@ -380,7 +380,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochTensorSlices(self): def testDistributeDistributedEpochTensorSlices(self):
self.skipTest("b/170910141")
cluster = self.create_cluster(num_workers=2) cluster = self.create_cluster(num_workers=2)
vals = [5, 1, 2, 4] vals = [5, 1, 2, 4]
ds = dataset_ops.Dataset.from_tensor_slices(vals) ds = dataset_ops.Dataset.from_tensor_slices(vals)
@ -390,7 +389,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochInterleave(self): def testDistributeDistributedEpochInterleave(self):
self.skipTest("b/170910141")
cluster = self.create_cluster(num_workers=2) cluster = self.create_cluster(num_workers=2)
elements = [1, 5, 0] elements = [1, 5, 0]
ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = dataset_ops.Dataset.from_tensor_slices(elements)
@ -401,7 +399,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochParallelInterleave(self): def testDistributeDistributedEpochParallelInterleave(self):
self.skipTest("b/170910141")
cluster = self.create_cluster(num_workers=2) cluster = self.create_cluster(num_workers=2)
elements = [1, 5, 0] elements = [1, 5, 0]
ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = dataset_ops.Dataset.from_tensor_slices(elements)
@ -414,7 +411,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochFlatMap(self): def testDistributeDistributedEpochFlatMap(self):
self.skipTest("b/170910141")
cluster = self.create_cluster(num_workers=2) cluster = self.create_cluster(num_workers=2)
elements = [1, 5, 0] elements = [1, 5, 0]
ds = dataset_ops.Dataset.from_tensor_slices(elements) ds = dataset_ops.Dataset.from_tensor_slices(elements)
@ -425,7 +421,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochRepeat(self): def testDistributeDistributedEpochRepeat(self):
self.skipTest("b/170910141")
cluster = self.create_cluster(num_workers=2) cluster = self.create_cluster(num_workers=2)
num_repeats = 5 num_repeats = 5
num_elements = 20 num_elements = 20
@ -437,7 +432,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpochShuffleAndRepeat(self): def testDistributeDistributedEpochShuffleAndRepeat(self):
self.skipTest("b/170910141")
cluster = self.create_cluster(num_workers=2) cluster = self.create_cluster(num_workers=2)
num_repeats = 5 num_repeats = 5
num_elements = 20 num_elements = 20
@ -462,7 +456,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeDistributedEpoch(self): def testDistributeDistributedEpoch(self):
self.skipTest("b/170910141")
cluster = self.create_cluster(num_workers=2) cluster = self.create_cluster(num_workers=2)
num_elements = 100 num_elements = 100
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)

View File

@ -99,7 +99,8 @@ class TestCluster(object):
server_lib.WorkerServer( server_lib.WorkerServer(
server_lib.WorkerConfig( server_lib.WorkerConfig(
dispatcher_address=self.dispatcher_address(), dispatcher_address=self.dispatcher_address(),
heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS), heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS,
dispatcher_timeout_ms=1000),
start=start)) start=start))
def start_dispatcher(self): def start_dispatcher(self):

View File

@ -217,7 +217,7 @@ class DispatchServer(object):
class WorkerConfig( class WorkerConfig(
collections.namedtuple("WorkerConfig", [ collections.namedtuple("WorkerConfig", [
"dispatcher_address", "worker_address", "port", "protocol", "dispatcher_address", "worker_address", "port", "protocol",
"heartbeat_interval_ms" "heartbeat_interval_ms", "dispatcher_timeout_ms"
])): ])):
"""Configuration class for tf.data service dispatchers. """Configuration class for tf.data service dispatchers.
@ -235,6 +235,8 @@ class WorkerConfig(
reasonable default. A higher value will reduce the load on the dispatcher, reasonable default. A higher value will reduce the load on the dispatcher,
while a lower value will reduce the time it takes to reclaim resources while a lower value will reduce the time it takes to reclaim resources
from finished jobs. from finished jobs.
dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the
dispatcher before giving up and reporting an error. Defaults to 1 hour.
""" """
def __new__(cls, def __new__(cls,
@ -242,15 +244,19 @@ class WorkerConfig(
worker_address=None, worker_address=None,
port=0, port=0,
protocol="grpc", protocol="grpc",
heartbeat_interval_ms=None): heartbeat_interval_ms=None,
dispatcher_timeout_ms=None):
if worker_address is None: if worker_address is None:
worker_address = "localhost:%port%" worker_address = "localhost:%port%"
if heartbeat_interval_ms is None: if heartbeat_interval_ms is None:
heartbeat_interval_ms = 30 * 1000 # 30 seconds heartbeat_interval_ms = 30 * 1000 # 30 seconds
if dispatcher_timeout_ms is None:
dispatcher_timeout_ms = 60 * 60 * 1000 # 1 hour
return super(WorkerConfig, return super(WorkerConfig,
cls).__new__(cls, dispatcher_address, worker_address, port, cls).__new__(cls, dispatcher_address, worker_address, port,
protocol, heartbeat_interval_ms) protocol, heartbeat_interval_ms,
dispatcher_timeout_ms)
@tf_export("data.experimental.service.WorkerServer", v1=[]) @tf_export("data.experimental.service.WorkerServer", v1=[])
@ -299,7 +305,8 @@ class WorkerServer(object):
worker_address=config.worker_address, worker_address=config.worker_address,
port=config.port, port=config.port,
protocol=config.protocol, protocol=config.protocol,
heartbeat_interval_ms=config.heartbeat_interval_ms) heartbeat_interval_ms=config.heartbeat_interval_ms,
dispatcher_timeout_ms=config.dispatcher_timeout_ms)
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
config_proto.SerializeToString()) config_proto.SerializeToString())
if start: if start:

View File

@ -7,6 +7,10 @@ tf_class {
name: "dispatcher_address" name: "dispatcher_address"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "dispatcher_timeout_ms"
mtype: "<type \'property\'>"
}
member { member {
name: "heartbeat_interval_ms" name: "heartbeat_interval_ms"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -7,6 +7,10 @@ tf_class {
name: "dispatcher_address" name: "dispatcher_address"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "dispatcher_timeout_ms"
mtype: "<type \'property\'>"
}
member { member {
name: "heartbeat_interval_ms" name: "heartbeat_interval_ms"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"