Add make_input_fn_iterator method to distribution strategies which takes an input function and (will, in a following cl,) return an iterator object

Add InputContext class which will be an argument for input_fn to the `make_input_fn_iterator` method.

PiperOrigin-RevId: 221581165
This commit is contained in:
Yuefeng Zhou 2018-11-15 00:58:02 -08:00 committed by TensorFlower Gardener
parent 387a062dd3
commit fa5e5ecd88
15 changed files with 529 additions and 25 deletions

View File

@ -118,6 +118,7 @@ cuda_py_test(
":combinations", ":combinations",
":multi_worker_test_base", ":multi_worker_test_base",
":parameter_server_strategy", ":parameter_server_strategy",
":strategy_test_lib",
":values", ":values",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops from tensorflow.python.ops import collective_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
# TODO(yuefengz): support in-graph replication. # TODO(yuefengz): support in-graph replication.
@ -94,8 +95,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"Unrecognized task_type: %r, valid task types are: \"chief\", " "Unrecognized task_type: %r, valid task types are: \"chief\", "
"\"worker\"." % task_type) "\"worker\"." % task_type)
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len( self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
cluster_spec.as_dict().get("chief", []))
if not self._num_workers: if not self._num_workers:
raise ValueError("No `worker` or `chief` tasks can be found in " raise ValueError("No `worker` or `chief` tasks can be found in "
"`cluster_spec`.") "`cluster_spec`.")
@ -208,6 +208,23 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
return values.PerReplicaDataset( return values.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._devices, True) self._call_dataset_fn(dataset_fn), self._devices, True)
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
"""Distributes the dataset to each local GPU."""
if self._cluster_spec is None:
input_pipeline_id = 0
else:
input_pipeline_id = multi_worker_util.id_in_cluster(
self._cluster_spec, self._task_type, self._task_id)
input_context = distribute_lib.InputContext(
num_input_pipelines=self._num_workers,
input_pipeline_id=input_pipeline_id,
num_replicas_in_sync=self.num_replicas_in_sync)
return values.PerReplicaDataset(
self._call_dataset_fn(input_fn, input_context), self._devices, True)
def configure(self, def configure(self,
session_config=None, session_config=None,
cluster_spec=None, cluster_spec=None,
@ -289,4 +306,3 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
@property @property
def num_replicas_in_sync(self): def num_replicas_in_sync(self):
return len(self._devices) * self._num_workers return len(self._devices) * self._num_workers

View File

@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
@ -339,5 +340,30 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase,
self._test_complex_model(None, None, num_gpus) self._test_complex_model(None, None, num_gpus)
class InputContextTest(strategy_test_lib.DistributionTestBase):
def testInputContextPropertyLocal(self):
d = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
num_gpus_per_worker=2)
with context.graph_mode():
input_fn = self._input_fn_to_test_input_context(
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
d.make_input_fn_iterator(input_fn)
def testInputContextPropertyMultiWorker(self):
d = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
num_gpus_per_worker=2)
cluster_spec = {'worker': ['worker1', 'worker2', 'worker3'], 'ps': ['ps1']}
d.configure(cluster_spec=cluster_spec, task_type='worker', task_id=1)
with context.graph_mode():
input_fn = self._input_fn_to_test_input_context(
expected_num_replicas_in_sync=6,
expected_num_input_pipelines=3,
expected_input_pipeline_id=1) # because task_id = 1
d.make_input_fn_iterator(input_fn)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -485,6 +485,30 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return values.PerReplicaDataset( return values.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._devices) self._call_dataset_fn(dataset_fn), self._devices)
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
if self._cluster_spec:
input_fns = []
for i in range(len(self._worker_devices)):
input_context = distribute_lib.InputContext(
num_input_pipelines=len(self._worker_devices),
input_pipeline_id=i,
num_replicas_in_sync=self.num_replicas_in_sync)
input_fns.append(
partial(self._call_dataset_fn, input_fn, input_context))
return values.MultiWorkerDataset(input_fns, self._worker_devices,
self._auto_shard_dataset)
else:
input_context = distribute_lib.InputContext(
num_input_pipelines=1,
input_pipeline_id=0,
num_replicas_in_sync=self.num_replicas_in_sync)
return values.PerReplicaDataset(
self._call_dataset_fn(input_fn, input_context), self._devices)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations, def _run_steps_on_dataset(self, fn, iterator, iterations,
initial_loop_values=None): initial_loop_values=None):

