Automated rollback of commit 766a1e0037

PiperOrigin-RevId: 239305410
This commit is contained in:
Ruoxin Sang 2019-03-19 17:31:19 -07:00 committed by TensorFlower Gardener
parent 18493fb2a3
commit e18f3efa34
2 changed files with 32 additions and 233 deletions

View File

@ -35,54 +35,36 @@ from tensorflow.python.util import nest
class InputIteratorTestBase(test.TestCase): class InputIteratorTestBase(test.TestCase):
def _create_iterator(self, input_type, dataset_fn, worker_device_pairs, def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
devices, split_batch_by): expected_values, sess=None, split_batch_by=None):
devices = nest.flatten([ds for _, ds in worker_device_pairs])
device_map = values.ReplicaDeviceMap(devices) device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
if input_type == "input_fn": if input_type == "input_fn":
input_contexts = [] input_contexts = [
for i in range(input_workers.num_workers): distribute_lib.InputContext() for _ in worker_device_pairs]
input_contexts.append( input_fn = lambda _: dataset_fn()
distribute_lib.InputContext( iterator = input_lib.InputFunctionIterator(
num_input_pipelines=input_workers.num_workers, input_fn, input_workers, input_contexts)
input_pipeline_id=i,
num_replicas_in_sync=len(devices)))
iterator = input_lib.InputFunctionIterator(dataset_fn, input_workers,
input_contexts)
else: else:
iterator = input_lib.DatasetIterator( iterator = input_lib.DatasetIterator(
dataset_fn(distribute_lib.InputContext()), input_workers, dataset_fn(), input_workers, split_batch_by)
split_batch_by)
return iterator
def _test_iterator(self,
input_type,
dataset_fn,
worker_device_pairs,
expected_values,
sess=None,
split_batch_by=None):
devices = nest.flatten([ds for _, ds in worker_device_pairs])
iterator = self._create_iterator(
input_type, dataset_fn, worker_device_pairs, devices, split_batch_by)
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(control_flow_ops.group(iterator.initialize())) evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values: for expected_value in expected_values:
next_element = iterator.get_next() next_element = iterator.get_next()
computed_value = evaluate( computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))]) [values.select_replica(r, next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value)) self.assertAllEqual(expected_value, computed_value)
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next() next_element = iterator.get_next()
evaluate( evaluate([values.select_replica(r, next_element)
[values.select_replica(r, next_element) for r in range(len(devices))]) for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again. # After re-initializing the iterator, should be able to iterate again.
evaluate(control_flow_ops.group(iterator.initialize())) evaluate(control_flow_ops.group(iterator.initialize()))
@ -91,9 +73,7 @@ class InputIteratorTestBase(test.TestCase):
next_element = iterator.get_next() next_element = iterator.get_next()
computed_value = evaluate( computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))]) [values.select_replica(r, next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value)) self.assertAllEqual(expected_value, computed_value)
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
class InputIteratorSingleWorkerTest(InputIteratorTestBase, class InputIteratorSingleWorkerTest(InputIteratorTestBase,
@ -104,7 +84,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
input_type=["input_fn", "dataset"])) input_type=["input_fn", "dataset"]))
def testOneDeviceCPU(self, input_type): def testOneDeviceCPU(self, input_type):
worker_device_pairs = [("", ["/device:CPU:0"])] worker_device_pairs = [("", ["/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i] for i in range(10)] expected_values = [[i] for i in range(10)]
@ -117,7 +97,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
required_gpus=1)) required_gpus=1))
def testTwoDevicesOneGPUOneCPU(self, input_type): def testTwoDevicesOneGPUOneCPU(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)] expected_values = [[i, i+1] for i in range(0, 10, 2)]
@ -130,9 +110,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
required_gpus=1)) required_gpus=1))
def testTupleDataset(self, input_type): def testTupleDataset(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
def dataset_fn():
def dataset_fn(ctx):
del ctx
dataset1 = dataset_ops.Dataset.range(10) dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2)) return dataset_ops.Dataset.zip((dataset1, dataset2))
@ -142,17 +120,15 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
self._test_iterator(input_type, dataset_fn, worker_device_pairs, self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values) expected_values)
@combinations.generate( @combinations.generate(combinations.combine(
combinations.combine(
mode=["graph", "eager"], mode=["graph", "eager"],
input_type=["input_fn", "dataset"], input_type=["input_fn", "dataset"],
required_gpus=1)) required_gpus=1))
def testUnevenDatasetBatches(self, input_type): def testUnevenDatasetBatches(self, input_type):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) dataset_fn = lambda: dataset_ops.Dataset.range(11)
# The last global batch only contains data for one replica. expected_values = [[i, i+1] for i in range(0, 10, 2)]
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
self._test_iterator(input_type, dataset_fn, worker_device_pairs, self._test_iterator(input_type, dataset_fn, worker_device_pairs,
expected_values) expected_values)
@ -164,7 +140,7 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
def testBatchSplitting(self, input_type, split_batch_by): def testBatchSplitting(self, input_type, split_batch_by):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
batch_size = 10 batch_size = 10
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size)
updated_batch_size = ( updated_batch_size = (
batch_size // split_batch_by if split_batch_by else batch_size) batch_size // split_batch_by if split_batch_by else batch_size)
@ -206,7 +182,7 @@ class InputIteratorMultiWorkerTest(
def testOneDevicePerWorker(self, input_type): def testOneDevicePerWorker(self, input_type):
worker_devices = self._cpu_devices() worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess: with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda _: dataset_ops.Dataset.range(4) dataset_fn = lambda: dataset_ops.Dataset.range(4)
self._test_iterator(input_type, dataset_fn, worker_devices, self._test_iterator(input_type, dataset_fn, worker_devices,
[[0, 0], [1, 1], [2, 2], [3, 3]], sess) [[0, 0], [1, 1], [2, 2], [3, 3]], sess)
@ -217,7 +193,7 @@ class InputIteratorMultiWorkerTest(
def testTwoDevicesPerWorker(self, input_type): def testTwoDevicesPerWorker(self, input_type):
worker_devices = self._cpu_and_one_gpu_devices() worker_devices = self._cpu_and_one_gpu_devices()
with context.graph_mode(), self.cached_session() as sess: with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda _: dataset_ops.Dataset.range(4) dataset_fn = lambda: dataset_ops.Dataset.range(4)
self._test_iterator(input_type, dataset_fn, worker_devices, self._test_iterator(input_type, dataset_fn, worker_devices,
[[0, 1, 0, 1], [2, 3, 2, 3]], sess) [[0, 1, 0, 1], [2, 3, 2, 3]], sess)
@ -227,9 +203,7 @@ class InputIteratorMultiWorkerTest(
def testTupleDataset(self, input_type): def testTupleDataset(self, input_type):
worker_devices = self._cpu_devices() worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess: with context.graph_mode(), self.cached_session() as sess:
def dataset_fn():
def dataset_fn(ctx):
del ctx
dataset1 = dataset_ops.Dataset.range(4) dataset1 = dataset_ops.Dataset.range(4)
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2)) return dataset_ops.Dataset.zip((dataset1, dataset2))
@ -238,36 +212,6 @@ class InputIteratorMultiWorkerTest(
self._test_iterator(input_type, dataset_fn, worker_devices, self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess) expected_values, sess)
@combinations.generate(
combinations.combine(
mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1))
def testUnevenDatasetBatches(self, input_type):
worker_devices = self._cpu_and_one_gpu_devices()
with context.graph_mode(), self.cached_session() as sess:
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
[[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess)
@combinations.generate(
combinations.combine(
mode=["graph"], input_type=["input_fn"], required_gpus=1))
def testDifferentDatasets(self, input_type):
worker_devices = self._cpu_and_one_gpu_devices()
with context.graph_mode(), self.cached_session() as sess:
def dataset_fn(ctx):
if ctx.input_pipeline_id == 0:
return dataset_ops.Dataset.range(8).batch(2)
else:
return dataset_ops.Dataset.range(9).batch(2)
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
[[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]]
self._test_iterator(input_type, dataset_fn, worker_devices,
expected_values, sess)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -21,21 +21,14 @@ from __future__ import print_function
from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.data.util import structure
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
class InputWorkers(object): class InputWorkers(object):
@ -133,7 +126,6 @@ class InputIteratorImpl(InputIterator):
def get_next(self, name=None): def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas.""" """Returns the next input from the iterator for all replicas."""
replicas = [] replicas = []
worker_has_values = []
for i, worker in enumerate(self._input_workers.worker_devices): for i, worker in enumerate(self._input_workers.worker_devices):
if name is not None: if name is not None:
d = tf_device.DeviceSpec.from_string(worker) d = tf_device.DeviceSpec.from_string(worker)
@ -141,61 +133,8 @@ class InputIteratorImpl(InputIterator):
else: else:
new_name = None new_name = None
with ops.device(worker): with ops.device(worker):
worker_has_value, next_element = (
self._iterators[i].get_next_as_list(new_name))
worker_has_values.append(worker_has_value)
# Make `replicas` a flat list of values across all replicas. # Make `replicas` a flat list of values across all replicas.
replicas.append(next_element) replicas.extend(self._iterators[i].get_next_as_list(new_name))
out_of_range_replicas = []
def out_of_range_fn(worker_index, device):
"""This function will throw an OutOfRange error."""
# As this will be only called when there is no data left, so calling
# get_next() will trigger an OutOfRange error.
data = self._iterators[worker_index].get_next(device)
out_of_range_replicas.append(data)
return data
# `global_has_value` indicates whether there is data in this global batch.
# We do a all-reduce across all the workers in the multi-worker case.
# TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy.
if len(worker_has_values) > 1:
with ops.device(self._input_workers.compute_devices_for_worker(0)[0]):
# Place the tf.reduce_any op in device 0 to minimize communication
# cost.
# TODO(b/128545270): Investigate why placing it on worker 0 will cause
# the entire data to copy back from device to host.
global_has_value = math_ops.reduce_any(worker_has_values)
else:
global_has_value = worker_has_values[0]
results = []
for i, worker in enumerate(self._input_workers.worker_devices):
with ops.device(worker):
devices = self._input_workers.compute_devices_for_worker(i)
for j, device in enumerate(devices):
with ops.device(device):
# pylint: disable=undefined-loop-variable
# pylint: disable=cell-var-from-loop
# It is fine for the lambda to capture variables from the loop as
# the lambda is executed in the loop as well.
result = control_flow_ops.cond(global_has_value,
lambda: replicas[i][j],
lambda: out_of_range_fn(i, device))
# pylint: enable=cell-var-from-loop
# pylint: enable=undefined-loop-variable
results.append(result)
replicas = results
# Some dimensions in `replicas` will become unknown after we conditionally
# return the real tensors or the dummy tensors. We fix the input shapes by
# using the shapes from `out_of_range_replicas` because it is calling
# get_next() inside.
flattened_replicas = nest.flatten(replicas)
for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
flattened_replicas[i].set_shape(replica_data.get_shape())
replicas = nest.pack_sequence_as(replicas, flattened_replicas)
return values.regroup(self._input_workers.device_map, replicas) return values.regroup(self._input_workers.device_map, replicas)
@ -326,45 +265,6 @@ class DatasetIterator(InputIteratorImpl):
super(DatasetIterator, self).__init__(input_workers, iterators) super(DatasetIterator, self).__init__(input_workers, iterators)
def _dummy_tensor_fn(value_structure):
"""A function to create dummy tensors from `value_structure`."""
def create_dummy_tensor(feature_shape, feature_type):
"""Create a dummy tensor with possible batch dimensions set to 0."""
# Ideally we should set the batch dimension to 0, however as in
# DistributionStrategy we don't know the batch dimension, we try to
# guess it as much as possible. If the feature has unknown dimensions, we
# will set them to 0. If the feature shape is already static, we guess the
# first dimension as batch dimension and set it to 0.
dims = []
for dim in feature_shape.dims:
if dim.value is None:
dims.append(tensor_shape.Dimension(0))
else:
dims.append(dim)
if feature_shape.is_fully_defined() and dims:
dims[0] = tensor_shape.Dimension(0)
# Create the dummy tensor.
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
return dummy_tensor
result = []
# pylint: disable=protected-access
for feature_shape, feature_type in zip(value_structure._flat_shapes,
value_structure._flat_types):
result.append(create_dummy_tensor(feature_shape, feature_type))
if isinstance(value_structure, structure.NestedStructure):
result = nest.pack_sequence_as(value_structure._nested_structure, result)
else:
result = result[0]
# pylint: enable=protected-access
return result
class _SingleWorkerDatasetIterator(object): class _SingleWorkerDatasetIterator(object):
"""Iterator for a single `tf.data.Dataset`.""" """Iterator for a single `tf.data.Dataset`."""
@ -390,51 +290,12 @@ class _SingleWorkerDatasetIterator(object):
self._iterator = multi_device_iterator_ops.MultiDeviceIterator( self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices) self._dataset, self._devices)
def get_next(self, device, name=None):
"""Get next element for the given device."""
del name
with ops.device(self._worker):
return self._iterator.get_next(device)
def get_next_as_list(self, name=None): def get_next_as_list(self, name=None):
"""Get next element from underlying iterator. """Get next element from the underlying iterator."""
If there is no data left, a list of dummy tensors with possible batch
dimensions set to 0 will be returned.
Args:
name: not used.
Returns:
A boolean tensor indicates whether there is any data in next element and
the real data as the next element or a list of dummy tensors if no data
left.
"""
del name del name
with ops.device(self._worker): with ops.device(self._worker):
data_list = self._iterator.get_next_as_optional() data_list = self._iterator.get_next()
result = [] return data_list
for i, data in enumerate(data_list):
# Place the condition op in the same device as the data so the data
# doesn't need to be sent back to the worker.
with ops.device(self._devices[i]):
# As MultiDeviceIterator will fetch data in order, so we only need to
# check if the first replica has value to see whether there is data
# left for this single worker.
if i == 0:
worker_has_value = data.has_value()
# pylint: disable=unnecessary-lambda
# pylint: disable=cell-var-from-loop
real_data = control_flow_ops.cond(
data.has_value(),
lambda: data.get_value(),
lambda: _dummy_tensor_fn(data.value_structure))
result.append(real_data)
# pylint: enable=cell-var-from-loop
# pylint: enable=unnecessary-lambda
return worker_has_value, result
def initialize(self): def initialize(self):
"""Initialze underlying iterator. """Initialze underlying iterator.
@ -473,18 +334,12 @@ class _SingleWorkerCallableIterator(object):
self._worker = worker self._worker = worker
self._devices = devices self._devices = devices
def get_next(self, device, name=None):
"""Get next element for the given device from the callable."""
del device, name
with ops.device(self._worker):
return self._fn()
def get_next_as_list(self, name=None): def get_next_as_list(self, name=None):
"""Get next element from the callable.""" """Get next element from the callable."""
del name del name
with ops.device(self._worker): with ops.device(self._worker):
data_list = [self._fn() for _ in self._devices] data_list = [self._fn() for _ in self._devices]
return constant_op.constant(True), data_list return data_list
def initialize(self): def initialize(self):
# TODO(petebu) Should this throw an exception instead? # TODO(petebu) Should this throw an exception instead?