[tf.data service] Split data_service_ops_test into ft and non-ft tests
Also, move the tests to experimental/kernel_tests where they belong. PiperOrigin-RevId: 333177431 Change-Id: I04f415c08268c0f9113d9bbb342bd00bdf911956
This commit is contained in:
parent
d3844c21b6
commit
148e93cf70
@ -139,6 +139,53 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "data_service_test_base",
|
||||||
|
srcs = ["data_service_test_base.py"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python/data/experimental/service:server_lib",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "data_service_ops_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["data_service_ops_test.py"],
|
||||||
|
shard_count = 10,
|
||||||
|
deps = [
|
||||||
|
":data_service_test_base",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python/data",
|
||||||
|
"//tensorflow/python/data/experimental/ops:testing",
|
||||||
|
"//tensorflow/python/data/experimental/service:server_lib",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "data_service_ops_ft_test",
|
||||||
|
srcs = ["data_service_ops_ft_test.py"],
|
||||||
|
deps = [
|
||||||
|
":data_service_test_base",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python/data",
|
||||||
|
"//tensorflow/python/data/experimental/ops:testing",
|
||||||
|
"//tensorflow/python/data/experimental/service:server_lib",
|
||||||
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "dense_to_sparse_batch_test",
|
name = "dense_to_sparse_batch_test",
|
||||||
srcs = ["dense_to_sparse_batch_test.py"],
|
srcs = ["dense_to_sparse_batch_test.py"],
|
||||||
|
@ -0,0 +1,256 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf.data service ops where servers are started late or preempted."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.kernel_tests import data_service_test_base
|
||||||
|
from tensorflow.python.data.experimental.ops import data_service_ops
|
||||||
|
from tensorflow.python.data.experimental.service import server_lib
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.framework import combinations
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
|
||||||
|
NO_WORK_DIR = data_service_test_base.NO_WORK_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _all_cluster_configurations():
|
||||||
|
with_work_dir = combinations.combine(
|
||||||
|
work_dir=TMP_WORK_DIR, fault_tolerant_mode=[True, False])
|
||||||
|
without_work_dir = combinations.combine(
|
||||||
|
work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
|
||||||
|
return with_work_dir + without_work_dir
|
||||||
|
|
||||||
|
|
||||||
|
class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDispatcherStop(self):
|
||||||
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
num_elements = 100
|
||||||
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
|
iterator = iter(ds)
|
||||||
|
results = []
|
||||||
|
results.append(next(iterator).numpy())
|
||||||
|
dispatcher._stop()
|
||||||
|
# After the dispatcher dies, the worker should continue providing the rest
|
||||||
|
# of the dataset's elements.
|
||||||
|
for _ in range(num_elements - 1):
|
||||||
|
results.append(next(iterator).numpy())
|
||||||
|
self.assertEqual(results, list(range(num_elements)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDispatcherRestartBeforeReading(self):
|
||||||
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
num_elements = 100
|
||||||
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
|
dispatcher = self.restart_dispatcher(dispatcher)
|
||||||
|
|
||||||
|
self.assertDatasetProduces(ds, list(range(num_elements)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDispatcherRestartDuringReading(self):
|
||||||
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
num_elements = 100
|
||||||
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
|
iterator = iter(ds)
|
||||||
|
results = []
|
||||||
|
for _ in range(num_elements // 2):
|
||||||
|
results.append(next(iterator).numpy())
|
||||||
|
dispatcher = self.restart_dispatcher(dispatcher)
|
||||||
|
for elem in iterator:
|
||||||
|
results.append(elem.numpy())
|
||||||
|
|
||||||
|
self.assertEqual(list(range(num_elements)), results)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDispatcherRestartBetweenIterations(self):
|
||||||
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
num_elements = 100
|
||||||
|
ds = self.make_distributed_range_dataset(100, dispatcher)
|
||||||
|
self.assertDatasetProduces(ds, list(range(num_elements)))
|
||||||
|
dispatcher = self.restart_dispatcher(dispatcher)
|
||||||
|
self.assertDatasetProduces(ds, list(range(num_elements)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDispatcherManyRestarts(self):
|
||||||
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
num_elements_start = 10
|
||||||
|
num_elements_end = 15
|
||||||
|
datasets = []
|
||||||
|
for num_elements in range(num_elements_start, num_elements_end):
|
||||||
|
datasets.append(
|
||||||
|
self.make_distributed_range_dataset(num_elements, dispatcher))
|
||||||
|
dispatcher = self.restart_dispatcher(dispatcher)
|
||||||
|
for ds, num_elements in zip(datasets,
|
||||||
|
range(num_elements_start, num_elements_end)):
|
||||||
|
self.assertDatasetProduces(ds, list(range(num_elements)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDispatcherAndWorkerRestart(self):
|
||||||
|
dispatcher, [worker] = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
num_elements = 100
|
||||||
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
|
||||||
|
def restart():
|
||||||
|
return (self.restart_dispatcher(dispatcher),
|
||||||
|
self.restart_worker(worker, dispatcher))
|
||||||
|
|
||||||
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
|
dispatcher, worker = restart()
|
||||||
|
self.assertDatasetProduces(ds, list(range(num_elements)))
|
||||||
|
dispatcher, worker = restart()
|
||||||
|
self.assertDatasetProduces(ds, list(range(num_elements)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testStartServersLate(self):
|
||||||
|
# Test that the data service client performs retries instead of failing when
|
||||||
|
# the dataset is created before the master and worker are started.
|
||||||
|
try:
|
||||||
|
import portpicker # pylint: disable=g-import-not-at-top
|
||||||
|
dispatcher_port = portpicker.pick_unused_port()
|
||||||
|
except:
|
||||||
|
raise self.skipTest("Flakes in portpicker library do not represent "
|
||||||
|
"TensorFlow errors.")
|
||||||
|
dispatcher = server_lib.DispatchServer(
|
||||||
|
server_lib.DispatcherConfig(port=dispatcher_port), start=False)
|
||||||
|
worker = server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(
|
||||||
|
dispatcher_address=self.dispatcher_address(dispatcher), port=0),
|
||||||
|
start=False)
|
||||||
|
|
||||||
|
def start_servers():
|
||||||
|
time.sleep(1)
|
||||||
|
dispatcher.start()
|
||||||
|
worker.start()
|
||||||
|
|
||||||
|
start_servers_thread = threading.Thread(target=start_servers, daemon=True)
|
||||||
|
start_servers_thread.start()
|
||||||
|
|
||||||
|
num_elements = 10
|
||||||
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
|
results = [elem.numpy() for elem in ds]
|
||||||
|
self.assertEqual(list(range(num_elements)), results)
|
||||||
|
start_servers_thread.join()
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testAddWorkerMidJob(self):
|
||||||
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
num_elements = 100
|
||||||
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
|
iterator = iter(ds)
|
||||||
|
results = []
|
||||||
|
# Read halfway through the dataset.
|
||||||
|
for _ in range(num_elements // 2):
|
||||||
|
results.append(next(iterator).numpy())
|
||||||
|
|
||||||
|
new_worker = self.start_worker_server(dispatcher) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
# Wait for the new worker to register with the dispatcher.
|
||||||
|
while dispatcher._num_workers() < 2:
|
||||||
|
time.sleep(10 / 1000) # 10ms
|
||||||
|
|
||||||
|
for elem in iterator:
|
||||||
|
results.append(elem.numpy())
|
||||||
|
|
||||||
|
self.assertCountEqual(2 * list(range(num_elements)), results)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(test_base.eager_only_combinations(),
|
||||||
|
combinations.combine(use_same_port=[True, False]),
|
||||||
|
_all_cluster_configurations()))
|
||||||
|
def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
|
||||||
|
dispatcher, [worker] = self.start_cluster(
|
||||||
|
1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode)
|
||||||
|
num_elements = 100
|
||||||
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
|
iterator = iter(ds)
|
||||||
|
# Read halfway through the dataset.
|
||||||
|
midpoint = num_elements // 2
|
||||||
|
for i in range(midpoint):
|
||||||
|
self.assertEqual(i, next(iterator).numpy())
|
||||||
|
|
||||||
|
# Stop the original worker and start a new one.
|
||||||
|
worker = self.restart_worker(worker, dispatcher, use_same_port)
|
||||||
|
|
||||||
|
# There may have been some elements prefetched from the first worker
|
||||||
|
# before it was stopped.
|
||||||
|
while True:
|
||||||
|
val = next(iterator).numpy()
|
||||||
|
if val == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# The dataset starts over now that we read from the new worker.
|
||||||
|
# TODO(b/157086991): Iterate until end of sequence when we support
|
||||||
|
# detecting lost workers.
|
||||||
|
for i in range(1, num_elements // 2):
|
||||||
|
val = next(iterator).numpy()
|
||||||
|
self.assertEqual(i, val)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testChangeProcessingModeAfterRestart(self):
|
||||||
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
|
num_elements = 100
|
||||||
|
range_dataset = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds = range_dataset.apply(
|
||||||
|
data_service_ops.distribute(
|
||||||
|
processing_mode="parallel_epochs",
|
||||||
|
service=dispatcher.target,
|
||||||
|
job_name="test"))
|
||||||
|
iterator = iter(ds)
|
||||||
|
for i in range(num_elements // 2):
|
||||||
|
self.assertEqual(i, next(iterator).numpy())
|
||||||
|
dispatcher = self.restart_dispatcher(dispatcher)
|
||||||
|
ds = range_dataset.apply(
|
||||||
|
data_service_ops.distribute(
|
||||||
|
processing_mode="distributed_epoch",
|
||||||
|
service=dispatcher.target,
|
||||||
|
job_name="test"))
|
||||||
|
with self.assertRaisesOpError("already an existing job with that name "
|
||||||
|
"using processing mode <parallel_epochs>"):
|
||||||
|
next(iter(ds)).numpy()
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
test_base.eager_only_combinations(),
|
||||||
|
combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR])))
|
||||||
|
def testDistributeLargeGraphThenRegisterWorker(self, work_dir):
|
||||||
|
dispatcher = self.start_dispatch_server(
|
||||||
|
work_dir=work_dir, fault_tolerant_mode=False)
|
||||||
|
worker = server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(
|
||||||
|
dispatcher_address=self.dispatcher_address(dispatcher), port=0),
|
||||||
|
start=False)
|
||||||
|
# Larger than default OSS grpc message size limit of 4MB.
|
||||||
|
tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
|
||||||
|
ds = dataset_ops.Dataset.from_tensors(tensor)
|
||||||
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
|
it = iter(ds)
|
||||||
|
worker.start()
|
||||||
|
self.assertAllEqual(next(it), tensor)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -17,18 +17,16 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.kernel_tests import data_service_test_base
|
||||||
from tensorflow.python.data.experimental.ops import batching
|
from tensorflow.python.data.experimental.ops import batching
|
||||||
from tensorflow.python.data.experimental.ops import data_service_ops
|
from tensorflow.python.data.experimental.ops import data_service_ops
|
||||||
from tensorflow.python.data.experimental.ops import distribute_options
|
from tensorflow.python.data.experimental.ops import distribute_options
|
||||||
from tensorflow.python.data.experimental.ops import grouping
|
from tensorflow.python.data.experimental.ops import grouping
|
||||||
from tensorflow.python.data.experimental.ops import testing
|
from tensorflow.python.data.experimental.ops import testing
|
||||||
from tensorflow.python.data.experimental.service import server_lib
|
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -48,28 +46,8 @@ from tensorflow.python.ops import tensor_array_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
# This will be resolved to a tmp directory by `start_dispatch_server`.
|
TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR
|
||||||
TMP_WORK_DIR = "tmp_work_dir_placeholder"
|
NO_WORK_DIR = data_service_test_base.NO_WORK_DIR
|
||||||
# `""` indicates not to use a work directory.
|
|
||||||
NO_WORK_DIR = ""
|
|
||||||
|
|
||||||
|
|
||||||
def _address_from_target(target):
|
|
||||||
# Targets are in the format <protocol>://<address>
|
|
||||||
return target.split("://")[1]
|
|
||||||
|
|
||||||
|
|
||||||
def _make_distributed_dataset(dataset,
|
|
||||||
dispatcher,
|
|
||||||
job_name=None,
|
|
||||||
max_outstanding_requests=None):
|
|
||||||
return dataset.apply(
|
|
||||||
data_service_ops._distribute(
|
|
||||||
"parallel_epochs",
|
|
||||||
dispatcher.target,
|
|
||||||
job_name=job_name,
|
|
||||||
max_outstanding_requests=max_outstanding_requests,
|
|
||||||
task_refresh_interval_hint_ms=20))
|
|
||||||
|
|
||||||
|
|
||||||
def _all_cluster_configurations():
|
def _all_cluster_configurations():
|
||||||
@ -80,83 +58,8 @@ def _all_cluster_configurations():
|
|||||||
return with_work_dir + without_work_dir
|
return with_work_dir + without_work_dir
|
||||||
|
|
||||||
|
|
||||||
def _make_distributed_range_dataset(num_elements,
|
class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||||
dispatcher,
|
parameterized.TestCase):
|
||||||
job_name=None,
|
|
||||||
max_outstanding_requests=None):
|
|
||||||
"""Creates a distributed dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_elements: The number of elements in the range dataset that will be
|
|
||||||
distributed.
|
|
||||||
dispatcher: The dispatcher to distribute to.
|
|
||||||
job_name: Optional job name for the distributed dataset.
|
|
||||||
max_outstanding_requests: Optional limit on the number of outstanding
|
|
||||||
requests.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created dataset.
|
|
||||||
"""
|
|
||||||
dataset = dataset_ops.Dataset.range(num_elements)
|
|
||||||
return _make_distributed_dataset(dataset, dispatcher, job_name,
|
|
||||||
max_outstanding_requests)
|
|
||||||
|
|
||||||
|
|
||||||
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|
||||||
|
|
||||||
def start_dispatch_server(self,
|
|
||||||
name="",
|
|
||||||
port=0,
|
|
||||||
work_dir=TMP_WORK_DIR,
|
|
||||||
fault_tolerant_mode=True,
|
|
||||||
job_gc_check_interval_ms=None,
|
|
||||||
job_gc_timeout_ms=None):
|
|
||||||
# If a test starts multiple independent dispatch servers, it should give
|
|
||||||
# them different `name` values.
|
|
||||||
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
|
|
||||||
name) if work_dir is TMP_WORK_DIR else work_dir
|
|
||||||
return server_lib.DispatchServer(
|
|
||||||
server_lib.DispatcherConfig(
|
|
||||||
port=port,
|
|
||||||
work_dir=work_dir,
|
|
||||||
fault_tolerant_mode=fault_tolerant_mode,
|
|
||||||
job_gc_check_interval_ms=job_gc_check_interval_ms,
|
|
||||||
job_gc_timeout_ms=job_gc_timeout_ms))
|
|
||||||
|
|
||||||
def start_worker_server(self, dispatcher, port=0):
|
|
||||||
return server_lib.WorkerServer(
|
|
||||||
server_lib.WorkerConfig(
|
|
||||||
dispatcher_address=_address_from_target(dispatcher.target),
|
|
||||||
port=port,
|
|
||||||
heartbeat_interval_ms=200))
|
|
||||||
|
|
||||||
def restart_dispatcher(self, dispatcher):
|
|
||||||
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
|
|
||||||
port = int(_address_from_target(dispatcher.target).split(":")[1])
|
|
||||||
dispatcher._stop()
|
|
||||||
return self.start_dispatch_server(
|
|
||||||
port=port,
|
|
||||||
work_dir=dispatcher._config.work_dir,
|
|
||||||
fault_tolerant_mode=dispatcher._config.fault_tolerant_mode)
|
|
||||||
|
|
||||||
def restart_worker(self, worker, dispatcher, use_same_port=True):
|
|
||||||
"""Stops `worker` and returns a new worker."""
|
|
||||||
port = 0
|
|
||||||
if use_same_port:
|
|
||||||
port = int(worker._address.split(":")[1])
|
|
||||||
worker._stop()
|
|
||||||
return self.start_worker_server(dispatcher, port)
|
|
||||||
|
|
||||||
def start_cluster(self,
|
|
||||||
num_workers,
|
|
||||||
name="",
|
|
||||||
work_dir=TMP_WORK_DIR,
|
|
||||||
fault_tolerant_mode=True):
|
|
||||||
"""Creates and starts a tf.data service cluster."""
|
|
||||||
dispatcher = self.start_dispatch_server(
|
|
||||||
name=name, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode)
|
|
||||||
workers = [self.start_worker_server(dispatcher) for _ in range(num_workers)]
|
|
||||||
return dispatcher, workers
|
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(test_base.eager_only_combinations(),
|
combinations.times(test_base.eager_only_combinations(),
|
||||||
@ -167,87 +70,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
work_dir=work_dir,
|
work_dir=work_dir,
|
||||||
fault_tolerant_mode=fault_tolerant_mode)
|
fault_tolerant_mode=fault_tolerant_mode)
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
ds = _make_distributed_range_dataset(10, dispatcher)
|
ds = self.make_distributed_range_dataset(10, dispatcher)
|
||||||
results = [elem.numpy() for elem in ds]
|
results = [elem.numpy() for elem in ds]
|
||||||
self.assertEqual(list(range(num_elements)), results)
|
self.assertEqual(list(range(num_elements)), results)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testDispatcherStop(self):
|
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
num_elements = 100
|
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
|
||||||
iterator = iter(ds)
|
|
||||||
results = []
|
|
||||||
results.append(next(iterator).numpy())
|
|
||||||
dispatcher._stop()
|
|
||||||
# After the dispatcher dies, the worker should continue providing the rest
|
|
||||||
# of the dataset's elements.
|
|
||||||
for _ in range(num_elements - 1):
|
|
||||||
results.append(next(iterator).numpy())
|
|
||||||
self.assertEqual(results, list(range(num_elements)))
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testDispatcherRestartBeforeReading(self):
|
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
num_elements = 100
|
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
|
||||||
dispatcher = self.restart_dispatcher(dispatcher)
|
|
||||||
|
|
||||||
self.assertDatasetProduces(ds, list(range(num_elements)))
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testDispatcherRestartDuringReading(self):
|
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
num_elements = 100
|
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
|
||||||
iterator = iter(ds)
|
|
||||||
results = []
|
|
||||||
for _ in range(num_elements // 2):
|
|
||||||
results.append(next(iterator).numpy())
|
|
||||||
dispatcher = self.restart_dispatcher(dispatcher)
|
|
||||||
for elem in iterator:
|
|
||||||
results.append(elem.numpy())
|
|
||||||
|
|
||||||
self.assertEqual(list(range(num_elements)), results)
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testDispatcherRestartBetweenIterations(self):
|
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
num_elements = 100
|
|
||||||
ds = _make_distributed_range_dataset(100, dispatcher)
|
|
||||||
self.assertDatasetProduces(ds, list(range(num_elements)))
|
|
||||||
dispatcher = self.restart_dispatcher(dispatcher)
|
|
||||||
self.assertDatasetProduces(ds, list(range(num_elements)))
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testDispatcherManyRestarts(self):
|
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
num_elements_start = 10
|
|
||||||
num_elements_end = 15
|
|
||||||
datasets = []
|
|
||||||
for num_elements in range(num_elements_start, num_elements_end):
|
|
||||||
datasets.append(_make_distributed_range_dataset(num_elements, dispatcher))
|
|
||||||
dispatcher = self.restart_dispatcher(dispatcher)
|
|
||||||
for ds, num_elements in zip(datasets,
|
|
||||||
range(num_elements_start, num_elements_end)):
|
|
||||||
self.assertDatasetProduces(ds, list(range(num_elements)))
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testDispatcherAndWorkerRestart(self):
|
|
||||||
dispatcher, [worker] = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
num_elements = 100
|
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
|
||||||
|
|
||||||
def restart():
|
|
||||||
return (self.restart_dispatcher(dispatcher),
|
|
||||||
self.restart_worker(worker, dispatcher))
|
|
||||||
|
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
|
||||||
dispatcher, worker = restart()
|
|
||||||
self.assertDatasetProduces(ds, list(range(num_elements)))
|
|
||||||
dispatcher, worker = restart()
|
|
||||||
self.assertDatasetProduces(ds, list(range(num_elements)))
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeSparse(self):
|
def testDistributeSparse(self):
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
@ -256,7 +82,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
values=constant_op.constant([0], dtype=dtypes.int32),
|
values=constant_op.constant([0], dtype=dtypes.int32),
|
||||||
dense_shape=[1])
|
dense_shape=[1])
|
||||||
ds = dataset_ops.Dataset.from_tensors(element)
|
ds = dataset_ops.Dataset.from_tensors(element)
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
|
results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
|
||||||
self.assertAllEqual(results, [[0]])
|
self.assertAllEqual(results, [[0]])
|
||||||
|
|
||||||
@ -266,7 +92,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
|
ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
|
||||||
ds = ds.map(math_ops.range)
|
ds = ds.map(math_ops.range)
|
||||||
ds = ds.apply(batching.dense_to_ragged_batch(2))
|
ds = ds.apply(batching.dense_to_ragged_batch(2))
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
results = [elem.to_tensor() for elem in ds]
|
results = [elem.to_tensor() for elem in ds]
|
||||||
self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
|
self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
|
||||||
self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
|
self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
|
||||||
@ -279,7 +105,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(2) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = ds.shuffle(num_elements)
|
ds = ds.shuffle(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
output = [elem.numpy() for elem in ds]
|
output = [elem.numpy() for elem in ds]
|
||||||
|
|
||||||
# The output will be two sequences of range(num_elements)
|
# The output will be two sequences of range(num_elements)
|
||||||
@ -298,7 +124,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testMultipleEpochs(self):
|
def testMultipleEpochs(self):
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
num_elements = 3
|
num_elements = 3
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
|
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
|
||||||
|
|
||||||
@ -307,7 +133,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_repetitions = 5
|
num_repetitions = 5
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
ds = ds.repeat(num_repetitions)
|
ds = ds.repeat(num_repetitions)
|
||||||
self.assertDatasetProduces(
|
self.assertDatasetProduces(
|
||||||
ds, expected_output=num_repetitions * list(range(num_elements)))
|
ds, expected_output=num_repetitions * list(range(num_elements)))
|
||||||
@ -320,7 +146,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
iterators = []
|
iterators = []
|
||||||
results = []
|
results = []
|
||||||
for _ in range(num_datasets):
|
for _ in range(num_datasets):
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
iterators.append(iter(ds))
|
iterators.append(iter(ds))
|
||||||
results.append([])
|
results.append([])
|
||||||
|
|
||||||
@ -337,7 +163,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_iterators = 3
|
num_iterators = 3
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
result = []
|
result = []
|
||||||
iterators = []
|
iterators = []
|
||||||
for _ in range(num_iterators):
|
for _ in range(num_iterators):
|
||||||
@ -360,100 +186,16 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
num_workers = 3
|
num_workers = 3
|
||||||
dispatcher, workers = self.start_cluster(num_workers) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(num_workers) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
results = [elem.numpy() for elem in ds]
|
results = [elem.numpy() for elem in ds]
|
||||||
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testStartServersLate(self):
|
|
||||||
# Test that the data service client performs retries instead of failing when
|
|
||||||
# the dataset is created before the master and worker are started.
|
|
||||||
try:
|
|
||||||
import portpicker # pylint: disable=g-import-not-at-top
|
|
||||||
dispatcher_port = portpicker.pick_unused_port()
|
|
||||||
except:
|
|
||||||
raise self.skipTest("Flakes in portpicker library do not represent "
|
|
||||||
"TensorFlow errors.")
|
|
||||||
dispatcher = server_lib.DispatchServer(
|
|
||||||
server_lib.DispatcherConfig(port=dispatcher_port), start=False)
|
|
||||||
worker = server_lib.WorkerServer(
|
|
||||||
server_lib.WorkerConfig(
|
|
||||||
dispatcher_address=_address_from_target(dispatcher.target), port=0),
|
|
||||||
start=False)
|
|
||||||
|
|
||||||
def start_servers():
|
|
||||||
time.sleep(1)
|
|
||||||
dispatcher.start()
|
|
||||||
worker.start()
|
|
||||||
|
|
||||||
start_servers_thread = threading.Thread(target=start_servers, daemon=True)
|
|
||||||
start_servers_thread.start()
|
|
||||||
|
|
||||||
num_elements = 10
|
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
|
||||||
results = [elem.numpy() for elem in ds]
|
|
||||||
self.assertEqual(list(range(num_elements)), results)
|
|
||||||
start_servers_thread.join()
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testAddWorkerMidJob(self):
|
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
num_elements = 100
|
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
|
||||||
iterator = iter(ds)
|
|
||||||
results = []
|
|
||||||
# Read halfway through the dataset.
|
|
||||||
for _ in range(num_elements // 2):
|
|
||||||
results.append(next(iterator).numpy())
|
|
||||||
|
|
||||||
new_worker = self.start_worker_server(dispatcher) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
# Wait for the new worker to register with the dispatcher.
|
|
||||||
while dispatcher._num_workers() < 2:
|
|
||||||
time.sleep(10 / 1000) # 10ms
|
|
||||||
|
|
||||||
for elem in iterator:
|
|
||||||
results.append(elem.numpy())
|
|
||||||
|
|
||||||
self.assertCountEqual(2 * list(range(num_elements)), results)
|
|
||||||
|
|
||||||
@combinations.generate(
|
|
||||||
combinations.times(test_base.eager_only_combinations(),
|
|
||||||
combinations.combine(use_same_port=[True, False]),
|
|
||||||
_all_cluster_configurations()))
|
|
||||||
def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
|
|
||||||
dispatcher, [worker] = self.start_cluster(
|
|
||||||
1, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode)
|
|
||||||
num_elements = 100
|
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
|
||||||
iterator = iter(ds)
|
|
||||||
# Read halfway through the dataset.
|
|
||||||
midpoint = num_elements // 2
|
|
||||||
for i in range(midpoint):
|
|
||||||
self.assertEqual(i, next(iterator).numpy())
|
|
||||||
|
|
||||||
# Stop the original worker and start a new one.
|
|
||||||
worker = self.restart_worker(worker, dispatcher, use_same_port)
|
|
||||||
|
|
||||||
# There may have been some elements prefetched from the first worker
|
|
||||||
# before it was stopped.
|
|
||||||
while True:
|
|
||||||
val = next(iterator).numpy()
|
|
||||||
if val == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# The dataset starts over now that we read from the new worker.
|
|
||||||
# TODO(b/157086991): Iterate until end of sequence when we support
|
|
||||||
# detecting lost workers.
|
|
||||||
for i in range(1, num_elements // 2):
|
|
||||||
val = next(iterator).numpy()
|
|
||||||
self.assertEqual(i, val)
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testMaxOutstandingRequests(self):
|
def testMaxOutstandingRequests(self):
|
||||||
num_workers = 3
|
num_workers = 3
|
||||||
dispatcher, workers = self.start_cluster(num_workers) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(num_workers) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
ds = _make_distributed_range_dataset(
|
ds = self.make_distributed_range_dataset(
|
||||||
num_elements, dispatcher, max_outstanding_requests=1)
|
num_elements, dispatcher, max_outstanding_requests=1)
|
||||||
self.assertCountEqual(num_workers * list(range(num_elements)),
|
self.assertCountEqual(num_workers * list(range(num_elements)),
|
||||||
self.getDatasetOutput(ds))
|
self.getDatasetOutput(ds))
|
||||||
@ -466,7 +208,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def f():
|
def f():
|
||||||
ds = _make_distributed_range_dataset(num_elements, dispatcher)
|
ds = self.make_distributed_range_dataset(num_elements, dispatcher)
|
||||||
result = tensor_array_ops.TensorArray(
|
result = tensor_array_ops.TensorArray(
|
||||||
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
||||||
i = 0
|
i = 0
|
||||||
@ -486,8 +228,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def make_ds():
|
def make_ds():
|
||||||
return dataset_ops.Dataset.range(num_elements).shuffle(num_elements)
|
return dataset_ops.Dataset.range(num_elements).shuffle(num_elements)
|
||||||
|
|
||||||
ds1 = _make_distributed_dataset(make_ds(), dispatcher, job_name="job_name")
|
ds1 = self.make_distributed_dataset(
|
||||||
ds2 = _make_distributed_dataset(make_ds(), dispatcher, job_name="job_name")
|
make_ds(), dispatcher, job_name="job_name")
|
||||||
|
ds2 = self.make_distributed_dataset(
|
||||||
|
make_ds(), dispatcher, job_name="job_name")
|
||||||
iter1 = iter(ds1)
|
iter1 = iter(ds1)
|
||||||
iter2 = iter(ds2)
|
iter2 = iter(ds2)
|
||||||
results = []
|
results = []
|
||||||
@ -507,9 +251,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def make_ds(num_elements):
|
def make_ds(num_elements):
|
||||||
return dataset_ops.Dataset.range(num_elements)
|
return dataset_ops.Dataset.range(num_elements)
|
||||||
|
|
||||||
ds1 = _make_distributed_dataset(
|
ds1 = self.make_distributed_dataset(
|
||||||
make_ds(num_elements=10), dispatcher, job_name="job_name")
|
make_ds(num_elements=10), dispatcher, job_name="job_name")
|
||||||
ds2 = _make_distributed_dataset(
|
ds2 = self.make_distributed_dataset(
|
||||||
make_ds(num_elements=11), dispatcher, job_name="job_name")
|
make_ds(num_elements=11), dispatcher, job_name="job_name")
|
||||||
iter(ds1)
|
iter(ds1)
|
||||||
with self.assertRaisesRegex(errors.FailedPreconditionError,
|
with self.assertRaisesRegex(errors.FailedPreconditionError,
|
||||||
@ -521,8 +265,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, dispatcher, job_name="job_name1")
|
ds1 = self.make_distributed_dataset(ds, dispatcher, job_name="job_name1")
|
||||||
ds2 = _make_distributed_dataset(ds, dispatcher, job_name="job_name2")
|
ds2 = self.make_distributed_dataset(ds, dispatcher, job_name="job_name2")
|
||||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||||
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
||||||
|
|
||||||
@ -531,8 +275,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, dispatcher, job_name="job_name")
|
ds1 = self.make_distributed_dataset(ds, dispatcher, job_name="job_name")
|
||||||
ds2 = _make_distributed_dataset(ds, dispatcher, job_name="job_name")
|
ds2 = self.make_distributed_dataset(ds, dispatcher, job_name="job_name")
|
||||||
# iteration 1
|
# iteration 1
|
||||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||||
self.assertDatasetProduces(ds2, [])
|
self.assertDatasetProduces(ds2, [])
|
||||||
@ -546,9 +290,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
num_elements = 100
|
num_elements = 100
|
||||||
num_repetitions = 3
|
num_repetitions = 3
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, dispatcher, job_name="job_name")
|
ds1 = self.make_distributed_dataset(ds, dispatcher, job_name="job_name")
|
||||||
ds1 = ds1.repeat(num_repetitions)
|
ds1 = ds1.repeat(num_repetitions)
|
||||||
ds2 = _make_distributed_dataset(ds, dispatcher, job_name="job_name")
|
ds2 = self.make_distributed_dataset(ds, dispatcher, job_name="job_name")
|
||||||
ds2 = ds2.repeat(num_repetitions)
|
ds2 = ds2.repeat(num_repetitions)
|
||||||
results = []
|
results = []
|
||||||
iter1 = iter(ds1)
|
iter1 = iter(ds1)
|
||||||
@ -571,7 +315,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
|
job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
|
||||||
worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable
|
worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
ds = _make_distributed_range_dataset(
|
ds = self.make_distributed_range_dataset(
|
||||||
num_elements, dispatcher, job_name=job_name)
|
num_elements, dispatcher, job_name=job_name)
|
||||||
it = iter(ds)
|
it = iter(ds)
|
||||||
self.assertEqual(next(it).numpy(), 0)
|
self.assertEqual(next(it).numpy(), 0)
|
||||||
@ -587,13 +331,13 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable
|
worker = self.start_worker_server(dispatcher) # pylint: disable=unused-variable
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
it1 = iter(
|
it1 = iter(
|
||||||
_make_distributed_range_dataset(
|
self.make_distributed_range_dataset(
|
||||||
num_elements, dispatcher, job_name="test1"))
|
num_elements, dispatcher, job_name="test1"))
|
||||||
it2 = iter(
|
it2 = iter(
|
||||||
_make_distributed_range_dataset(
|
self.make_distributed_range_dataset(
|
||||||
num_elements, dispatcher, job_name="test2"))
|
num_elements, dispatcher, job_name="test2"))
|
||||||
it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable
|
it3 = iter( # this iterator keeps the task alive. pylint: disable=unused-variable
|
||||||
_make_distributed_range_dataset(
|
self.make_distributed_range_dataset(
|
||||||
num_elements, dispatcher, job_name="test2"))
|
num_elements, dispatcher, job_name="test2"))
|
||||||
self.assertEqual(2, worker._num_tasks())
|
self.assertEqual(2, worker._num_tasks())
|
||||||
del it1
|
del it1
|
||||||
@ -624,7 +368,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
opts = dataset_ops.Options()
|
opts = dataset_ops.Options()
|
||||||
opts.experimental_deterministic = False
|
opts.experimental_deterministic = False
|
||||||
ds = ds.with_options(opts)
|
ds = ds.with_options(opts)
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
self.checkDeterminism(
|
self.checkDeterminism(
|
||||||
@ -642,7 +386,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ds = ds.with_options(options)
|
ds = ds.with_options(options)
|
||||||
|
|
||||||
dispatcher, workers = self.start_cluster(3) # to avoid gcing workers, pylint: disable=unused-variable
|
dispatcher, workers = self.start_cluster(3) # to avoid gcing workers, pylint: disable=unused-variable
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
next(iter(ds))
|
next(iter(ds))
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
@ -701,7 +445,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
def interleave_fn(_):
|
def interleave_fn(_):
|
||||||
dataset = dataset_ops.Dataset.range(2)
|
dataset = dataset_ops.Dataset.range(2)
|
||||||
_make_distributed_dataset(dataset, dispatcher)
|
self.make_distributed_dataset(dataset, dispatcher)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
ds = ds.interleave(interleave_fn, cycle_length=2)
|
ds = ds.interleave(interleave_fn, cycle_length=2)
|
||||||
@ -718,29 +462,6 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertDatasetProduces(
|
self.assertDatasetProduces(
|
||||||
ds, list(range(num_elements)), assert_items_equal=True)
|
ds, list(range(num_elements)), assert_items_equal=True)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testChangeProcessingModeAfterRestart(self):
|
|
||||||
dispatcher, workers = self.start_cluster(1) # to avoid gcing workers, pylint: disable=unused-variable
|
|
||||||
num_elements = 100
|
|
||||||
range_dataset = dataset_ops.Dataset.range(num_elements)
|
|
||||||
ds = range_dataset.apply(
|
|
||||||
data_service_ops.distribute(
|
|
||||||
processing_mode="parallel_epochs",
|
|
||||||
service=dispatcher.target,
|
|
||||||
job_name="test"))
|
|
||||||
iterator = iter(ds)
|
|
||||||
for i in range(num_elements // 2):
|
|
||||||
self.assertEqual(i, next(iterator).numpy())
|
|
||||||
dispatcher = self.restart_dispatcher(dispatcher)
|
|
||||||
ds = range_dataset.apply(
|
|
||||||
data_service_ops.distribute(
|
|
||||||
processing_mode="distributed_epoch",
|
|
||||||
service=dispatcher.target,
|
|
||||||
job_name="test"))
|
|
||||||
with self.assertRaisesOpError("already an existing job with that name "
|
|
||||||
"using processing mode <parallel_epochs>"):
|
|
||||||
next(iter(ds)).numpy()
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeNonStringAddresses(self):
|
def testDistributeNonStringAddresses(self):
|
||||||
ds = dataset_ops.Dataset.range(10)
|
ds = dataset_ops.Dataset.range(10)
|
||||||
@ -831,7 +552,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
slow = dataset_ops.Dataset.range(1)
|
slow = dataset_ops.Dataset.range(1)
|
||||||
slow = slow.apply(testing.sleep(sleep_microseconds))
|
slow = slow.apply(testing.sleep(sleep_microseconds))
|
||||||
ds = dataset_ops.Dataset.range(1).concatenate(slow)
|
ds = dataset_ops.Dataset.range(1).concatenate(slow)
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
ds = ds.prefetch(1)
|
ds = ds.prefetch(1)
|
||||||
get_next = self.getNext(ds, requires_initialization=True)
|
get_next = self.getNext(ds, requires_initialization=True)
|
||||||
self.assertEqual(0, self.evaluate(get_next()))
|
self.assertEqual(0, self.evaluate(get_next()))
|
||||||
@ -868,7 +589,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
strings = ["a" * i for i in range(num_sizes)] * size_repeats
|
strings = ["a" * i for i in range(num_sizes)] * size_repeats
|
||||||
ds = dataset_ops.Dataset.from_tensor_slices(strings)
|
ds = dataset_ops.Dataset.from_tensor_slices(strings)
|
||||||
ds = ds.shuffle(len(strings))
|
ds = ds.shuffle(len(strings))
|
||||||
ds = _make_distributed_dataset(ds, dispatcher_1)
|
ds = self.make_distributed_dataset(ds, dispatcher_1)
|
||||||
# Large enough so that all strings of the same size are windowed together.
|
# Large enough so that all strings of the same size are windowed together.
|
||||||
window_size = cluster_1_size * size_repeats
|
window_size = cluster_1_size * size_repeats
|
||||||
batch_size = size_repeats
|
batch_size = size_repeats
|
||||||
@ -881,7 +602,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
key_func=key_func,
|
key_func=key_func,
|
||||||
reduce_func=lambda _, x: x.batch(batch_size),
|
reduce_func=lambda _, x: x.batch(batch_size),
|
||||||
window_size=window_size))
|
window_size=window_size))
|
||||||
ds = _make_distributed_dataset(ds, dispatcher_2)
|
ds = self.make_distributed_dataset(ds, dispatcher_2)
|
||||||
|
|
||||||
it = iter(ds)
|
it = iter(ds)
|
||||||
for _ in range(num_sizes):
|
for _ in range(num_sizes):
|
||||||
@ -900,28 +621,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
# Larger than default OSS grpc message size limit of 4MB.
|
# Larger than default OSS grpc message size limit of 4MB.
|
||||||
tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
|
tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
|
||||||
ds = dataset_ops.Dataset.from_tensors(tensor)
|
ds = dataset_ops.Dataset.from_tensors(tensor)
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
ds = self.make_distributed_dataset(ds, dispatcher)
|
||||||
self.assertDatasetProduces(ds, [tensor])
|
self.assertDatasetProduces(ds, [tensor])
|
||||||
|
|
||||||
@combinations.generate(
|
|
||||||
combinations.times(
|
|
||||||
test_base.eager_only_combinations(),
|
|
||||||
combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR])))
|
|
||||||
def testDistributeLargeGraphThenRegisterWorker(self, work_dir):
|
|
||||||
dispatcher = self.start_dispatch_server(
|
|
||||||
work_dir=work_dir, fault_tolerant_mode=False)
|
|
||||||
worker = server_lib.WorkerServer(
|
|
||||||
server_lib.WorkerConfig(
|
|
||||||
dispatcher_address=_address_from_target(dispatcher.target), port=0),
|
|
||||||
start=False)
|
|
||||||
# Larger than default OSS grpc message size limit of 4MB.
|
|
||||||
tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
|
|
||||||
ds = dataset_ops.Dataset.from_tensors(tensor)
|
|
||||||
ds = _make_distributed_dataset(ds, dispatcher)
|
|
||||||
it = iter(ds)
|
|
||||||
worker.start()
|
|
||||||
self.assertAllEqual(next(it), tensor)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
@ -0,0 +1,178 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test base for tf.data service tests."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.ops import data_service_ops
|
||||||
|
from tensorflow.python.data.experimental.service import server_lib
|
||||||
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.framework import combinations
|
||||||
|
|
||||||
|
# This will be resolved to a tmp directory by `start_dispatch_server`.
|
||||||
|
TMP_WORK_DIR = "tmp_work_dir_placeholder"
|
||||||
|
# `""` indicates not to use a work directory.
|
||||||
|
NO_WORK_DIR = ""
|
||||||
|
|
||||||
|
|
||||||
|
def _address_from_target(target):
|
||||||
|
# Targets are in the format <protocol>://<address>
|
||||||
|
return target.split("://")[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_distributed_dataset(dataset,
|
||||||
|
dispatcher,
|
||||||
|
job_name=None,
|
||||||
|
max_outstanding_requests=None):
|
||||||
|
return dataset.apply(
|
||||||
|
data_service_ops._distribute( # pylint: disable=protected-access
|
||||||
|
"parallel_epochs",
|
||||||
|
dispatcher.target,
|
||||||
|
job_name=job_name,
|
||||||
|
max_outstanding_requests=max_outstanding_requests,
|
||||||
|
task_refresh_interval_hint_ms=20))
|
||||||
|
|
||||||
|
|
||||||
|
def _all_cluster_configurations():
|
||||||
|
with_work_dir = combinations.combine(
|
||||||
|
work_dir=TMP_WORK_DIR, fault_tolerant_mode=[True, False])
|
||||||
|
without_work_dir = combinations.combine(
|
||||||
|
work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
|
||||||
|
return with_work_dir + without_work_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _make_distributed_range_dataset(num_elements,
|
||||||
|
dispatcher,
|
||||||
|
job_name=None,
|
||||||
|
max_outstanding_requests=None):
|
||||||
|
"""Creates a distributed dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_elements: The number of elements in the range dataset that will be
|
||||||
|
distributed.
|
||||||
|
dispatcher: The dispatcher to distribute to.
|
||||||
|
job_name: Optional job name for the distributed dataset.
|
||||||
|
max_outstanding_requests: Optional limit on the number of outstanding
|
||||||
|
requests.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created dataset.
|
||||||
|
"""
|
||||||
|
dataset = dataset_ops.Dataset.range(num_elements)
|
||||||
|
return _make_distributed_dataset(dataset, dispatcher, job_name,
|
||||||
|
max_outstanding_requests)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBase(test_base.DatasetTestBase):
|
||||||
|
"""Base class for tf.data service tests."""
|
||||||
|
|
||||||
|
def start_dispatch_server(self,
|
||||||
|
name="",
|
||||||
|
port=0,
|
||||||
|
work_dir=TMP_WORK_DIR,
|
||||||
|
fault_tolerant_mode=True,
|
||||||
|
job_gc_check_interval_ms=None,
|
||||||
|
job_gc_timeout_ms=None):
|
||||||
|
# If a test starts multiple independent dispatch servers, it should give
|
||||||
|
# them different `name` values.
|
||||||
|
work_dir = os.path.join(self.get_temp_dir(), "work_dir_",
|
||||||
|
name) if work_dir is TMP_WORK_DIR else work_dir
|
||||||
|
return server_lib.DispatchServer(
|
||||||
|
server_lib.DispatcherConfig(
|
||||||
|
port=port,
|
||||||
|
work_dir=work_dir,
|
||||||
|
fault_tolerant_mode=fault_tolerant_mode,
|
||||||
|
job_gc_check_interval_ms=job_gc_check_interval_ms,
|
||||||
|
job_gc_timeout_ms=job_gc_timeout_ms))
|
||||||
|
|
||||||
|
def start_worker_server(self, dispatcher, port=0):
|
||||||
|
return server_lib.WorkerServer(
|
||||||
|
server_lib.WorkerConfig(
|
||||||
|
dispatcher_address=self.dispatcher_address(dispatcher),
|
||||||
|
port=port,
|
||||||
|
heartbeat_interval_ms=200))
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def restart_dispatcher(self, dispatcher):
|
||||||
|
"""Stops `dispatcher` and returns a new dispatcher with the same port."""
|
||||||
|
port = int(self.dispatcher_address(dispatcher).split(":")[1])
|
||||||
|
dispatcher._stop()
|
||||||
|
return self.start_dispatch_server(
|
||||||
|
port=port,
|
||||||
|
work_dir=dispatcher._config.work_dir,
|
||||||
|
fault_tolerant_mode=dispatcher._config.fault_tolerant_mode)
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def restart_worker(self, worker, dispatcher, use_same_port=True):
|
||||||
|
"""Stops `worker` and returns a new worker."""
|
||||||
|
port = 0
|
||||||
|
if use_same_port:
|
||||||
|
port = int(worker._address.split(":")[1])
|
||||||
|
worker._stop()
|
||||||
|
return self.start_worker_server(dispatcher, port)
|
||||||
|
|
||||||
|
def start_cluster(self,
|
||||||
|
num_workers,
|
||||||
|
name="",
|
||||||
|
work_dir=TMP_WORK_DIR,
|
||||||
|
fault_tolerant_mode=True):
|
||||||
|
"""Creates and starts a tf.data service cluster."""
|
||||||
|
dispatcher = self.start_dispatch_server(
|
||||||
|
name=name, work_dir=work_dir, fault_tolerant_mode=fault_tolerant_mode)
|
||||||
|
workers = [self.start_worker_server(dispatcher) for _ in range(num_workers)]
|
||||||
|
return dispatcher, workers
|
||||||
|
|
||||||
|
def dispatcher_address(self, dispatcher):
|
||||||
|
# Targets are in the format <protocol>://<address>
|
||||||
|
return dispatcher.target.split("://")[1]
|
||||||
|
|
||||||
|
def make_distributed_dataset(self,
|
||||||
|
dataset,
|
||||||
|
dispatcher,
|
||||||
|
job_name=None,
|
||||||
|
max_outstanding_requests=None):
|
||||||
|
return dataset.apply(
|
||||||
|
data_service_ops._distribute(
|
||||||
|
"parallel_epochs",
|
||||||
|
dispatcher.target,
|
||||||
|
job_name=job_name,
|
||||||
|
max_outstanding_requests=max_outstanding_requests,
|
||||||
|
task_refresh_interval_hint_ms=20))
|
||||||
|
|
||||||
|
def make_distributed_range_dataset(self,
|
||||||
|
num_elements,
|
||||||
|
dispatcher,
|
||||||
|
job_name=None,
|
||||||
|
max_outstanding_requests=None):
|
||||||
|
"""Creates a distributed dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_elements: The number of elements in the range dataset that will be
|
||||||
|
distributed.
|
||||||
|
dispatcher: The dispatcher to distribute to.
|
||||||
|
job_name: Optional job name for the distributed dataset.
|
||||||
|
max_outstanding_requests: Optional limit on the number of outstanding
|
||||||
|
requests.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created dataset.
|
||||||
|
"""
|
||||||
|
dataset = dataset_ops.Dataset.range(num_elements)
|
||||||
|
return self.make_distributed_dataset(dataset, dispatcher, job_name,
|
||||||
|
max_outstanding_requests)
|
@ -90,25 +90,6 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
|
||||||
name = "data_service_ops_test",
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["data_service_ops_test.py"],
|
|
||||||
shard_count = 10,
|
|
||||||
deps = [
|
|
||||||
"//tensorflow:tensorflow_py",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:errors",
|
|
||||||
"//tensorflow/python:framework",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
"//tensorflow/python/data",
|
|
||||||
"//tensorflow/python/data/experimental/ops:testing",
|
|
||||||
"//tensorflow/python/data/experimental/service:server_lib",
|
|
||||||
"//tensorflow/python/data/kernel_tests:test_base",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "dataset_test",
|
name = "dataset_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -117,6 +117,7 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/python:util_example_parser_configuration",
|
"//tensorflow/python:util_example_parser_configuration",
|
||||||
"//tensorflow/python/data/benchmarks:benchmark_base",
|
"//tensorflow/python/data/benchmarks:benchmark_base",
|
||||||
"//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
|
"//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
|
||||||
|
"//tensorflow/python/data/experimental/kernel_tests:data_service_test_base",
|
||||||
"//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
|
"//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
|
||||||
"//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
|
"//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
|
||||||
"//tensorflow/python/data/experimental/ops:testing",
|
"//tensorflow/python/data/experimental/ops:testing",
|
||||||
|
Loading…
Reference in New Issue
Block a user