View File

@ -47,6 +47,28 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
def testCallAndMergeExceptions(self): def testCallAndMergeExceptions(self):
self._test_call_and_merge_exceptions(self._get_distribution_strategy()) self._test_call_and_merge_exceptions(self._get_distribution_strategy())
@test_util.run_in_graph_and_eager_modes
def testInputContextPropertyLocal(self):
d = mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2)
input_fn = self._input_fn_to_test_input_context(
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
d.make_input_fn_iterator(input_fn)
def testInputContextPropertyMultiWorker(self):
d = mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2)
cluster_spec = {"worker": ["worker1", "worker2", "worker3"]}
d.configure(cluster_spec=cluster_spec)
with context.graph_mode():
# `expected_input_pipeline_id` is None because the input_fn will be called
# multiple times, each with a different input_pipeline_id.
input_fn = self._input_fn_to_test_input_context(
expected_num_replicas_in_sync=6,
expected_num_input_pipelines=3,
expected_input_pipeline_id=None)
d.make_input_fn_iterator(input_fn)
class VariableCreatorStackTest(test.TestCase): class VariableCreatorStackTest(test.TestCase):

View File

@ -67,6 +67,14 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
return values.PerReplicaDataset( return values.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), [self._device]) self._call_dataset_fn(dataset_fn), [self._device])
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
return values.PerReplicaDataset(
self._call_dataset_fn(input_fn, distribute_lib.InputContext()),
[self._device])
def _broadcast(self, tensor, destinations): def _broadcast(self, tensor, destinations):
del destinations del destinations
return tensor return tensor

View File

@ -42,6 +42,15 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
def testCallAndMergeExceptions(self): def testCallAndMergeExceptions(self):
self._test_call_and_merge_exceptions(self._get_distribution_strategy()) self._test_call_and_merge_exceptions(self._get_distribution_strategy())
@test_util.run_in_graph_and_eager_modes
def testInputContextPropertyLocal(self):
d = one_device_strategy.OneDeviceStrategy("/device:CPU:0")
input_fn = self._input_fn_to_test_input_context(
expected_num_replicas_in_sync=1,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
d.make_input_fn_iterator(input_fn)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -227,6 +227,27 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return values.PerReplicaDataset( return values.PerReplicaDataset(
self._call_dataset_fn(dataset_fn), self._compute_devices, True) self._call_dataset_fn(dataset_fn), self._compute_devices, True)
def _make_input_fn_iterator(
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
"""Distributes the dataset to each local GPU."""
if self._cluster_spec:
input_pipeline_id = multi_worker_util.id_in_cluster(
self._cluster_spec, self._task_type, self._task_id)
num_input_pipelines = multi_worker_util.worker_count(
self._cluster_spec, self._task_type)
else:
input_pipeline_id = 0
num_input_pipelines = 1
input_context = distribute_lib.InputContext(
num_input_pipelines=num_input_pipelines,
input_pipeline_id=input_pipeline_id,
num_replicas_in_sync=self.num_replicas_in_sync)
return values.PerReplicaDataset(
self._call_dataset_fn(input_fn, input_context), self._compute_devices,
True)
def _broadcast(self, tensor, destinations): def _broadcast(self, tensor, destinations):
if not cross_tower_ops_lib.check_destinations(destinations): if not cross_tower_ops_lib.check_destinations(destinations):
destinations = self._compute_devices destinations = self._compute_devices

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.contrib.distribute.python import values from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
@ -622,5 +623,28 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
distribution.call_for_each_replica(f) distribution.call_for_each_replica(f)
class InputContextTest(strategy_test_lib.DistributionTestBase):
def testInputContextPropertyLocal(self):
d = parameter_server_strategy.ParameterServerStrategy(num_gpus_per_worker=2)
with context.graph_mode():
input_fn = self._input_fn_to_test_input_context(
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
d.make_input_fn_iterator(input_fn)
def testInputContextPropertyMultiWorker(self):
d = parameter_server_strategy.ParameterServerStrategy(num_gpus_per_worker=2)
cluster_spec = {'worker': ['worker1', 'worker2', 'worker3'], 'ps': ['ps1']}
d.configure(cluster_spec=cluster_spec, task_type='worker', task_id=1)
with context.graph_mode():
input_fn = self._input_fn_to_test_input_context(
expected_num_replicas_in_sync=2,
expected_num_input_pipelines=3,
expected_input_pipeline_id=1) # because task_id =1
d.make_input_fn_iterator(input_fn)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -211,3 +212,30 @@ class DistributionTestBase(test.TestCase):
dist.call_for_each_replica(_merge_call_raises_fn) dist.call_for_each_replica(_merge_call_raises_fn)
with self.assertRaises(_TestException): with self.assertRaises(_TestException):
dist.call_for_each_replica(_merge_call_merge_raises_fn) dist.call_for_each_replica(_merge_call_merge_raises_fn)
def _input_fn_to_test_input_context(self, expected_num_replicas_in_sync,
expected_num_input_pipelines,
expected_input_pipeline_id):
# Use a list of one element as counter so that it can be captured by the
# `_input_fn`. This counter is incremented by 1 each time an input_fn is
# called. We use this counter to check whether the `input_pipeline_id`
# matches the counter in the in-graph replication.
worker_id_counter = [0]
def _input_fn(input_context):
"""Input fn for testing."""
self.assertIsNotNone(input_context)
self.assertEqual(expected_num_replicas_in_sync,
input_context.num_replicas_in_sync)
self.assertEqual(expected_num_input_pipelines,
input_context.num_input_pipelines)
if expected_input_pipeline_id is not None:
self.assertEqual(expected_input_pipeline_id,
input_context.input_pipeline_id)
else:
self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id)
worker_id_counter[0] += 1
return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
return _input_fn

View File

@ -1256,22 +1256,34 @@ class MultiWorkerDataset(object):
"""Initialize the MultiWorkerDataset object. """Initialize the MultiWorkerDataset object.
Args: Args:
dataset_fn: a function that returns a `tf.data.Dataset`. dataset_fn: a function or a list of functions that returns a
`tf.data.Dataset`.
worker_device_pairs: a list of (worker, list of devices on that worker) worker_device_pairs: a list of (worker, list of devices on that worker)
pairs. pairs; it must have same length with `dataset_fn` if `dataset_fn` is a
list.
prefetch_on_device: whether to prefetch to devices. prefetch_on_device: whether to prefetch to devices.
auto_shard: whether to auto-shard the dataset. auto_shard: whether to auto-shard the dataset.
""" """
if isinstance(dataset_fn, list):
if len(dataset_fn) != len(worker_device_pairs):
raise ValueError("If `dataset_fn` is a list, it must have same length "
"as `worker_device_pairs`")
if auto_shard:
raise ValueError(
"If `dataset_fn` is a list, `auto_shard` is not supported.")
self._worker_device_pairs = worker_device_pairs self._worker_device_pairs = worker_device_pairs
self._datasets = [] self._datasets = []
# TODO(yuefengz, priyag): support different set of jobs for input # TODO(yuefengz, priyag): support different set of jobs for input
# processing. # processing.
for i, (worker, worker_devices) in enumerate(worker_device_pairs): for i, (worker, worker_devices) in enumerate(worker_device_pairs):
with ops.device(worker): with ops.device(worker):
worker_input = dataset_fn() if isinstance(dataset_fn, list):
if auto_shard: worker_input = dataset_fn[i]()
worker_input = input_ops.auto_shard_dataset( else:
worker_input, len(worker_device_pairs), i) worker_input = dataset_fn()
if auto_shard:
worker_input = input_ops.auto_shard_dataset(
worker_input, len(worker_device_pairs), i)
dataset = PerReplicaDataset( dataset = PerReplicaDataset(
worker_input, worker_devices, prefetch_on_device=prefetch_on_device) worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
self._datasets.append((worker, dataset)) self._datasets.append((worker, dataset))

View File

@ -45,6 +45,33 @@ def normalize_cluster_spec(cluster_spec):
return cluster_spec return cluster_spec
# TODO(yuefengz): add more validations.
def _validate_cluster_spec(cluster_spec, task_type, task_id):
"""Validates `cluster_spec`.
It checks
1) whether there is such a task type as `task_type` in the
`cluster_spec`.
2) whether there is at most one "chief" job.
3) whether the `task_id` is smaller than the number of `task_type`.
Args:
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
task_type: string indicating the type of the task.
task_id: task_id: the id of the `task_type` in this cluster.
Throws:
ValueError: if `cluster_spec` fails any check.
"""
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
if task_type and task_type not in cluster_spec:
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
if len(cluster_spec.get("chief", [])) > 1:
raise ValueError("There must be at most one 'chief' job.")
if task_id >= len(cluster_spec[task_type]):
raise ValueError(
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
def is_chief(cluster_spec, task_type, task_id): def is_chief(cluster_spec, task_type, task_id):
"""Returns whether the given task is chief in the cluster. """Returns whether the given task is chief in the cluster.
@ -61,20 +88,73 @@ def is_chief(cluster_spec, task_type, task_id):
ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
the maximum id of the `task_type`. the maximum id of the `task_type`.
""" """
cluster_spec = normalize_cluster_spec(cluster_spec) _validate_cluster_spec(cluster_spec, task_type, task_id)
if task_type not in cluster_spec.jobs: cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
raise ValueError(
"The task_type \"%s\" is not in the `cluster_spec`." % task_type)
if task_id >= cluster_spec.num_tasks(task_type):
raise ValueError("The `task_id` %d exceeds the maximum id of %s." % (
task_id, task_type))
if task_type == "chief": if task_type == "chief":
return True return True
# If chief not in the cluster_spec, use the first worker as chief. This is # If chief not in the cluster_spec, use the first worker as chief. This is
# common in CollectiveAllReduceStrategy. # common in CollectiveAllReduceStrategy.
if ("chief" not in cluster_spec.jobs and task_type == "worker" and if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0):
task_id == 0):
return True return True
return False return False
def worker_count(cluster_spec, task_type):
"""Returns the number of workers in the cluster."""
_validate_cluster_spec(cluster_spec, task_type, task_id=0)
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
# Other jobs such as "ps" shouldn't call this function.
if task_type not in ["chief", "worker", "evaluator"]:
raise ValueError("Unexpected `task_type` %r" % task_type)
if task_type == "evaluator":
# The "evaluator" is in its own cluster or its own partition of a cluster.
# So we don't have to count "chief" or "worker" if the current task is an
# "evaluator".
return len(cluster_spec["evaluator"])
else:
# In the non-evaluator case, we return the total number of "chief" and
# "worker" tasks as the "chief" is also a worker.
return (len(cluster_spec.get("chief", [])) + len(
cluster_spec.get("worker", [])))
def id_in_cluster(cluster_spec, task_type, task_id):
"""Returns a unique id for the task in the `task_type`'s cluster.
It returns an id ranging from [0, `worker_count(task_type, task_id)`).
Note: this function assumes that "evaluate" job is in its own cluster or its
own partition of a cluster.
Args:
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
task_type: string indicating the type of the task.
task_id: the id of the `task_type` in this cluster.
Returns:
an int indicating the unique id.
Throws:
ValueError: if `task_type` is not "chief", "worker" or "evaluator".
"""
_validate_cluster_spec(cluster_spec, task_type, task_id)
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
# The "chief" job has always id 0 and there is at most one and "worker" jobs
# come after it.
if task_type == "chief":
return 0
if task_type == "worker":
return task_id + len(cluster_spec.get("chief", []))
# The "evaluator" is in its own cluster or its own partition of a cluster.
if task_type == "evaluator":
return task_id
# We currently don't assign ids to other tasks.
raise ValueError("There is no id for task_type %r" % task_type)

View File

@ -95,7 +95,7 @@ class IsChiefTest(test.TestCase):
self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1)) self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1))
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, "The task_type \"chief\" is not in the `cluster_spec`."): ValueError, "`task_type` 'chief' not found in cluster_spec."):
multi_worker_util.is_chief(cluster_spec, "chief", 0) multi_worker_util.is_chief(cluster_spec, "chief", 0)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -103,5 +103,94 @@ class IsChiefTest(test.TestCase):
multi_worker_util.is_chief(cluster_spec, "worker", 2) multi_worker_util.is_chief(cluster_spec, "worker", 2)
class NumWorkersTest(test.TestCase):
def testCountWorker(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
self.assertEqual(
multi_worker_util.worker_count(cluster_spec, task_type="chief"), 3)
self.assertEqual(
multi_worker_util.worker_count(cluster_spec, task_type="worker"), 3)
def testCountEvaluator(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"evaluator": ["127.0.0.1:7566"]
}
self.assertEqual(
multi_worker_util.worker_count(cluster_spec, task_type="evaluator"), 1)
def testTaskTypeNotFound(self):
cluster_spec = {}
with self.assertRaisesRegexp(
ValueError, "`task_type` 'worker' not found in cluster_spec."):
multi_worker_util.worker_count(cluster_spec, task_type="worker")
def testCountPs(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
# A "ps" job shouldn't call this method.
with self.assertRaisesRegexp(ValueError, "Unexpected `task_type` 'ps'"):
multi_worker_util.worker_count(cluster_spec, task_type="ps")
class IdInClusterTest(test.TestCase):
def testChiefId(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
self.assertEqual(
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0), 0)
def testWorkerId(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
self.assertEqual(
multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 2)
cluster_spec = {
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
}
self.assertEqual(
multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 1)
def testEvaluatorId(self):
cluster_spec = {
"chief": ["127.0.0.1:1234"],
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
"evaluator": ["127.0.0.1:7566"]
}
self.assertEqual(
multi_worker_util.id_in_cluster(cluster_spec, "evaluator", 0), 0)
def testPsId(self):
cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]}
with self.assertRaisesRegexp(ValueError,
"There is no id for task_type 'ps'"):
multi_worker_util.id_in_cluster(cluster_spec, "ps", 0)
def testMultipleChiefs(self):
cluster_spec = {
"chief": ["127.0.0.1:8258", "127.0.0.1:7566"],
}
with self.assertRaisesRegexp(ValueError,
"There must be at most one 'chief' job."):
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import threading import threading
import enum
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
@ -184,6 +185,80 @@ class _SameScopeAgainContext(object):
del exception_type, exception_value, traceback del exception_type, exception_value, traceback
# TODO(yuefengz): add more replication modes.
class InputReplicationMode(enum.Enum):
"""Replication mode for input function."""
# The input function will be called on each worker independently, creating as
# many input pipelines as number of workers. Replicas will dequeue from the
# local Dataset on their worker. Distribution Strategy doesn't manage any
# state sharing between such separate input pipelines.
PER_WORKER = 0
class InputContext(object):
"""A class wrapping information needed by an input function.
This is a context class that is passed to the user's input fn and contains
information about the compute replicas and input pipelines. The number of
compute replicas (in sync training) helps compute per input pipeline batch
size from the desired global batch size. Input pipeline information can be
used to return a different subset of the input in each input pipeline (for
e.g. shard the input pipeline, use a different input source etc).
"""
def __init__(self,
num_input_pipelines=1,
input_pipeline_id=0,
num_replicas_in_sync=1):
"""Initializes an InputContext object.
Args:
num_input_pipelines: the number of input pipelines in a cluster.
input_pipeline_id: the current input pipeline id, should be an int in
[0,`num_input_pipelines`).
num_replicas_in_sync: the number of replicas that are in sync.
"""
self._num_input_pipelines = num_input_pipelines
self._input_pipeline_id = input_pipeline_id
self._num_replicas_in_sync = num_replicas_in_sync
@property
def num_replicas_in_sync(self):
"""Returns the number of compute replicas in sync."""
return self._num_replicas_in_sync
@property
def input_pipeline_id(self):
"""Returns the input pipeline ID."""
return self._input_pipeline_id
@property
def num_input_pipelines(self):
"""Returns the number of input pipelines."""
return self._num_input_pipelines
def get_per_replica_batch_size(self, global_batch_size):
"""Returns the per-replica batch size.
Args:
global_batch_size: the global batch size which should be divisible by
`num_replicas_in_sync`.
Returns:
the per-replica batch size.
Raises:
ValueError: if `global_batch_size` not divisible by
`num_replicas_in_sync`.
"""
if global_batch_size % self._num_replicas_in_sync != 0:
raise ValueError("The `global_batch_size` %r is not divisible by "
"`num_replicas_in_sync` %r " %
(global_batch_size, self._num_replicas_in_sync))
return global_batch_size // self._num_replicas_in_sync
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Base classes for all distribution strategies. # Base classes for all distribution strategies.
@ -534,8 +609,15 @@ class DistributionStrategy(object):
_require_distribution_strategy_scope(self) _require_distribution_strategy_scope(self)
return variable_scope.variable_creator_scope(create_colocated_variable) return variable_scope.variable_creator_scope(create_colocated_variable)
def _call_dataset_fn(self, dataset_fn): def _call_dataset_fn(self, dataset_fn, input_context=None):
result = dataset_fn() """Call the `dataset_fn` with `input_context` as argument."""
# This method is invoked by both `make_input_fn_iterator` and
# `distribute_dataset`. The `dataset_fn` for the former one accepts an
# input_context while the latter one doesn't.
if input_context:
result = dataset_fn(input_context)
else:
result = dataset_fn()
if not isinstance(result, dataset_ops.Dataset): if not isinstance(result, dataset_ops.Dataset):
raise ValueError( raise ValueError(
"dataset_fn() must return a tf.data.Dataset when using a " "dataset_fn() must return a tf.data.Dataset when using a "
@ -580,7 +662,7 @@ class DistributionStrategy(object):
return tf.data.Dataset.from_tensors([[1.]]).repeat() return tf.data.Dataset.from_tensors([[1.]]).repeat()
with distribution_strategy.scope(): with distribution_strategy.scope():
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn) distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
iterator = distributed_dataset.make_one_shot_iterator() iterator = distributed_dataset.make_initializable_iterator()
replica_results = distribution_strategy.call_for_each_replica( replica_results = distribution_strategy.call_for_each_replica(
replica_fn, args=(iterator.get_next(),)) replica_fn, args=(iterator.get_next(),))
``` ```
@ -593,6 +675,46 @@ class DistributionStrategy(object):
""" """
raise NotImplementedError("must be implemented in descendants") raise NotImplementedError("must be implemented in descendants")
def make_input_fn_iterator(self,
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
"""Returns an iterator split across replicas created from an input function.
The `input_fn` should take an `InputContext` object where information about
input sharding can be accessed:
```
def input_fn(input_context):
d = tf.data.Dataset.from_tensors([[1.]]).repeat()
return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
with distribution_strategy.scope():
iterator = distribution_strategy.make_input_fn_iterator(
input_fn)
replica_results = distribution_strategy.call_for_each_replica(
replica_fn, iterator.get_next())
```
Args:
input_fn: A function that returns a `tf.data.Dataset`. This function is
expected to take an `InputContext` object.
replication_mode: an enum value of `InputReplicationMode`. Only
`PER_WORKER` is supported currently.
Returns:
An iterator object that can be initialized and fetched next element.
"""
if replication_mode != InputReplicationMode.PER_WORKER:
raise ValueError(
"Input replication mode not supported: %r" % replication_mode)
return self._make_input_fn_iterator(
input_fn, replication_mode=replication_mode)
def _make_input_fn_iterator(self,
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
raise NotImplementedError("must be implemented in descendants")
def broadcast(self, tensor, destinations=None): def broadcast(self, tensor, destinations=None):
"""Mirror a tensor on one device to all worker devices. """Mirror a tensor on one device to all worker devices.
@ -1193,6 +1315,11 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def distribute_dataset(self, dataset_fn): def distribute_dataset(self, dataset_fn):
return self._call_dataset_fn(dataset_fn) return self._call_dataset_fn(dataset_fn)
def _make_input_fn_iterator(self,
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
return self._call_dataset_fn(input_fn, InputContext())
def _broadcast(self, tensor, destinations): def _broadcast(self, tensor, destinations):
if destinations is None: if destinations is None:
return tensor return tensor

View File

@ -20,11 +20,11 @@ from __future__ import print_function
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import distribute from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import distribution_strategy_context
class _TestReplicaContext(distribute.ReplicaContext): class _TestReplicaContext(distribute_lib.ReplicaContext):
def merge_call(self, fn, *args, **kwargs): def merge_call(self, fn, *args, **kwargs):
return kwargs["test_arg"] return kwargs["test_arg"]
@ -38,7 +38,7 @@ def _get_test_variable(name, synchronization, aggregation):
} }
class _TestStrategy(distribute.DistributionStrategy): class _TestStrategy(distribute_lib.DistributionStrategy):
def _call_for_each_replica(self, fn, args, kwargs): def _call_for_each_replica(self, fn, args, kwargs):
with _TestReplicaContext(self, replica_id=0): with _TestReplicaContext(self, replica_id=0):
@ -144,5 +144,22 @@ class DefaultDistributionStrategyTest(test.TestCase):
_assert_in_default_state(self) _assert_in_default_state(self)
class InputContextTest(test.TestCase):
def testProperties(self):
input_context = distribute_lib.InputContext(
num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6)
self.assertEqual(6, input_context.num_replicas_in_sync)
self.assertEqual(1, input_context.input_pipeline_id)
self.assertEqual(2, input_context.num_input_pipelines)
def testPerReplicaBatchSize(self):
input_context = distribute_lib.InputContext(
num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6)
self.assertEqual(2, input_context.get_per_replica_batch_size(12))
with self.assertRaises(ValueError):
input_context.get_per_replica_batch_size(13)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()