Add make_input_fn_iterator method to distribution strategies which takes an input function and (will, in a following cl,) return an iterator object
Add InputContext class which will be an argument for input_fn to the `make_input_fn_iterator` method. PiperOrigin-RevId: 221581165
This commit is contained in:
parent
387a062dd3
commit
fa5e5ecd88
@ -118,6 +118,7 @@ cuda_py_test(
|
|||||||
":combinations",
|
":combinations",
|
||||||
":multi_worker_test_base",
|
":multi_worker_test_base",
|
||||||
":parameter_server_strategy",
|
":parameter_server_strategy",
|
||||||
|
":strategy_test_lib",
|
||||||
":values",
|
":values",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import collective_ops
|
from tensorflow.python.ops import collective_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
|
|
||||||
|
|
||||||
# TODO(yuefengz): support in-graph replication.
|
# TODO(yuefengz): support in-graph replication.
|
||||||
@ -94,8 +95,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
|
|||||||
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
"Unrecognized task_type: %r, valid task types are: \"chief\", "
|
||||||
"\"worker\"." % task_type)
|
"\"worker\"." % task_type)
|
||||||
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
||||||
self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len(
|
self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
|
||||||
cluster_spec.as_dict().get("chief", []))
|
|
||||||
if not self._num_workers:
|
if not self._num_workers:
|
||||||
raise ValueError("No `worker` or `chief` tasks can be found in "
|
raise ValueError("No `worker` or `chief` tasks can be found in "
|
||||||
"`cluster_spec`.")
|
"`cluster_spec`.")
|
||||||
@ -208,6 +208,23 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
|
|||||||
return values.PerReplicaDataset(
|
return values.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), self._devices, True)
|
self._call_dataset_fn(dataset_fn), self._devices, True)
|
||||||
|
|
||||||
|
def _make_input_fn_iterator(
|
||||||
|
self,
|
||||||
|
input_fn,
|
||||||
|
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
|
||||||
|
"""Distributes the dataset to each local GPU."""
|
||||||
|
if self._cluster_spec is None:
|
||||||
|
input_pipeline_id = 0
|
||||||
|
else:
|
||||||
|
input_pipeline_id = multi_worker_util.id_in_cluster(
|
||||||
|
self._cluster_spec, self._task_type, self._task_id)
|
||||||
|
input_context = distribute_lib.InputContext(
|
||||||
|
num_input_pipelines=self._num_workers,
|
||||||
|
input_pipeline_id=input_pipeline_id,
|
||||||
|
num_replicas_in_sync=self.num_replicas_in_sync)
|
||||||
|
return values.PerReplicaDataset(
|
||||||
|
self._call_dataset_fn(input_fn, input_context), self._devices, True)
|
||||||
|
|
||||||
def configure(self,
|
def configure(self,
|
||||||
session_config=None,
|
session_config=None,
|
||||||
cluster_spec=None,
|
cluster_spec=None,
|
||||||
@ -289,4 +306,3 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
|
|||||||
@property
|
@property
|
||||||
def num_replicas_in_sync(self):
|
def num_replicas_in_sync(self):
|
||||||
return len(self._devices) * self._num_workers
|
return len(self._devices) * self._num_workers
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
|
|||||||
from tensorflow.contrib.distribute.python import combinations
|
from tensorflow.contrib.distribute.python import combinations
|
||||||
from tensorflow.contrib.distribute.python import cross_tower_utils
|
from tensorflow.contrib.distribute.python import cross_tower_utils
|
||||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||||
|
from tensorflow.contrib.distribute.python import strategy_test_lib
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
@ -339,5 +340,30 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase,
|
|||||||
self._test_complex_model(None, None, num_gpus)
|
self._test_complex_model(None, None, num_gpus)
|
||||||
|
|
||||||
|
|
||||||
|
class InputContextTest(strategy_test_lib.DistributionTestBase):
|
||||||
|
|
||||||
|
def testInputContextPropertyLocal(self):
|
||||||
|
d = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
|
||||||
|
num_gpus_per_worker=2)
|
||||||
|
with context.graph_mode():
|
||||||
|
input_fn = self._input_fn_to_test_input_context(
|
||||||
|
expected_num_replicas_in_sync=2,
|
||||||
|
expected_num_input_pipelines=1,
|
||||||
|
expected_input_pipeline_id=0)
|
||||||
|
d.make_input_fn_iterator(input_fn)
|
||||||
|
|
||||||
|
def testInputContextPropertyMultiWorker(self):
|
||||||
|
d = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
|
||||||
|
num_gpus_per_worker=2)
|
||||||
|
cluster_spec = {'worker': ['worker1', 'worker2', 'worker3'], 'ps': ['ps1']}
|
||||||
|
d.configure(cluster_spec=cluster_spec, task_type='worker', task_id=1)
|
||||||
|
with context.graph_mode():
|
||||||
|
input_fn = self._input_fn_to_test_input_context(
|
||||||
|
expected_num_replicas_in_sync=6,
|
||||||
|
expected_num_input_pipelines=3,
|
||||||
|
expected_input_pipeline_id=1) # because task_id = 1
|
||||||
|
d.make_input_fn_iterator(input_fn)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -485,6 +485,30 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
|||||||
return values.PerReplicaDataset(
|
return values.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), self._devices)
|
self._call_dataset_fn(dataset_fn), self._devices)
|
||||||
|
|
||||||
|
def _make_input_fn_iterator(
|
||||||
|
self,
|
||||||
|
input_fn,
|
||||||
|
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
|
||||||
|
if self._cluster_spec:
|
||||||
|
input_fns = []
|
||||||
|
for i in range(len(self._worker_devices)):
|
||||||
|
input_context = distribute_lib.InputContext(
|
||||||
|
num_input_pipelines=len(self._worker_devices),
|
||||||
|
input_pipeline_id=i,
|
||||||
|
num_replicas_in_sync=self.num_replicas_in_sync)
|
||||||
|
input_fns.append(
|
||||||
|
partial(self._call_dataset_fn, input_fn, input_context))
|
||||||
|
|
||||||
|
return values.MultiWorkerDataset(input_fns, self._worker_devices,
|
||||||
|
self._auto_shard_dataset)
|
||||||
|
else:
|
||||||
|
input_context = distribute_lib.InputContext(
|
||||||
|
num_input_pipelines=1,
|
||||||
|
input_pipeline_id=0,
|
||||||
|
num_replicas_in_sync=self.num_replicas_in_sync)
|
||||||
|
return values.PerReplicaDataset(
|
||||||
|
self._call_dataset_fn(input_fn, input_context), self._devices)
|
||||||
|
|
||||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||||
def _run_steps_on_dataset(self, fn, iterator, iterations,
|
def _run_steps_on_dataset(self, fn, iterator, iterations,
|
||||||
initial_loop_values=None):
|
initial_loop_values=None):
|
||||||
|
@ -47,6 +47,28 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
|
|||||||
def testCallAndMergeExceptions(self):
|
def testCallAndMergeExceptions(self):
|
||||||
self._test_call_and_merge_exceptions(self._get_distribution_strategy())
|
self._test_call_and_merge_exceptions(self._get_distribution_strategy())
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testInputContextPropertyLocal(self):
|
||||||
|
d = mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2)
|
||||||
|
input_fn = self._input_fn_to_test_input_context(
|
||||||
|
expected_num_replicas_in_sync=2,
|
||||||
|
expected_num_input_pipelines=1,
|
||||||
|
expected_input_pipeline_id=0)
|
||||||
|
d.make_input_fn_iterator(input_fn)
|
||||||
|
|
||||||
|
def testInputContextPropertyMultiWorker(self):
|
||||||
|
d = mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2)
|
||||||
|
cluster_spec = {"worker": ["worker1", "worker2", "worker3"]}
|
||||||
|
d.configure(cluster_spec=cluster_spec)
|
||||||
|
with context.graph_mode():
|
||||||
|
# `expected_input_pipeline_id` is None because the input_fn will be called
|
||||||
|
# multiple times, each with a different input_pipeline_id.
|
||||||
|
input_fn = self._input_fn_to_test_input_context(
|
||||||
|
expected_num_replicas_in_sync=6,
|
||||||
|
expected_num_input_pipelines=3,
|
||||||
|
expected_input_pipeline_id=None)
|
||||||
|
d.make_input_fn_iterator(input_fn)
|
||||||
|
|
||||||
|
|
||||||
class VariableCreatorStackTest(test.TestCase):
|
class VariableCreatorStackTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -67,6 +67,14 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
|
|||||||
return values.PerReplicaDataset(
|
return values.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), [self._device])
|
self._call_dataset_fn(dataset_fn), [self._device])
|
||||||
|
|
||||||
|
def _make_input_fn_iterator(
|
||||||
|
self,
|
||||||
|
input_fn,
|
||||||
|
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
|
||||||
|
return values.PerReplicaDataset(
|
||||||
|
self._call_dataset_fn(input_fn, distribute_lib.InputContext()),
|
||||||
|
[self._device])
|
||||||
|
|
||||||
def _broadcast(self, tensor, destinations):
|
def _broadcast(self, tensor, destinations):
|
||||||
del destinations
|
del destinations
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -42,6 +42,15 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
|
|||||||
def testCallAndMergeExceptions(self):
|
def testCallAndMergeExceptions(self):
|
||||||
self._test_call_and_merge_exceptions(self._get_distribution_strategy())
|
self._test_call_and_merge_exceptions(self._get_distribution_strategy())
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testInputContextPropertyLocal(self):
|
||||||
|
d = one_device_strategy.OneDeviceStrategy("/device:CPU:0")
|
||||||
|
input_fn = self._input_fn_to_test_input_context(
|
||||||
|
expected_num_replicas_in_sync=1,
|
||||||
|
expected_num_input_pipelines=1,
|
||||||
|
expected_input_pipeline_id=0)
|
||||||
|
d.make_input_fn_iterator(input_fn)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -227,6 +227,27 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
|
|||||||
return values.PerReplicaDataset(
|
return values.PerReplicaDataset(
|
||||||
self._call_dataset_fn(dataset_fn), self._compute_devices, True)
|
self._call_dataset_fn(dataset_fn), self._compute_devices, True)
|
||||||
|
|
||||||
|
def _make_input_fn_iterator(
|
||||||
|
self,
|
||||||
|
input_fn,
|
||||||
|
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
|
||||||
|
"""Distributes the dataset to each local GPU."""
|
||||||
|
if self._cluster_spec:
|
||||||
|
input_pipeline_id = multi_worker_util.id_in_cluster(
|
||||||
|
self._cluster_spec, self._task_type, self._task_id)
|
||||||
|
num_input_pipelines = multi_worker_util.worker_count(
|
||||||
|
self._cluster_spec, self._task_type)
|
||||||
|
else:
|
||||||
|
input_pipeline_id = 0
|
||||||
|
num_input_pipelines = 1
|
||||||
|
input_context = distribute_lib.InputContext(
|
||||||
|
num_input_pipelines=num_input_pipelines,
|
||||||
|
input_pipeline_id=input_pipeline_id,
|
||||||
|
num_replicas_in_sync=self.num_replicas_in_sync)
|
||||||
|
return values.PerReplicaDataset(
|
||||||
|
self._call_dataset_fn(input_fn, input_context), self._compute_devices,
|
||||||
|
True)
|
||||||
|
|
||||||
def _broadcast(self, tensor, destinations):
|
def _broadcast(self, tensor, destinations):
|
||||||
if not cross_tower_ops_lib.check_destinations(destinations):
|
if not cross_tower_ops_lib.check_destinations(destinations):
|
||||||
destinations = self._compute_devices
|
destinations = self._compute_devices
|
||||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
|||||||
from tensorflow.contrib.distribute.python import combinations
|
from tensorflow.contrib.distribute.python import combinations
|
||||||
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
from tensorflow.contrib.distribute.python import multi_worker_test_base
|
||||||
from tensorflow.contrib.distribute.python import parameter_server_strategy
|
from tensorflow.contrib.distribute.python import parameter_server_strategy
|
||||||
|
from tensorflow.contrib.distribute.python import strategy_test_lib
|
||||||
from tensorflow.contrib.distribute.python import values
|
from tensorflow.contrib.distribute.python import values
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
@ -622,5 +623,28 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
|
|||||||
distribution.call_for_each_replica(f)
|
distribution.call_for_each_replica(f)
|
||||||
|
|
||||||
|
|
||||||
|
class InputContextTest(strategy_test_lib.DistributionTestBase):
|
||||||
|
|
||||||
|
def testInputContextPropertyLocal(self):
|
||||||
|
d = parameter_server_strategy.ParameterServerStrategy(num_gpus_per_worker=2)
|
||||||
|
with context.graph_mode():
|
||||||
|
input_fn = self._input_fn_to_test_input_context(
|
||||||
|
expected_num_replicas_in_sync=2,
|
||||||
|
expected_num_input_pipelines=1,
|
||||||
|
expected_input_pipeline_id=0)
|
||||||
|
d.make_input_fn_iterator(input_fn)
|
||||||
|
|
||||||
|
def testInputContextPropertyMultiWorker(self):
|
||||||
|
d = parameter_server_strategy.ParameterServerStrategy(num_gpus_per_worker=2)
|
||||||
|
cluster_spec = {'worker': ['worker1', 'worker2', 'worker3'], 'ps': ['ps1']}
|
||||||
|
d.configure(cluster_spec=cluster_spec, task_type='worker', task_id=1)
|
||||||
|
with context.graph_mode():
|
||||||
|
input_fn = self._input_fn_to_test_input_context(
|
||||||
|
expected_num_replicas_in_sync=2,
|
||||||
|
expected_num_input_pipelines=3,
|
||||||
|
expected_input_pipeline_id=1) # because task_id =1
|
||||||
|
d.make_input_fn_iterator(input_fn)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -211,3 +212,30 @@ class DistributionTestBase(test.TestCase):
|
|||||||
dist.call_for_each_replica(_merge_call_raises_fn)
|
dist.call_for_each_replica(_merge_call_raises_fn)
|
||||||
with self.assertRaises(_TestException):
|
with self.assertRaises(_TestException):
|
||||||
dist.call_for_each_replica(_merge_call_merge_raises_fn)
|
dist.call_for_each_replica(_merge_call_merge_raises_fn)
|
||||||
|
|
||||||
|
def _input_fn_to_test_input_context(self, expected_num_replicas_in_sync,
|
||||||
|
expected_num_input_pipelines,
|
||||||
|
expected_input_pipeline_id):
|
||||||
|
# Use a list of one element as counter so that it can be captured by the
|
||||||
|
# `_input_fn`. This counter is incremented by 1 each time an input_fn is
|
||||||
|
# called. We use this counter to check whether the `input_pipeline_id`
|
||||||
|
# matches the counter in the in-graph replication.
|
||||||
|
worker_id_counter = [0]
|
||||||
|
|
||||||
|
def _input_fn(input_context):
|
||||||
|
"""Input fn for testing."""
|
||||||
|
self.assertIsNotNone(input_context)
|
||||||
|
self.assertEqual(expected_num_replicas_in_sync,
|
||||||
|
input_context.num_replicas_in_sync)
|
||||||
|
self.assertEqual(expected_num_input_pipelines,
|
||||||
|
input_context.num_input_pipelines)
|
||||||
|
if expected_input_pipeline_id is not None:
|
||||||
|
self.assertEqual(expected_input_pipeline_id,
|
||||||
|
input_context.input_pipeline_id)
|
||||||
|
else:
|
||||||
|
self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id)
|
||||||
|
worker_id_counter[0] += 1
|
||||||
|
|
||||||
|
return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
|
||||||
|
|
||||||
|
return _input_fn
|
||||||
|
@ -1256,22 +1256,34 @@ class MultiWorkerDataset(object):
|
|||||||
"""Initialize the MultiWorkerDataset object.
|
"""Initialize the MultiWorkerDataset object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_fn: a function that returns a `tf.data.Dataset`.
|
dataset_fn: a function or a list of functions that returns a
|
||||||
|
`tf.data.Dataset`.
|
||||||
worker_device_pairs: a list of (worker, list of devices on that worker)
|
worker_device_pairs: a list of (worker, list of devices on that worker)
|
||||||
pairs.
|
pairs; it must have same length with `dataset_fn` if `dataset_fn` is a
|
||||||
|
list.
|
||||||
prefetch_on_device: whether to prefetch to devices.
|
prefetch_on_device: whether to prefetch to devices.
|
||||||
auto_shard: whether to auto-shard the dataset.
|
auto_shard: whether to auto-shard the dataset.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(dataset_fn, list):
|
||||||
|
if len(dataset_fn) != len(worker_device_pairs):
|
||||||
|
raise ValueError("If `dataset_fn` is a list, it must have same length "
|
||||||
|
"as `worker_device_pairs`")
|
||||||
|
if auto_shard:
|
||||||
|
raise ValueError(
|
||||||
|
"If `dataset_fn` is a list, `auto_shard` is not supported.")
|
||||||
self._worker_device_pairs = worker_device_pairs
|
self._worker_device_pairs = worker_device_pairs
|
||||||
self._datasets = []
|
self._datasets = []
|
||||||
# TODO(yuefengz, priyag): support different set of jobs for input
|
# TODO(yuefengz, priyag): support different set of jobs for input
|
||||||
# processing.
|
# processing.
|
||||||
for i, (worker, worker_devices) in enumerate(worker_device_pairs):
|
for i, (worker, worker_devices) in enumerate(worker_device_pairs):
|
||||||
with ops.device(worker):
|
with ops.device(worker):
|
||||||
worker_input = dataset_fn()
|
if isinstance(dataset_fn, list):
|
||||||
if auto_shard:
|
worker_input = dataset_fn[i]()
|
||||||
worker_input = input_ops.auto_shard_dataset(
|
else:
|
||||||
worker_input, len(worker_device_pairs), i)
|
worker_input = dataset_fn()
|
||||||
|
if auto_shard:
|
||||||
|
worker_input = input_ops.auto_shard_dataset(
|
||||||
|
worker_input, len(worker_device_pairs), i)
|
||||||
dataset = PerReplicaDataset(
|
dataset = PerReplicaDataset(
|
||||||
worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
|
worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
|
||||||
self._datasets.append((worker, dataset))
|
self._datasets.append((worker, dataset))
|
||||||
|
@ -45,6 +45,33 @@ def normalize_cluster_spec(cluster_spec):
|
|||||||
return cluster_spec
|
return cluster_spec
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(yuefengz): add more validations.
|
||||||
|
def _validate_cluster_spec(cluster_spec, task_type, task_id):
|
||||||
|
"""Validates `cluster_spec`.
|
||||||
|
|
||||||
|
It checks
|
||||||
|
1) whether there is such a task type as `task_type` in the
|
||||||
|
`cluster_spec`.
|
||||||
|
2) whether there is at most one "chief" job.
|
||||||
|
3) whether the `task_id` is smaller than the number of `task_type`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
||||||
|
task_type: string indicating the type of the task.
|
||||||
|
task_id: task_id: the id of the `task_type` in this cluster.
|
||||||
|
Throws:
|
||||||
|
ValueError: if `cluster_spec` fails any check.
|
||||||
|
"""
|
||||||
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||||
|
if task_type and task_type not in cluster_spec:
|
||||||
|
raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
|
||||||
|
if len(cluster_spec.get("chief", [])) > 1:
|
||||||
|
raise ValueError("There must be at most one 'chief' job.")
|
||||||
|
if task_id >= len(cluster_spec[task_type]):
|
||||||
|
raise ValueError(
|
||||||
|
"The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
|
||||||
|
|
||||||
|
|
||||||
def is_chief(cluster_spec, task_type, task_id):
|
def is_chief(cluster_spec, task_type, task_id):
|
||||||
"""Returns whether the given task is chief in the cluster.
|
"""Returns whether the given task is chief in the cluster.
|
||||||
|
|
||||||
@ -61,20 +88,73 @@ def is_chief(cluster_spec, task_type, task_id):
|
|||||||
ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
|
ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
|
||||||
the maximum id of the `task_type`.
|
the maximum id of the `task_type`.
|
||||||
"""
|
"""
|
||||||
cluster_spec = normalize_cluster_spec(cluster_spec)
|
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
||||||
if task_type not in cluster_spec.jobs:
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||||
raise ValueError(
|
|
||||||
"The task_type \"%s\" is not in the `cluster_spec`." % task_type)
|
|
||||||
if task_id >= cluster_spec.num_tasks(task_type):
|
|
||||||
raise ValueError("The `task_id` %d exceeds the maximum id of %s." % (
|
|
||||||
task_id, task_type))
|
|
||||||
|
|
||||||
if task_type == "chief":
|
if task_type == "chief":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# If chief not in the cluster_spec, use the first worker as chief. This is
|
# If chief not in the cluster_spec, use the first worker as chief. This is
|
||||||
# common in CollectiveAllReduceStrategy.
|
# common in CollectiveAllReduceStrategy.
|
||||||
if ("chief" not in cluster_spec.jobs and task_type == "worker" and
|
if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0):
|
||||||
task_id == 0):
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def worker_count(cluster_spec, task_type):
|
||||||
|
"""Returns the number of workers in the cluster."""
|
||||||
|
_validate_cluster_spec(cluster_spec, task_type, task_id=0)
|
||||||
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||||
|
|
||||||
|
# Other jobs such as "ps" shouldn't call this function.
|
||||||
|
if task_type not in ["chief", "worker", "evaluator"]:
|
||||||
|
raise ValueError("Unexpected `task_type` %r" % task_type)
|
||||||
|
|
||||||
|
if task_type == "evaluator":
|
||||||
|
# The "evaluator" is in its own cluster or its own partition of a cluster.
|
||||||
|
# So we don't have to count "chief" or "worker" if the current task is an
|
||||||
|
# "evaluator".
|
||||||
|
return len(cluster_spec["evaluator"])
|
||||||
|
else:
|
||||||
|
# In the non-evaluator case, we return the total number of "chief" and
|
||||||
|
# "worker" tasks as the "chief" is also a worker.
|
||||||
|
return (len(cluster_spec.get("chief", [])) + len(
|
||||||
|
cluster_spec.get("worker", [])))
|
||||||
|
|
||||||
|
|
||||||
|
def id_in_cluster(cluster_spec, task_type, task_id):
|
||||||
|
"""Returns a unique id for the task in the `task_type`'s cluster.
|
||||||
|
|
||||||
|
It returns an id ranging from [0, `worker_count(task_type, task_id)`).
|
||||||
|
|
||||||
|
Note: this function assumes that "evaluate" job is in its own cluster or its
|
||||||
|
own partition of a cluster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
|
||||||
|
task_type: string indicating the type of the task.
|
||||||
|
task_id: the id of the `task_type` in this cluster.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an int indicating the unique id.
|
||||||
|
|
||||||
|
Throws:
|
||||||
|
ValueError: if `task_type` is not "chief", "worker" or "evaluator".
|
||||||
|
"""
|
||||||
|
_validate_cluster_spec(cluster_spec, task_type, task_id)
|
||||||
|
cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
||||||
|
|
||||||
|
# The "chief" job has always id 0 and there is at most one and "worker" jobs
|
||||||
|
# come after it.
|
||||||
|
if task_type == "chief":
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if task_type == "worker":
|
||||||
|
return task_id + len(cluster_spec.get("chief", []))
|
||||||
|
|
||||||
|
# The "evaluator" is in its own cluster or its own partition of a cluster.
|
||||||
|
if task_type == "evaluator":
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
# We currently don't assign ids to other tasks.
|
||||||
|
raise ValueError("There is no id for task_type %r" % task_type)
|
||||||
|
@ -95,7 +95,7 @@ class IsChiefTest(test.TestCase):
|
|||||||
self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1))
|
self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1))
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "The task_type \"chief\" is not in the `cluster_spec`."):
|
ValueError, "`task_type` 'chief' not found in cluster_spec."):
|
||||||
multi_worker_util.is_chief(cluster_spec, "chief", 0)
|
multi_worker_util.is_chief(cluster_spec, "chief", 0)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
@ -103,5 +103,94 @@ class IsChiefTest(test.TestCase):
|
|||||||
multi_worker_util.is_chief(cluster_spec, "worker", 2)
|
multi_worker_util.is_chief(cluster_spec, "worker", 2)
|
||||||
|
|
||||||
|
|
||||||
|
class NumWorkersTest(test.TestCase):
|
||||||
|
|
||||||
|
def testCountWorker(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.worker_count(cluster_spec, task_type="chief"), 3)
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.worker_count(cluster_spec, task_type="worker"), 3)
|
||||||
|
|
||||||
|
def testCountEvaluator(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"evaluator": ["127.0.0.1:7566"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.worker_count(cluster_spec, task_type="evaluator"), 1)
|
||||||
|
|
||||||
|
def testTaskTypeNotFound(self):
|
||||||
|
cluster_spec = {}
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "`task_type` 'worker' not found in cluster_spec."):
|
||||||
|
multi_worker_util.worker_count(cluster_spec, task_type="worker")
|
||||||
|
|
||||||
|
def testCountPs(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||||
|
}
|
||||||
|
# A "ps" job shouldn't call this method.
|
||||||
|
with self.assertRaisesRegexp(ValueError, "Unexpected `task_type` 'ps'"):
|
||||||
|
multi_worker_util.worker_count(cluster_spec, task_type="ps")
|
||||||
|
|
||||||
|
|
||||||
|
class IdInClusterTest(test.TestCase):
|
||||||
|
|
||||||
|
def testChiefId(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0), 0)
|
||||||
|
|
||||||
|
def testWorkerId(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 2)
|
||||||
|
|
||||||
|
cluster_spec = {
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 1)
|
||||||
|
|
||||||
|
def testEvaluatorId(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:1234"],
|
||||||
|
"worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
|
||||||
|
"evaluator": ["127.0.0.1:7566"]
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
multi_worker_util.id_in_cluster(cluster_spec, "evaluator", 0), 0)
|
||||||
|
|
||||||
|
def testPsId(self):
|
||||||
|
cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]}
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"There is no id for task_type 'ps'"):
|
||||||
|
multi_worker_util.id_in_cluster(cluster_spec, "ps", 0)
|
||||||
|
|
||||||
|
def testMultipleChiefs(self):
|
||||||
|
cluster_spec = {
|
||||||
|
"chief": ["127.0.0.1:8258", "127.0.0.1:7566"],
|
||||||
|
}
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"There must be at most one 'chief' job."):
|
||||||
|
multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
import enum
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
@ -184,6 +185,80 @@ class _SameScopeAgainContext(object):
|
|||||||
del exception_type, exception_value, traceback
|
del exception_type, exception_value, traceback
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(yuefengz): add more replication modes.
|
||||||
|
class InputReplicationMode(enum.Enum):
|
||||||
|
"""Replication mode for input function."""
|
||||||
|
|
||||||
|
# The input function will be called on each worker independently, creating as
|
||||||
|
# many input pipelines as number of workers. Replicas will dequeue from the
|
||||||
|
# local Dataset on their worker. Distribution Strategy doesn't manage any
|
||||||
|
# state sharing between such separate input pipelines.
|
||||||
|
PER_WORKER = 0
|
||||||
|
|
||||||
|
|
||||||
|
class InputContext(object):
|
||||||
|
"""A class wrapping information needed by an input function.
|
||||||
|
|
||||||
|
This is a context class that is passed to the user's input fn and contains
|
||||||
|
information about the compute replicas and input pipelines. The number of
|
||||||
|
compute replicas (in sync training) helps compute per input pipeline batch
|
||||||
|
size from the desired global batch size. Input pipeline information can be
|
||||||
|
used to return a different subset of the input in each input pipeline (for
|
||||||
|
e.g. shard the input pipeline, use a different input source etc).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_input_pipelines=1,
|
||||||
|
input_pipeline_id=0,
|
||||||
|
num_replicas_in_sync=1):
|
||||||
|
"""Initializes an InputContext object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_input_pipelines: the number of input pipelines in a cluster.
|
||||||
|
input_pipeline_id: the current input pipeline id, should be an int in
|
||||||
|
[0,`num_input_pipelines`).
|
||||||
|
num_replicas_in_sync: the number of replicas that are in sync.
|
||||||
|
"""
|
||||||
|
self._num_input_pipelines = num_input_pipelines
|
||||||
|
self._input_pipeline_id = input_pipeline_id
|
||||||
|
self._num_replicas_in_sync = num_replicas_in_sync
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_replicas_in_sync(self):
|
||||||
|
"""Returns the number of compute replicas in sync."""
|
||||||
|
return self._num_replicas_in_sync
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_pipeline_id(self):
|
||||||
|
"""Returns the input pipeline ID."""
|
||||||
|
return self._input_pipeline_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_input_pipelines(self):
|
||||||
|
"""Returns the number of input pipelines."""
|
||||||
|
return self._num_input_pipelines
|
||||||
|
|
||||||
|
def get_per_replica_batch_size(self, global_batch_size):
|
||||||
|
"""Returns the per-replica batch size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_batch_size: the global batch size which should be divisible by
|
||||||
|
`num_replicas_in_sync`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the per-replica batch size.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `global_batch_size` not divisible by
|
||||||
|
`num_replicas_in_sync`.
|
||||||
|
"""
|
||||||
|
if global_batch_size % self._num_replicas_in_sync != 0:
|
||||||
|
raise ValueError("The `global_batch_size` %r is not divisible by "
|
||||||
|
"`num_replicas_in_sync` %r " %
|
||||||
|
(global_batch_size, self._num_replicas_in_sync))
|
||||||
|
return global_batch_size // self._num_replicas_in_sync
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
# Base classes for all distribution strategies.
|
# Base classes for all distribution strategies.
|
||||||
|
|
||||||
@ -534,8 +609,15 @@ class DistributionStrategy(object):
|
|||||||
_require_distribution_strategy_scope(self)
|
_require_distribution_strategy_scope(self)
|
||||||
return variable_scope.variable_creator_scope(create_colocated_variable)
|
return variable_scope.variable_creator_scope(create_colocated_variable)
|
||||||
|
|
||||||
def _call_dataset_fn(self, dataset_fn):
|
def _call_dataset_fn(self, dataset_fn, input_context=None):
|
||||||
result = dataset_fn()
|
"""Call the `dataset_fn` with `input_context` as argument."""
|
||||||
|
# This method is invoked by both `make_input_fn_iterator` and
|
||||||
|
# `distribute_dataset`. The `dataset_fn` for the former one accepts an
|
||||||
|
# input_context while the latter one doesn't.
|
||||||
|
if input_context:
|
||||||
|
result = dataset_fn(input_context)
|
||||||
|
else:
|
||||||
|
result = dataset_fn()
|
||||||
if not isinstance(result, dataset_ops.Dataset):
|
if not isinstance(result, dataset_ops.Dataset):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"dataset_fn() must return a tf.data.Dataset when using a "
|
"dataset_fn() must return a tf.data.Dataset when using a "
|
||||||
@ -580,7 +662,7 @@ class DistributionStrategy(object):
|
|||||||
return tf.data.Dataset.from_tensors([[1.]]).repeat()
|
return tf.data.Dataset.from_tensors([[1.]]).repeat()
|
||||||
with distribution_strategy.scope():
|
with distribution_strategy.scope():
|
||||||
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
|
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
|
||||||
iterator = distributed_dataset.make_one_shot_iterator()
|
iterator = distributed_dataset.make_initializable_iterator()
|
||||||
replica_results = distribution_strategy.call_for_each_replica(
|
replica_results = distribution_strategy.call_for_each_replica(
|
||||||
replica_fn, args=(iterator.get_next(),))
|
replica_fn, args=(iterator.get_next(),))
|
||||||
```
|
```
|
||||||
@ -593,6 +675,46 @@ class DistributionStrategy(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("must be implemented in descendants")
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
|
def make_input_fn_iterator(self,
|
||||||
|
input_fn,
|
||||||
|
replication_mode=InputReplicationMode.PER_WORKER):
|
||||||
|
"""Returns an iterator split across replicas created from an input function.
|
||||||
|
|
||||||
|
The `input_fn` should take an `InputContext` object where information about
|
||||||
|
input sharding can be accessed:
|
||||||
|
|
||||||
|
```
|
||||||
|
def input_fn(input_context):
|
||||||
|
d = tf.data.Dataset.from_tensors([[1.]]).repeat()
|
||||||
|
return d.shard(input_context.num_input_pipelines,
|
||||||
|
input_context.input_pipeline_id)
|
||||||
|
with distribution_strategy.scope():
|
||||||
|
iterator = distribution_strategy.make_input_fn_iterator(
|
||||||
|
input_fn)
|
||||||
|
replica_results = distribution_strategy.call_for_each_replica(
|
||||||
|
replica_fn, iterator.get_next())
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_fn: A function that returns a `tf.data.Dataset`. This function is
|
||||||
|
expected to take an `InputContext` object.
|
||||||
|
replication_mode: an enum value of `InputReplicationMode`. Only
|
||||||
|
`PER_WORKER` is supported currently.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An iterator object that can be initialized and fetched next element.
|
||||||
|
"""
|
||||||
|
if replication_mode != InputReplicationMode.PER_WORKER:
|
||||||
|
raise ValueError(
|
||||||
|
"Input replication mode not supported: %r" % replication_mode)
|
||||||
|
return self._make_input_fn_iterator(
|
||||||
|
input_fn, replication_mode=replication_mode)
|
||||||
|
|
||||||
|
def _make_input_fn_iterator(self,
|
||||||
|
input_fn,
|
||||||
|
replication_mode=InputReplicationMode.PER_WORKER):
|
||||||
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
def broadcast(self, tensor, destinations=None):
|
def broadcast(self, tensor, destinations=None):
|
||||||
"""Mirror a tensor on one device to all worker devices.
|
"""Mirror a tensor on one device to all worker devices.
|
||||||
|
|
||||||
@ -1193,6 +1315,11 @@ class _DefaultDistributionStrategy(DistributionStrategy):
|
|||||||
def distribute_dataset(self, dataset_fn):
|
def distribute_dataset(self, dataset_fn):
|
||||||
return self._call_dataset_fn(dataset_fn)
|
return self._call_dataset_fn(dataset_fn)
|
||||||
|
|
||||||
|
def _make_input_fn_iterator(self,
|
||||||
|
input_fn,
|
||||||
|
replication_mode=InputReplicationMode.PER_WORKER):
|
||||||
|
return self._call_dataset_fn(input_fn, InputContext())
|
||||||
|
|
||||||
def _broadcast(self, tensor, destinations):
|
def _broadcast(self, tensor, destinations):
|
||||||
if destinations is None:
|
if destinations is None:
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -20,11 +20,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import distribute
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
from tensorflow.python.training import distribution_strategy_context
|
from tensorflow.python.training import distribution_strategy_context
|
||||||
|
|
||||||
|
|
||||||
class _TestReplicaContext(distribute.ReplicaContext):
|
class _TestReplicaContext(distribute_lib.ReplicaContext):
|
||||||
|
|
||||||
def merge_call(self, fn, *args, **kwargs):
|
def merge_call(self, fn, *args, **kwargs):
|
||||||
return kwargs["test_arg"]
|
return kwargs["test_arg"]
|
||||||
@ -38,7 +38,7 @@ def _get_test_variable(name, synchronization, aggregation):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class _TestStrategy(distribute.DistributionStrategy):
|
class _TestStrategy(distribute_lib.DistributionStrategy):
|
||||||
|
|
||||||
def _call_for_each_replica(self, fn, args, kwargs):
|
def _call_for_each_replica(self, fn, args, kwargs):
|
||||||
with _TestReplicaContext(self, replica_id=0):
|
with _TestReplicaContext(self, replica_id=0):
|
||||||
@ -144,5 +144,22 @@ class DefaultDistributionStrategyTest(test.TestCase):
|
|||||||
_assert_in_default_state(self)
|
_assert_in_default_state(self)
|
||||||
|
|
||||||
|
|
||||||
|
class InputContextTest(test.TestCase):
|
||||||
|
|
||||||
|
def testProperties(self):
|
||||||
|
input_context = distribute_lib.InputContext(
|
||||||
|
num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6)
|
||||||
|
self.assertEqual(6, input_context.num_replicas_in_sync)
|
||||||
|
self.assertEqual(1, input_context.input_pipeline_id)
|
||||||
|
self.assertEqual(2, input_context.num_input_pipelines)
|
||||||
|
|
||||||
|
def testPerReplicaBatchSize(self):
|
||||||
|
input_context = distribute_lib.InputContext(
|
||||||
|
num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6)
|
||||||
|
self.assertEqual(2, input_context.get_per_replica_batch_size(12))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
input_context.get_per_replica_batch_size(13)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user