Merge pull request #31667 from tensorflow/ggadde-cp10

Improve NumPy to Dataset performance with vectorized shuffling.
This commit is contained in:
Goldie Gadde 2019-08-15 22:26:13 -07:00 committed by GitHub
commit d96ab53c6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 305 additions and 59 deletions

View File

@ -162,6 +162,13 @@ class DataAdapter(object):
""" """
raise NotImplementedError raise NotImplementedError
def should_recreate_iterator(self, steps_per_epoch):
"""Returns whether a new iterator should be created every epoch."""
# Only recreate iterator when the data has a fixed length, which will be
# fully consumed every epoch, or has a unknown length (dataset, generator)
# and will be fully consumed (steps_per_epoch is None)
return self.get_size() is not None or steps_per_epoch is None
class TensorLikeDataAdapter(DataAdapter): class TensorLikeDataAdapter(DataAdapter):
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
@ -174,20 +181,174 @@ class TensorLikeDataAdapter(DataAdapter):
if y is not None: if y is not None:
flat_inputs += nest.flatten(y) flat_inputs += nest.flatten(y)
def _is_tensor_or_composite(v): def _is_tensor(v):
if isinstance(v, (ops.Tensor, np.ndarray)): if isinstance(v, (ops.Tensor, np.ndarray)):
return True return True
return False
return all(_is_tensor(v) for v in flat_inputs)
def __init__(self,
x,
y=None,
sample_weights=None,
batch_size=None,
epochs=1,
steps=None,
shuffle=False,
**kwargs):
super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs)
x = _process_numpy_inputs(x)
y = _process_numpy_inputs(y)
sample_weights = _process_numpy_inputs(sample_weights)
# If sample_weights are not specified for an output use 1.0 as weights.
if sample_weights is not None and None in sample_weights:
weight = next(s for s in sample_weights if s is not None)
sample_weights = training_utils.list_to_tuple([
array_ops.ones((weight.shape[0],)) if sw is None else sw
for sw in sample_weights
])
if y is not None and sample_weights is not None:
inputs = (x, y, sample_weights)
elif y is not None:
# Sample weight is only needed for training, so if y is None, then
# sample_weight is ignored.
inputs = (x, y)
else:
inputs = (x,)
num_samples = int(nest.flatten(x)[0].shape[0])
# If batch_size is not passed but steps is, calculate from the input data.
if steps and not batch_size:
batch_size = int(math.ceil(num_samples / steps))
if not batch_size:
raise ValueError(
"`batch_size` or `steps` is required for `Tensor` or `NumPy`"
" input data.")
self._size = int(math.ceil(num_samples / batch_size))
self._batch_size = batch_size
self._has_partial_batch = (self._size != (num_samples // batch_size))
self._partial_batch_size = None
if self._has_partial_batch:
self._partial_batch_size = (
num_samples - (self._size - 1) * self._batch_size)
# Vectorized version of shuffle.
# This is a performance improvement over using `from_tensor_slices`.
# The indices of the data are shuffled and batched, and these indices
# are then zipped with the data and used to extract a batch of the data
# at each step. The performance improvements here come from:
# 1. vectorized batch using gather
# 2. parallelized map
# 3. vectorized shuffle by using reshape and unbatch
# 4. disabled static optimizations
indices_list = []
for _ in range(epochs):
indices = np.arange(num_samples)
if shuffle:
np.random.shuffle(indices)
full_batch_indices = np.reshape(
indices[:(num_samples // batch_size) * batch_size], [-1, batch_size])
partial_batch_indices = indices[(num_samples // batch_size) * batch_size:]
epoch_indices_ds = dataset_ops.DatasetV2.from_tensors(
full_batch_indices).unbatch()
if partial_batch_indices.size:
epoch_indices_ds = epoch_indices_ds.concatenate(
dataset_ops.DatasetV2.from_tensors(partial_batch_indices))
indices_list.append(epoch_indices_ds)
indices_ds = dataset_ops.DatasetV2.from_tensor_slices(
indices_list).flat_map(lambda x: x)
data_ds = dataset_ops.DatasetV2.from_tensors(inputs).repeat()
dataset = dataset_ops.DatasetV2.zip((data_ds, indices_ds))
def _nested_grab_batch(data, indices):
"""Grabs batches of Tensors in `data` based on `indices`."""
def _grab_batch(x):
"""Grabs a batch of `x`."""
x_batch = array_ops.gather(x, indices)
x_shape = x.shape.as_list()
if not self._has_partial_batch:
# Recover the batch shape info.
x_shape[0] = self._batch_size
x_batch.set_shape(x_shape)
elif self._partial_batch_size >= num_samples:
# Only one batch per epoch.
x_shape[0] = self._partial_batch_size
x_batch.set_shape(x_shape)
return x_batch
return nest.map_structure(_grab_batch, data)
dataset = dataset.map(
_nested_grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE)
# Default optimizations are disabled to avoid the overhead of (unnecessary)
# input pipeline graph serialization and deserialization
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
self._dataset = dataset
def get_dataset(self):
return self._dataset
def get_size(self):
return self._size
def batch_size(self):
return self._batch_size
def has_partial_batch(self):
return self._has_partial_batch
def partial_batch_size(self):
return self._partial_batch_size
def should_recreate_iterator(self, _):
# An infinite dataset is always created here.
return False
class CompositeTensorDataAdapter(DataAdapter):
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
@staticmethod
def can_handle(x, y=None):
flat_inputs = nest.flatten(x)
if y is not None:
flat_inputs += nest.flatten(y)
def _is_composite(v):
# Dataset inherits from CompositeTensor but shouldn't be handled here. # Dataset inherits from CompositeTensor but shouldn't be handled here.
if (isinstance(v, composite_tensor.CompositeTensor) and if (isinstance(v, composite_tensor.CompositeTensor) and
not isinstance(v, dataset_ops.DatasetV2)): not isinstance(v, dataset_ops.DatasetV2)):
return True return True
return False return False
return all(_is_tensor_or_composite(v) for v in flat_inputs) def _is_tensor_or_composite(v):
if isinstance(v, (ops.Tensor, np.ndarray)):
return True
return _is_composite(v)
return (any(_is_composite(v) for v in flat_inputs) and
all(_is_tensor_or_composite(v) for v in flat_inputs))
def __init__(self, x, y=None, sample_weights=None, batch_size=None, def __init__(self, x, y=None, sample_weights=None, batch_size=None,
steps=None, shuffle=False, **kwargs): steps=None, shuffle=False, **kwargs):
super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs) super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs)
x = _process_numpy_inputs(x) x = _process_numpy_inputs(x)
y = _process_numpy_inputs(y) y = _process_numpy_inputs(y)
sample_weights = _process_numpy_inputs(sample_weights) sample_weights = _process_numpy_inputs(sample_weights)
@ -431,9 +592,8 @@ class KerasSequenceAdapter(DataAdapter):
ALL_ADAPTER_CLS = [ ALL_ADAPTER_CLS = [
ListsOfScalarsDataAdapter, ListsOfScalarsDataAdapter, TensorLikeDataAdapter, DatasetAdapter,
TensorLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter, GeneratorDataAdapter, KerasSequenceAdapter, CompositeTensorDataAdapter
KerasSequenceAdapter
] ]

View File

@ -18,12 +18,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
@ -106,11 +110,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
self.assertTrue(adapter.has_partial_batch()) self.assertTrue(adapter.has_partial_batch())
self.assertEqual(adapter.partial_batch_size(), 2) self.assertEqual(adapter.partial_batch_size(), 2)
@test_util.run_in_graph_and_eager_modes
def test_training_numpy(self): def test_training_numpy(self):
dataset = self.adapter_cls( if not context.executing_eagerly():
self.numpy_input, self.numpy_target, batch_size=5).get_dataset() return # Only test in eager.
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd') self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset) self.model.fit(self.numpy_input, self.numpy_target, batch_size=5)
def test_can_handle(self): def test_can_handle(self):
self.assertTrue(self.adapter_cls.can_handle(self.tensor_input)) self.assertTrue(self.adapter_cls.can_handle(self.tensor_input))
@ -121,11 +127,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
@test_util.run_in_graph_and_eager_modes
def test_training(self): def test_training(self):
dataset = self.adapter_cls( if not context.executing_eagerly():
self.tensor_input, self.tensor_target, batch_size=5).get_dataset() return # Only test EagerTensors.
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd') self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset) self.model.fit(self.tensor_input, self.tensor_target, batch_size=5)
def test_size(self): def test_size(self):
adapter = self.adapter_cls( adapter = self.adapter_cls(
@ -133,6 +141,39 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
self.assertEqual(adapter.get_size(), 10) self.assertEqual(adapter.get_size(), 10)
self.assertFalse(adapter.has_partial_batch()) self.assertFalse(adapter.has_partial_batch())
def test_shuffle_correctness(self):
with context.eager_mode():
num_samples = 100
batch_size = 32
x = np.arange(num_samples)
np.random.seed(99)
adapter = self.adapter_cls(
x, y=None, batch_size=batch_size, shuffle=True, epochs=2)
def _get_epoch(ds_iter):
ds_data = []
for _ in range(int(math.ceil(num_samples / batch_size))):
ds_data.append(next(ds_iter)[0].numpy())
return np.concatenate(ds_data)
ds_iter = iter(adapter.get_dataset())
# First epoch.
epoch_data = _get_epoch(ds_iter)
# Check that shuffling occurred.
self.assertNotAllClose(x, epoch_data)
# Check that each elements appears, and only once.
self.assertAllClose(x, np.sort(epoch_data))
# Second epoch.
second_epoch_data = _get_epoch(ds_iter)
# Check that shuffling occurred.
self.assertNotAllClose(x, second_epoch_data)
# Check that shuffling is different across epochs.
self.assertNotAllClose(epoch_data, second_epoch_data)
# Check that each elements appears, and only once.
self.assertAllClose(x, np.sort(second_epoch_data))
@parameterized.named_parameters( @parameterized.named_parameters(
('batch_size_5', 5, None, 5), ('batch_size_5', 5, None, 5),
('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence ('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence
@ -259,4 +300,5 @@ class KerasSequenceAdapterTest(DataAdapterTestBase):
if __name__ == '__main__': if __name__ == '__main__':
ops.enable_eager_execution()
test.main() test.main()

View File

@ -1783,24 +1783,25 @@ class Model(network.Network):
layers = super(Model, self).layers # Avoids the override in Sequential. layers = super(Model, self).layers # Avoids the override in Sequential.
if layers: if layers:
first_layer = layers[0] first_layer = layers[0]
# The per-replica static batch size.
static_batch_size = training_utils.get_static_batch_size(first_layer) static_batch_size = training_utils.get_static_batch_size(first_layer)
if static_batch_size is not None: if static_batch_size is not None:
split_batch_size = self._distribution_strategy and \
# Determine number of times the user-supplied batch size will be split.
if (self._distribution_strategy and
distributed_training_utils.global_batch_size_supported( distributed_training_utils.global_batch_size_supported(
self._distribution_strategy) self._distribution_strategy)):
if split_batch_size: num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync
num_replicas = self._distribution_strategy.num_replicas_in_sync else:
num_splits_for_ds = 1
# Check `batch_size` argument is consistent with InputLayer. # Check `batch_size` argument is consistent with InputLayer.
if batch_size is not None: if batch_size is not None:
if split_batch_size: if batch_size % num_splits_for_ds != 0:
if batch_size % num_replicas != 0: raise ValueError('The `batch_size` argument value {} cannot be '
raise ValueError('The `batch_size` argument value {} cannot be ' 'divisible by number of replicas {}'.format(
'divisible by number of replicas {}'.format( batch_size, num_splits_for_ds))
batch_size, num_replicas)) per_replica_batch_size = batch_size // num_splits_for_ds
per_replica_batch_size = batch_size // num_replicas
else:
per_replica_batch_size = batch_size
if per_replica_batch_size != static_batch_size: if per_replica_batch_size != static_batch_size:
raise ValueError('The `batch_size` argument value {} is ' raise ValueError('The `batch_size` argument value {} is '
@ -1814,23 +1815,23 @@ class Model(network.Network):
ds_batch_size = tensor_shape.as_dimension( ds_batch_size = tensor_shape.as_dimension(
nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
if ds_batch_size is not None: if ds_batch_size is not None:
if split_batch_size: if ds_batch_size % num_splits_for_ds != 0:
if ds_batch_size % num_replicas != 0: raise ValueError(
raise ValueError( 'The batch output shape of your `Dataset` {} '
'The batch output shape of your `Dataset` {} ' 'cannot be divisible by number of replicas {}'.format(
'cannot be divisible by number of replicas {}'.format( ds_batch_size, num_splits_for_ds))
ds_batch_size, num_replicas))
ds_batch_size = ds_batch_size // num_replicas
if ds_batch_size != static_batch_size: ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds
if ds_per_replica_batch_size != static_batch_size:
raise ValueError('The batch output shape of your `Dataset` is ' raise ValueError('The batch output shape of your `Dataset` is '
'{}, which is incompatible with the specified ' '{}, which is incompatible with the specified '
'batch size of your Input Layer: {}'.format( 'batch size of your Input Layer: {}'.format(
ds_batch_size, static_batch_size)) ds_per_replica_batch_size,
static_batch_size))
# Set inferred batch size from the InputLayer. # Set inferred batch size from the InputLayer.
if steps is None: if steps is None:
batch_size = static_batch_size batch_size = static_batch_size * num_splits_for_ds
if batch_size is None and steps is None: if batch_size is None and steps is None:
# Backwards compatibility # Backwards compatibility

View File

@ -49,7 +49,8 @@ _ADAPTER_FOR_VALIDATION_SPLIT = [data_adapter.TensorLikeDataAdapter]
# dataset/generate/sequence input will be peeked and processed by # dataset/generate/sequence input will be peeked and processed by
# model._standardize_user_data() # model._standardize_user_data()
_ADAPTER_FOR_STANDARDIZE_USER_DATA = [ _ADAPTER_FOR_STANDARDIZE_USER_DATA = [
data_adapter.TensorLikeDataAdapter, data_adapter.DatasetAdapter data_adapter.TensorLikeDataAdapter, data_adapter.DatasetAdapter,
data_adapter.CompositeTensorDataAdapter
] ]
@ -207,6 +208,7 @@ class Loop(training_utils.TrainingLoop):
x, x,
y, y,
batch_size=batch_size, batch_size=batch_size,
epochs=epochs,
sample_weights=sample_weight, sample_weights=sample_weight,
class_weights=class_weight, class_weights=class_weight,
validation_split=validation_split, validation_split=validation_split,
@ -247,11 +249,8 @@ class Loop(training_utils.TrainingLoop):
model, ModeKeys.TRAIN) model, ModeKeys.TRAIN)
training_data_iter = None training_data_iter = None
# Only recreate iterator when the data has a fixed length, which will be recreate_training_iterator = (
# fully consumed every epoch, or has a unknown length (dataset, generator) training_data_adapter.should_recreate_iterator(steps_per_epoch))
# and will be fully consumed (steps_per_epoch is None)
recreate_training_iterator = (training_data_adapter.get_size() is not None
or steps_per_epoch is None)
if do_validation: if do_validation:
if not validation_steps: if not validation_steps:
@ -474,11 +473,22 @@ def _get_distribution_strategy(model):
return strategy return strategy
def _process_training_inputs(model, x, y, batch_size=None, def _process_training_inputs(model,
sample_weights=None, class_weights=None, x,
steps_per_epoch=None, validation_split=0., y,
validation_data=None, validation_steps=None, batch_size=None,
shuffle=True, distribution_strategy=None): epochs=1,
sample_weights=None,
class_weights=None,
steps_per_epoch=None,
validation_split=0.,
validation_data=None,
validation_steps=None,
shuffle=True,
distribution_strategy=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False):
"""Process the data input for fit() with respect to validation_split.""" """Process the data input for fit() with respect to validation_split."""
if validation_split and 0. < validation_split < 1. and validation_data: if validation_split and 0. < validation_split < 1. and validation_data:
raise ValueError('validation_data and validation_split cannot be used ' raise ValueError('validation_data and validation_split cannot be used '
@ -508,19 +518,33 @@ def _process_training_inputs(model, x, y, batch_size=None,
val_x, val_y, val_x, val_y,
val_sample_weights) = training_utils.split_training_and_validation_data( val_sample_weights) = training_utils.split_training_and_validation_data(
x, y, sample_weights, validation_split) x, y, sample_weights, validation_split)
train_adapter = adapter_cls(x, y, batch_size=batch_size, train_adapter = adapter_cls(
sample_weights=sample_weights, shuffle=shuffle, x,
distribution_strategy=distribution_strategy) y,
batch_size=batch_size,
epochs=epochs,
sample_weights=sample_weights,
shuffle=shuffle,
distribution_strategy=distribution_strategy)
val_adapter = adapter_cls(val_x, val_y, val_adapter = adapter_cls(val_x, val_y,
sample_weights=val_sample_weights, sample_weights=val_sample_weights,
batch_size=batch_size, batch_size=batch_size,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
else: else:
train_adapter = _process_inputs(model, x, y, sample_weights=sample_weights, train_adapter = _process_inputs(
batch_size=batch_size, model,
class_weights=class_weights, x,
shuffle=shuffle, steps=steps_per_epoch, y,
distribution_strategy=distribution_strategy) sample_weights=sample_weights,
batch_size=batch_size,
epochs=epochs,
class_weights=class_weights,
shuffle=shuffle,
steps=steps_per_epoch,
distribution_strategy=distribution_strategy,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing)
val_adapter = None val_adapter = None
if validation_data: if validation_data:
(val_x, val_y, (val_x, val_y,
@ -543,9 +567,19 @@ def _process_training_inputs(model, x, y, batch_size=None,
return train_adapter, val_adapter return train_adapter, val_adapter
def _process_inputs(model, x, y, batch_size=None, sample_weights=None, def _process_inputs(model,
class_weights=None, shuffle=False, steps=None, x,
distribution_strategy=None): y,
batch_size=None,
epochs=1,
sample_weights=None,
class_weights=None,
shuffle=False,
steps=None,
distribution_strategy=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False):
"""Process the inputs for fit/eval/predict().""" """Process the inputs for fit/eval/predict()."""
adapter_cls = data_adapter.select_data_adapter(x, y) adapter_cls = data_adapter.select_data_adapter(x, y)
if adapter_cls in _ADAPTER_FOR_STANDARDIZE_USER_DATA: if adapter_cls in _ADAPTER_FOR_STANDARDIZE_USER_DATA:
@ -557,9 +591,18 @@ def _process_inputs(model, x, y, batch_size=None, sample_weights=None,
batch_size=batch_size, batch_size=batch_size,
check_steps=False, check_steps=False,
steps=steps) steps=steps)
adapter = adapter_cls(x, y, batch_size=batch_size, steps=steps, adapter = adapter_cls(
sample_weights=sample_weights, shuffle=shuffle, x,
distribution_strategy=distribution_strategy) y,
batch_size=batch_size,
epochs=epochs,
steps=steps,
sample_weights=sample_weights,
shuffle=shuffle,
distribution_strategy=distribution_strategy,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing)
# As a fallback for the data type that does not work with # As a fallback for the data type that does not work with
# _standardize_user_data, use the _prepare_model_with_inputs. # _standardize_user_data, use the _prepare_model_with_inputs.
if adapter_cls not in _ADAPTER_FOR_STANDARDIZE_USER_DATA: if adapter_cls not in _ADAPTER_FOR_STANDARDIZE_USER_DATA: