Enable last partial batch for MWMS in TF2.x
PiperOrigin-RevId: 317760674 Change-Id: Ib7e0adbf4f8f013f21faef07ed4961c078806093
This commit is contained in:
parent
7c38468051
commit
d2b35a7955
@ -178,6 +178,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._communication = communication
|
||||
self._initialize_strategy(self._cluster_resolver)
|
||||
self._cfer_fn_cache = weakref.WeakKeyDictionary()
|
||||
self.experimental_enable_get_next_as_optional = True
|
||||
assert isinstance(self._cross_device_ops,
|
||||
cross_device_ops_lib.CollectiveAllReduce)
|
||||
|
||||
|
@ -370,7 +370,8 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
else:
|
||||
self.assertEqual(list(expected_value), list(computed_value))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
# error raised by calling optional_get_value on an Optional of None
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
next_element = iterator.get_next()
|
||||
sess.run([distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
@ -449,31 +450,35 @@ class DistributedCollectiveAllReduceStrategyTest(
|
||||
combinations.combine(
|
||||
mode=['graph'], required_gpus=[0, 1, 2], use_dataset=[True, False]))
|
||||
def testMakeInputFnIterator(self, required_gpus, use_dataset):
|
||||
if use_dataset:
|
||||
fn = lambda: dataset_ops.Dataset.range(100)
|
||||
else:
|
||||
def fn():
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
it = dataset_ops.make_one_shot_iterator(dataset)
|
||||
return it.get_next
|
||||
# We use CPU as the device when required_gpus = 0
|
||||
devices_per_worker = max(1, required_gpus)
|
||||
expected_values = [[i+j for j in range(devices_per_worker)]
|
||||
for i in range(0, 100, devices_per_worker)]
|
||||
def _worker_fn(task_type, task_id, required_gpus):
|
||||
if use_dataset:
|
||||
fn = lambda: dataset_ops.Dataset.range(100)
|
||||
else:
|
||||
def fn():
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
it = dataset_ops.make_one_shot_iterator(dataset)
|
||||
return it.get_next
|
||||
# We use CPU as the device when required_gpus = 0
|
||||
devices_per_worker = max(1, required_gpus)
|
||||
expected_values = [[i+j for j in range(devices_per_worker)]
|
||||
for i in range(0, 100, devices_per_worker)]
|
||||
|
||||
input_fn = self._input_fn_to_test_input_context(
|
||||
fn,
|
||||
expected_num_replicas_in_sync=3*devices_per_worker,
|
||||
expected_num_input_pipelines=3,
|
||||
expected_input_pipeline_id=1) # because task_id = 1
|
||||
self._test_input_fn_iterator(
|
||||
'worker',
|
||||
1,
|
||||
required_gpus,
|
||||
input_fn,
|
||||
expected_values,
|
||||
test_reinitialize=use_dataset,
|
||||
ignore_order=not use_dataset)
|
||||
input_fn = self._input_fn_to_test_input_context(
|
||||
fn,
|
||||
expected_num_replicas_in_sync=3*devices_per_worker,
|
||||
expected_num_input_pipelines=3,
|
||||
expected_input_pipeline_id=task_id)
|
||||
self._test_input_fn_iterator(
|
||||
task_type,
|
||||
task_id,
|
||||
required_gpus,
|
||||
input_fn,
|
||||
expected_values,
|
||||
test_reinitialize=use_dataset,
|
||||
ignore_order=not use_dataset)
|
||||
|
||||
self._run_between_graph_clients(_worker_fn, self._cluster_spec,
|
||||
required_gpus)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph']))
|
||||
def testUpdateConfigProto(self):
|
||||
|
@ -549,7 +549,7 @@ def _get_next_as_optional(iterator, strategy, name=None):
|
||||
# Collective all-reduce requires explicit devices for inputs.
|
||||
with ops.device("/cpu:0"):
|
||||
# Converting to integers for all-reduce.
|
||||
worker_has_value = math_ops.cast(worker_has_value, dtypes.int32)
|
||||
worker_has_value = math_ops.cast(worker_has_value, dtypes.int64)
|
||||
worker_devices.append(worker_has_value.device)
|
||||
worker_has_values.append(worker_has_value)
|
||||
# Make `replicas` a flat list of values across all replicas.
|
||||
@ -624,16 +624,12 @@ class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
# get_next_as_optional(). And we only enable get_next_as_optional when the
|
||||
# output shapes are not static.
|
||||
#
|
||||
# TODO(yuefengz): Currently `experimental_enable_get_next_as_optional` is
|
||||
# always set to False in CollectiveAllReduceStrategy. We want to have a way
|
||||
# to distinguish multi workers/single worker between graph, so we can enable
|
||||
# the behavior in single worker case.
|
||||
#
|
||||
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
|
||||
# when user passed input_fn instead of dataset.
|
||||
if getattr(
|
||||
strategy.extended, "experimental_enable_get_next_as_optional", False):
|
||||
self._enable_get_next_as_optional = not static_shape
|
||||
self._enable_get_next_as_optional = (
|
||||
not static_shape) or strategy.extended._in_multi_worker_mode()
|
||||
else:
|
||||
self._enable_get_next_as_optional = False
|
||||
|
||||
@ -906,7 +902,8 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
self._strategy = strategy
|
||||
if getattr(
|
||||
strategy.extended, "experimental_enable_get_next_as_optional", False):
|
||||
self._enable_get_next_as_optional = not static_shape
|
||||
self._enable_get_next_as_optional = (
|
||||
not static_shape) or strategy.extended._in_multi_worker_mode()
|
||||
else:
|
||||
self._enable_get_next_as_optional = False
|
||||
else:
|
||||
|
@ -1144,7 +1144,6 @@ class DistributedIteratorMultiWorkerTest(
|
||||
expected_values = [[[0, 1]], [[2, 3]], [[4]]]
|
||||
input_context = None
|
||||
|
||||
strategy.extended.experimental_enable_get_next_as_optional = True
|
||||
self._test_input_iteration(
|
||||
input_type,
|
||||
api_type,
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
@ -29,7 +30,6 @@ from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ class DistributedCollectiveAllReduceStrategyTest(
|
||||
return d.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
|
||||
expected_sum_on_workers = [10, 35]
|
||||
expected_data_on_worker = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
|
||||
input_iterator = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@ -90,10 +90,59 @@ class DistributedCollectiveAllReduceStrategyTest(
|
||||
return strategy.experimental_local_results(iterator.get_next())
|
||||
|
||||
result = run(input_iterator)
|
||||
sum_value = math_ops.reduce_sum(result)
|
||||
self.assertEqual(
|
||||
sum_value.numpy(),
|
||||
expected_sum_on_workers[multi_worker_test_base.get_task_index()])
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
result[0].numpy(),
|
||||
expected_data_on_worker[multi_worker_test_base.get_task_index()]))
|
||||
|
||||
def testSimpleInputFromDatasetLastPartialBatch(self, strategy):
|
||||
global_batch_size = 8
|
||||
dataset = dataset_ops.DatasetV2.range(14).batch(
|
||||
global_batch_size, drop_remainder=False)
|
||||
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
|
||||
@def_function.function
|
||||
def run(input_iterator):
|
||||
return strategy.run(lambda x: x, args=(next(input_iterator),))
|
||||
|
||||
# Let the complete batch go.
|
||||
run(input_iterator)
|
||||
|
||||
# `result` is an incomplete batch
|
||||
result = run(input_iterator)
|
||||
expected_data_on_worker = [[8, 9, 10], [11, 12, 13]]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
result.numpy(),
|
||||
expected_data_on_worker[multi_worker_test_base.get_task_index()]))
|
||||
|
||||
def testSimpleInputFromFnLastPartialBatch(self, strategy):
|
||||
|
||||
def dataset_fn(input_context):
|
||||
global_batch_size = 8
|
||||
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
|
||||
dataset = dataset_ops.DatasetV2.range(14).batch(
|
||||
batch_size, drop_remainder=False)
|
||||
return dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
|
||||
input_iterator = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def run(input_iterator):
|
||||
return strategy.run(lambda x: x, args=(next(input_iterator),))
|
||||
|
||||
# Let the complete batch go.
|
||||
run(input_iterator)
|
||||
# `result` is an incomplete batch
|
||||
result = run(input_iterator)
|
||||
|
||||
expected_data_on_worker = [[8, 9, 10, 11], [12, 13]]
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
result.numpy(), expected_data_on_worker[
|
||||
multi_worker_test_base.get_task_index()]))
|
||||
|
||||
def testReduceHostTensor(self, strategy):
|
||||
reduced = strategy.reduce(
|
||||
|
Loading…
Reference in New Issue
Block a user