[tf.data service] Fix timeouts in distributed_epoch tests.
This fixes an issue where tests would get stuck retrying GetSplit requests to the dispatcher after the test had already shut down the dispatcher. Now the retry period is configurable, so that tests can set it to 1 second instead of the default 1 hour. PiperOrigin-RevId: 337935450 Change-Id: I11bdfba4e731af314fc0016ee5b113dbc339cbdd
This commit is contained in:
parent
cca4ca7344
commit
209ede00c5
@ -22,10 +22,6 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
namespace {
|
|
||||||
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
|
Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!dispatcher_) {
|
if (!dispatcher_) {
|
||||||
@ -38,7 +34,8 @@ Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
|
|||||||
*end_of_splits);
|
*end_of_splits);
|
||||||
},
|
},
|
||||||
"get next split",
|
"get next split",
|
||||||
/*deadline_micros=*/Env::Default()->NowMicros() + kRetryTimeoutMicros);
|
/*deadline_micros=*/Env::Default()->NowMicros() +
|
||||||
|
(timeout_ms_ * EnvTime::kMillisToMicros));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceSplitProvider::Reset() {
|
Status DataServiceSplitProvider::Reset() {
|
||||||
|
@ -28,8 +28,12 @@ namespace data {
|
|||||||
class DataServiceSplitProvider : public SplitProvider {
|
class DataServiceSplitProvider : public SplitProvider {
|
||||||
public:
|
public:
|
||||||
DataServiceSplitProvider(const std::string& address,
|
DataServiceSplitProvider(const std::string& address,
|
||||||
const std::string& protocol, int64 job_id)
|
const std::string& protocol, int64 job_id,
|
||||||
: address_(address), protocol_(protocol), job_id_(job_id) {}
|
int64 timeout_ms)
|
||||||
|
: address_(address),
|
||||||
|
protocol_(protocol),
|
||||||
|
job_id_(job_id),
|
||||||
|
timeout_ms_(timeout_ms) {}
|
||||||
|
|
||||||
Status GetNext(Tensor* split, bool* end_of_splits) override;
|
Status GetNext(Tensor* split, bool* end_of_splits) override;
|
||||||
Status Reset() override;
|
Status Reset() override;
|
||||||
@ -42,6 +46,7 @@ class DataServiceSplitProvider : public SplitProvider {
|
|||||||
const std::string address_;
|
const std::string address_;
|
||||||
const std::string protocol_;
|
const std::string protocol_;
|
||||||
const int64 job_id_;
|
const int64 job_id_;
|
||||||
|
const int64 timeout_ms_;
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
int64 repetition_ = 0;
|
int64 repetition_ = 0;
|
||||||
|
@ -150,7 +150,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
|||||||
case DISTRIBUTED_EPOCH: {
|
case DISTRIBUTED_EPOCH: {
|
||||||
auto split_provider = absl::make_unique<DataServiceSplitProvider>(
|
auto split_provider = absl::make_unique<DataServiceSplitProvider>(
|
||||||
config_.dispatcher_address(), config_.protocol(),
|
config_.dispatcher_address(), config_.protocol(),
|
||||||
task.task_def.job_id());
|
task.task_def.job_id(), config_.dispatcher_timeout_ms());
|
||||||
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider),
|
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider),
|
||||||
&task.iterator));
|
&task.iterator));
|
||||||
break;
|
break;
|
||||||
|
@ -37,4 +37,7 @@ message WorkerConfig {
|
|||||||
string worker_address = 4;
|
string worker_address = 4;
|
||||||
// How often the worker should heartbeat to the master.
|
// How often the worker should heartbeat to the master.
|
||||||
int64 heartbeat_interval_ms = 5;
|
int64 heartbeat_interval_ms = 5;
|
||||||
|
// How long to retry requests to the dispatcher before giving up and reporting
|
||||||
|
// an error.
|
||||||
|
int64 dispatcher_timeout_ms = 6;
|
||||||
}
|
}
|
||||||
|
@ -380,7 +380,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeDistributedEpochTensorSlices(self):
|
def testDistributeDistributedEpochTensorSlices(self):
|
||||||
self.skipTest("b/170910141")
|
|
||||||
cluster = self.create_cluster(num_workers=2)
|
cluster = self.create_cluster(num_workers=2)
|
||||||
vals = [5, 1, 2, 4]
|
vals = [5, 1, 2, 4]
|
||||||
ds = dataset_ops.Dataset.from_tensor_slices(vals)
|
ds = dataset_ops.Dataset.from_tensor_slices(vals)
|
||||||
@ -390,7 +389,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeDistributedEpochInterleave(self):
|
def testDistributeDistributedEpochInterleave(self):
|
||||||
self.skipTest("b/170910141")
|
|
||||||
cluster = self.create_cluster(num_workers=2)
|
cluster = self.create_cluster(num_workers=2)
|
||||||
elements = [1, 5, 0]
|
elements = [1, 5, 0]
|
||||||
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||||
@ -401,7 +399,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeDistributedEpochParallelInterleave(self):
|
def testDistributeDistributedEpochParallelInterleave(self):
|
||||||
self.skipTest("b/170910141")
|
|
||||||
cluster = self.create_cluster(num_workers=2)
|
cluster = self.create_cluster(num_workers=2)
|
||||||
elements = [1, 5, 0]
|
elements = [1, 5, 0]
|
||||||
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||||
@ -414,7 +411,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeDistributedEpochFlatMap(self):
|
def testDistributeDistributedEpochFlatMap(self):
|
||||||
self.skipTest("b/170910141")
|
|
||||||
cluster = self.create_cluster(num_workers=2)
|
cluster = self.create_cluster(num_workers=2)
|
||||||
elements = [1, 5, 0]
|
elements = [1, 5, 0]
|
||||||
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
ds = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||||
@ -425,7 +421,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeDistributedEpochRepeat(self):
|
def testDistributeDistributedEpochRepeat(self):
|
||||||
self.skipTest("b/170910141")
|
|
||||||
cluster = self.create_cluster(num_workers=2)
|
cluster = self.create_cluster(num_workers=2)
|
||||||
num_repeats = 5
|
num_repeats = 5
|
||||||
num_elements = 20
|
num_elements = 20
|
||||||
@ -437,7 +432,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeDistributedEpochShuffleAndRepeat(self):
|
def testDistributeDistributedEpochShuffleAndRepeat(self):
|
||||||
self.skipTest("b/170910141")
|
|
||||||
cluster = self.create_cluster(num_workers=2)
|
cluster = self.create_cluster(num_workers=2)
|
||||||
num_repeats = 5
|
num_repeats = 5
|
||||||
num_elements = 20
|
num_elements = 20
|
||||||
@ -462,7 +456,6 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeDistributedEpoch(self):
|
def testDistributeDistributedEpoch(self):
|
||||||
self.skipTest("b/170910141")
|
|
||||||
cluster = self.create_cluster(num_workers=2)
|
cluster = self.create_cluster(num_workers=2)
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
@ -99,7 +99,8 @@ class TestCluster(object):
|
|||||||
server_lib.WorkerServer(
|
server_lib.WorkerServer(
|
||||||
server_lib.WorkerConfig(
|
server_lib.WorkerConfig(
|
||||||
dispatcher_address=self.dispatcher_address(),
|
dispatcher_address=self.dispatcher_address(),
|
||||||
heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS),
|
heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS,
|
||||||
|
dispatcher_timeout_ms=1000),
|
||||||
start=start))
|
start=start))
|
||||||
|
|
||||||
def start_dispatcher(self):
|
def start_dispatcher(self):
|
||||||
|
@ -217,7 +217,7 @@ class DispatchServer(object):
|
|||||||
class WorkerConfig(
|
class WorkerConfig(
|
||||||
collections.namedtuple("WorkerConfig", [
|
collections.namedtuple("WorkerConfig", [
|
||||||
"dispatcher_address", "worker_address", "port", "protocol",
|
"dispatcher_address", "worker_address", "port", "protocol",
|
||||||
"heartbeat_interval_ms"
|
"heartbeat_interval_ms", "dispatcher_timeout_ms"
|
||||||
])):
|
])):
|
||||||
"""Configuration class for tf.data service dispatchers.
|
"""Configuration class for tf.data service dispatchers.
|
||||||
|
|
||||||
@ -235,6 +235,8 @@ class WorkerConfig(
|
|||||||
reasonable default. A higher value will reduce the load on the dispatcher,
|
reasonable default. A higher value will reduce the load on the dispatcher,
|
||||||
while a lower value will reduce the time it takes to reclaim resources
|
while a lower value will reduce the time it takes to reclaim resources
|
||||||
from finished jobs.
|
from finished jobs.
|
||||||
|
dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the
|
||||||
|
dispatcher before giving up and reporting an error. Defaults to 1 hour.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
@ -242,15 +244,19 @@ class WorkerConfig(
|
|||||||
worker_address=None,
|
worker_address=None,
|
||||||
port=0,
|
port=0,
|
||||||
protocol="grpc",
|
protocol="grpc",
|
||||||
heartbeat_interval_ms=None):
|
heartbeat_interval_ms=None,
|
||||||
|
dispatcher_timeout_ms=None):
|
||||||
if worker_address is None:
|
if worker_address is None:
|
||||||
worker_address = "localhost:%port%"
|
worker_address = "localhost:%port%"
|
||||||
if heartbeat_interval_ms is None:
|
if heartbeat_interval_ms is None:
|
||||||
heartbeat_interval_ms = 30 * 1000 # 30 seconds
|
heartbeat_interval_ms = 30 * 1000 # 30 seconds
|
||||||
|
if dispatcher_timeout_ms is None:
|
||||||
|
dispatcher_timeout_ms = 60 * 60 * 1000 # 1 hour
|
||||||
|
|
||||||
return super(WorkerConfig,
|
return super(WorkerConfig,
|
||||||
cls).__new__(cls, dispatcher_address, worker_address, port,
|
cls).__new__(cls, dispatcher_address, worker_address, port,
|
||||||
protocol, heartbeat_interval_ms)
|
protocol, heartbeat_interval_ms,
|
||||||
|
dispatcher_timeout_ms)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
||||||
@ -299,7 +305,8 @@ class WorkerServer(object):
|
|||||||
worker_address=config.worker_address,
|
worker_address=config.worker_address,
|
||||||
port=config.port,
|
port=config.port,
|
||||||
protocol=config.protocol,
|
protocol=config.protocol,
|
||||||
heartbeat_interval_ms=config.heartbeat_interval_ms)
|
heartbeat_interval_ms=config.heartbeat_interval_ms,
|
||||||
|
dispatcher_timeout_ms=config.dispatcher_timeout_ms)
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||||
config_proto.SerializeToString())
|
config_proto.SerializeToString())
|
||||||
if start:
|
if start:
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
name: "dispatcher_address"
|
name: "dispatcher_address"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "dispatcher_timeout_ms"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "heartbeat_interval_ms"
|
name: "heartbeat_interval_ms"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -7,6 +7,10 @@ tf_class {
|
|||||||
name: "dispatcher_address"
|
name: "dispatcher_address"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "dispatcher_timeout_ms"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "heartbeat_interval_ms"
|
name: "heartbeat_interval_ms"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user