Enable last partial batch for MWMS in TF2.x

PiperOrigin-RevId: 317760674
Change-Id: Ib7e0adbf4f8f013f21faef07ed4961c078806093
This commit is contained in:
Xinyi Wang 2020-06-22 16:39:57 -07:00 committed by TensorFlower Gardener
parent 7c38468051
commit d2b35a7955
5 changed files with 91 additions and 40 deletions

View File

@ -178,6 +178,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._communication = communication self._communication = communication
self._initialize_strategy(self._cluster_resolver) self._initialize_strategy(self._cluster_resolver)
self._cfer_fn_cache = weakref.WeakKeyDictionary() self._cfer_fn_cache = weakref.WeakKeyDictionary()
self.experimental_enable_get_next_as_optional = True
assert isinstance(self._cross_device_ops, assert isinstance(self._cross_device_ops,
cross_device_ops_lib.CollectiveAllReduce) cross_device_ops_lib.CollectiveAllReduce)

View File

@ -370,7 +370,8 @@ class CollectiveAllReduceStrategyTestBase(
else: else:
self.assertEqual(list(expected_value), list(computed_value)) self.assertEqual(list(expected_value), list(computed_value))
with self.assertRaises(errors.OutOfRangeError): # error raised by calling optional_get_value on an Optional of None
with self.assertRaises(errors.InvalidArgumentError):
next_element = iterator.get_next() next_element = iterator.get_next()
sess.run([distribute_utils.select_replica(r, next_element) sess.run([distribute_utils.select_replica(r, next_element)
for r in range(len(devices))]) for r in range(len(devices))])
@ -449,31 +450,35 @@ class DistributedCollectiveAllReduceStrategyTest(
combinations.combine( combinations.combine(
mode=['graph'], required_gpus=[0, 1, 2], use_dataset=[True, False])) mode=['graph'], required_gpus=[0, 1, 2], use_dataset=[True, False]))
def testMakeInputFnIterator(self, required_gpus, use_dataset): def testMakeInputFnIterator(self, required_gpus, use_dataset):
if use_dataset: def _worker_fn(task_type, task_id, required_gpus):
fn = lambda: dataset_ops.Dataset.range(100) if use_dataset:
else: fn = lambda: dataset_ops.Dataset.range(100)
def fn(): else:
dataset = dataset_ops.Dataset.range(100) def fn():
it = dataset_ops.make_one_shot_iterator(dataset) dataset = dataset_ops.Dataset.range(100)
return it.get_next it = dataset_ops.make_one_shot_iterator(dataset)
# We use CPU as the device when required_gpus = 0 return it.get_next
devices_per_worker = max(1, required_gpus) # We use CPU as the device when required_gpus = 0
expected_values = [[i+j for j in range(devices_per_worker)] devices_per_worker = max(1, required_gpus)
for i in range(0, 100, devices_per_worker)] expected_values = [[i+j for j in range(devices_per_worker)]
for i in range(0, 100, devices_per_worker)]
input_fn = self._input_fn_to_test_input_context( input_fn = self._input_fn_to_test_input_context(
fn, fn,
expected_num_replicas_in_sync=3*devices_per_worker, expected_num_replicas_in_sync=3*devices_per_worker,
expected_num_input_pipelines=3, expected_num_input_pipelines=3,
expected_input_pipeline_id=1) # because task_id = 1 expected_input_pipeline_id=task_id)
self._test_input_fn_iterator( self._test_input_fn_iterator(
'worker', task_type,
1, task_id,
required_gpus, required_gpus,
input_fn, input_fn,
expected_values, expected_values,
test_reinitialize=use_dataset, test_reinitialize=use_dataset,
ignore_order=not use_dataset) ignore_order=not use_dataset)
self._run_between_graph_clients(_worker_fn, self._cluster_spec,
required_gpus)
@combinations.generate(combinations.combine(mode=['graph'])) @combinations.generate(combinations.combine(mode=['graph']))
def testUpdateConfigProto(self): def testUpdateConfigProto(self):

View File

@ -549,7 +549,7 @@ def _get_next_as_optional(iterator, strategy, name=None):
# Collective all-reduce requires explicit devices for inputs. # Collective all-reduce requires explicit devices for inputs.
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
# Converting to integers for all-reduce. # Converting to integers for all-reduce.
worker_has_value = math_ops.cast(worker_has_value, dtypes.int32) worker_has_value = math_ops.cast(worker_has_value, dtypes.int64)
worker_devices.append(worker_has_value.device) worker_devices.append(worker_has_value.device)
worker_has_values.append(worker_has_value) worker_has_values.append(worker_has_value)
# Make `replicas` a flat list of values across all replicas. # Make `replicas` a flat list of values across all replicas.
@ -624,16 +624,12 @@ class DistributedIteratorBase(DistributedIteratorInterface):
# get_next_as_optional(). And we only enable get_next_as_optional when the # get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static. # output shapes are not static.
# #
# TODO(yuefengz): Currently `experimental_enable_get_next_as_optional` is
# always set to False in CollectiveAllReduceStrategy. We want to have a way
# to distinguish multi workers/single worker between graph, so we can enable
# the behavior in single worker case.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior # TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset. # when user passed input_fn instead of dataset.
if getattr( if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False): strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = not static_shape self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else: else:
self._enable_get_next_as_optional = False self._enable_get_next_as_optional = False
@ -906,7 +902,8 @@ class DistributedIterator(DistributedIteratorBase,
self._strategy = strategy self._strategy = strategy
if getattr( if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False): strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = not static_shape self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else: else:
self._enable_get_next_as_optional = False self._enable_get_next_as_optional = False
else: else:

View File

@ -1144,7 +1144,6 @@ class DistributedIteratorMultiWorkerTest(
expected_values = [[[0, 1]], [[2, 3]], [[4]]] expected_values = [[[0, 1]], [[2, 3]], [[4]]]
input_context = None input_context = None
strategy.extended.experimental_enable_get_next_as_optional = True
self._test_input_iteration( self._test_input_iteration(
input_type, input_type,
api_type, api_type,

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
@ -29,7 +30,6 @@ from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -81,7 +81,7 @@ class DistributedCollectiveAllReduceStrategyTest(
return d.shard(input_context.num_input_pipelines, return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
expected_sum_on_workers = [10, 35] expected_data_on_worker = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
input_iterator = iter( input_iterator = iter(
strategy.experimental_distribute_datasets_from_function(dataset_fn)) strategy.experimental_distribute_datasets_from_function(dataset_fn))
@ -90,10 +90,59 @@ class DistributedCollectiveAllReduceStrategyTest(
return strategy.experimental_local_results(iterator.get_next()) return strategy.experimental_local_results(iterator.get_next())
result = run(input_iterator) result = run(input_iterator)
sum_value = math_ops.reduce_sum(result) self.assertTrue(
self.assertEqual( np.array_equal(
sum_value.numpy(), result[0].numpy(),
expected_sum_on_workers[multi_worker_test_base.get_task_index()]) expected_data_on_worker[multi_worker_test_base.get_task_index()]))
def testSimpleInputFromDatasetLastPartialBatch(self, strategy):
global_batch_size = 8
dataset = dataset_ops.DatasetV2.range(14).batch(
global_batch_size, drop_remainder=False)
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
@def_function.function
def run(input_iterator):
return strategy.run(lambda x: x, args=(next(input_iterator),))
# Let the complete batch go.
run(input_iterator)
# `result` is an incomplete batch
result = run(input_iterator)
expected_data_on_worker = [[8, 9, 10], [11, 12, 13]]
self.assertTrue(
np.array_equal(
result.numpy(),
expected_data_on_worker[multi_worker_test_base.get_task_index()]))
def testSimpleInputFromFnLastPartialBatch(self, strategy):
def dataset_fn(input_context):
global_batch_size = 8
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = dataset_ops.DatasetV2.range(14).batch(
batch_size, drop_remainder=False)
return dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
input_iterator = iter(
strategy.experimental_distribute_datasets_from_function(dataset_fn))
@def_function.function
def run(input_iterator):
return strategy.run(lambda x: x, args=(next(input_iterator),))
# Let the complete batch go.
run(input_iterator)
# `result` is an incomplete batch
result = run(input_iterator)
expected_data_on_worker = [[8, 9, 10, 11], [12, 13]]
self.assertTrue(
np.array_equal(
result.numpy(), expected_data_on_worker[
multi_worker_test_base.get_task_index()]))
def testReduceHostTensor(self, strategy): def testReduceHostTensor(self, strategy):
reduced = strategy.reduce( reduced = strategy.reduce(