[tf.data service] Fix timeouts in distributed_epoch tests.
This fixes an issue where tests would get stuck retrying GetSplit requests to the dispatcher after the test had already shut down the dispatcher. Now the retry period is configurable, so that tests can set it to 1 second instead of the default 1 hour. PiperOrigin-RevId: 337935450 Change-Id: I11bdfba4e731af314fc0016ee5b113dbc339cbdd
This commit is contained in:
parent
cca4ca7344
commit
209ede00c5
@ -22,10 +22,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
namespace {
|
||||
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
|
||||
} // namespace
|
||||
|
||||
Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
|
||||
mutex_lock l(mu_);
|
||||
if (!dispatcher_) {
|
||||
@ -38,7 +34,8 @@ Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
|
||||
*end_of_splits);
|
||||
},
|
||||
"get next split",
|
||||
/*deadline_micros=*/Env::Default()->NowMicros() + kRetryTimeoutMicros);
|
||||
/*deadline_micros=*/Env::Default()->NowMicros() +
|
||||
(timeout_ms_ * EnvTime::kMillisToMicros));
|
||||
}
|
||||
|
||||
Status DataServiceSplitProvider::Reset() {
|
||||
|
@ -28,8 +28,12 @@ namespace data {
|
||||
class DataServiceSplitProvider : public SplitProvider {
|
||||
public:
|
||||
DataServiceSplitProvider(const std::string& address,
|
||||
const std::string& protocol, int64 job_id)
|
||||
: address_(address), protocol_(protocol), job_id_(job_id) {}
|
||||
const std::string& protocol, int64 job_id,
|
||||
int64 timeout_ms)
|
||||
: address_(address),
|
||||
protocol_(protocol),
|
||||
job_id_(job_id),
|
||||
timeout_ms_(timeout_ms) {}
|
||||
|
||||
Status GetNext(Tensor* split, bool* end_of_splits) override;
|
||||
Status Reset() override;
|
||||
@ -42,6 +46,7 @@ class DataServiceSplitProvider : public SplitProvider {
|
||||
const std::string address_;
|
||||
const std::string protocol_;
|
||||
const int64 job_id_;
|
||||
const int64 timeout_ms_;
|
||||
|
||||
mutex mu_;
|
||||
int64 repetition_ = 0;
|
||||
|
@ -150,7 +150,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
||||
case DISTRIBUTED_EPOCH: {
|
||||
auto split_provider = absl::make_unique<DataServiceSplitProvider>(
|
||||
config_.dispatcher_address(), config_.protocol(),
|
||||
task.task_def.job_id());
|
||||
task.task_def.job_id(), config_.dispatcher_timeout_ms());
|
||||
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider),
|
||||
&task.iterator));
|
||||
break;
|
||||
|
@ -37,4 +37,7 @@ message WorkerConfig {
|
||||
string worker_address = 4;
|
||||
// How often the worker should heartbeat to the master.
|
||||
int64 heartbeat_interval_ms = 5;
|
||||
// How long to retry requests to the dispatcher before giving up and reporting
|
||||
// an error.
|
||||
int64 dispatcher_timeout_ms = 6;
|
||||
}
|
||||
|
@ -380,7 +380,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochTensorSlices(self):
|
||||
self.skipTest("b/170910141")
|
||||
cluster = self.create_cluster(num_workers=2)
|
||||
vals = [5, 1, 2, 4]
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(vals)
|
||||
@ -390,7 +389,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochInterleave(self):
|
||||
self.skipTest("b/170910141")
|
||||
cluster = self.create_cluster(num_workers=2)
|
||||
elements = [1, 5, 0]
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||
@ -401,7 +399,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochParallelInterleave(self):
|
||||
self.skipTest("b/170910141")
|
||||
cluster = self.create_cluster(num_workers=2)
|
||||
elements = [1, 5, 0]
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||
@ -414,7 +411,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochFlatMap(self):
|
||||
self.skipTest("b/170910141")
|
||||
cluster = self.create_cluster(num_workers=2)
|
||||
elements = [1, 5, 0]
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||
@ -425,7 +421,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochRepeat(self):
|
||||
self.skipTest("b/170910141")
|
||||
cluster = self.create_cluster(num_workers=2)
|
||||
num_repeats = 5
|
||||
num_elements = 20
|
||||
@ -437,7 +432,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpochShuffleAndRepeat(self):
|
||||
self.skipTest("b/170910141")
|
||||
cluster = self.create_cluster(num_workers=2)
|
||||
num_repeats = 5
|
||||
num_elements = 20
|
||||
@ -462,7 +456,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDistributeDistributedEpoch(self):
|
||||
self.skipTest("b/170910141")
|
||||
cluster = self.create_cluster(num_workers=2)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
|
@ -99,7 +99,8 @@ class TestCluster(object):
|
||||
server_lib.WorkerServer(
|
||||
server_lib.WorkerConfig(
|
||||
dispatcher_address=self.dispatcher_address(),
|
||||
heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS),
|
||||
heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS,
|
||||
dispatcher_timeout_ms=1000),
|
||||
start=start))
|
||||
|
||||
def start_dispatcher(self):
|
||||
|
@ -217,7 +217,7 @@ class DispatchServer(object):
|
||||
class WorkerConfig(
|
||||
collections.namedtuple("WorkerConfig", [
|
||||
"dispatcher_address", "worker_address", "port", "protocol",
|
||||
"heartbeat_interval_ms"
|
||||
"heartbeat_interval_ms", "dispatcher_timeout_ms"
|
||||
])):
|
||||
"""Configuration class for tf.data service dispatchers.
|
||||
|
||||
@ -235,6 +235,8 @@ class WorkerConfig(
|
||||
reasonable default. A higher value will reduce the load on the dispatcher,
|
||||
while a lower value will reduce the time it takes to reclaim resources
|
||||
from finished jobs.
|
||||
dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the
|
||||
dispatcher before giving up and reporting an error. Defaults to 1 hour.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
@ -242,15 +244,19 @@ class WorkerConfig(
|
||||
worker_address=None,
|
||||
port=0,
|
||||
protocol="grpc",
|
||||
heartbeat_interval_ms=None):
|
||||
heartbeat_interval_ms=None,
|
||||
dispatcher_timeout_ms=None):
|
||||
if worker_address is None:
|
||||
worker_address = "localhost:%port%"
|
||||
if heartbeat_interval_ms is None:
|
||||
heartbeat_interval_ms = 30 * 1000 # 30 seconds
|
||||
if dispatcher_timeout_ms is None:
|
||||
dispatcher_timeout_ms = 60 * 60 * 1000 # 1 hour
|
||||
|
||||
return super(WorkerConfig,
|
||||
cls).__new__(cls, dispatcher_address, worker_address, port,
|
||||
protocol, heartbeat_interval_ms)
|
||||
protocol, heartbeat_interval_ms,
|
||||
dispatcher_timeout_ms)
|
||||
|
||||
|
||||
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
||||
@ -299,7 +305,8 @@ class WorkerServer(object):
|
||||
worker_address=config.worker_address,
|
||||
port=config.port,
|
||||
protocol=config.protocol,
|
||||
heartbeat_interval_ms=config.heartbeat_interval_ms)
|
||||
heartbeat_interval_ms=config.heartbeat_interval_ms,
|
||||
dispatcher_timeout_ms=config.dispatcher_timeout_ms)
|
||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||
config_proto.SerializeToString())
|
||||
if start:
|
||||
|
@ -7,6 +7,10 @@ tf_class {
|
||||
name: "dispatcher_address"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dispatcher_timeout_ms"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "heartbeat_interval_ms"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -7,6 +7,10 @@ tf_class {
|
||||
name: "dispatcher_address"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dispatcher_timeout_ms"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "heartbeat_interval_ms"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user