diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 5a1fbb69937..28973bbf1f5 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -162,6 +162,13 @@ class DataAdapter(object): """ raise NotImplementedError + def should_recreate_iterator(self, steps_per_epoch): + """Returns whether a new iterator should be created every epoch.""" + # Only recreate iterator when the data has a fixed length, which will be + # fully consumed every epoch, or has a unknown length (dataset, generator) + # and will be fully consumed (steps_per_epoch is None) + return self.get_size() is not None or steps_per_epoch is None + class TensorLikeDataAdapter(DataAdapter): """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" @@ -174,20 +181,174 @@ class TensorLikeDataAdapter(DataAdapter): if y is not None: flat_inputs += nest.flatten(y) - def _is_tensor_or_composite(v): + def _is_tensor(v): if isinstance(v, (ops.Tensor, np.ndarray)): return True + return False + + return all(_is_tensor(v) for v in flat_inputs) + + def __init__(self, + x, + y=None, + sample_weights=None, + batch_size=None, + epochs=1, + steps=None, + shuffle=False, + **kwargs): + super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs) + x = _process_numpy_inputs(x) + y = _process_numpy_inputs(y) + sample_weights = _process_numpy_inputs(sample_weights) + + # If sample_weights are not specified for an output use 1.0 as weights. + if sample_weights is not None and None in sample_weights: + weight = next(s for s in sample_weights if s is not None) + sample_weights = training_utils.list_to_tuple([ + array_ops.ones((weight.shape[0],)) if sw is None else sw + for sw in sample_weights + ]) + + if y is not None and sample_weights is not None: + inputs = (x, y, sample_weights) + elif y is not None: + # Sample weight is only needed for training, so if y is None, then + # sample_weight is ignored. + inputs = (x, y) + else: + inputs = (x,) + + num_samples = int(nest.flatten(x)[0].shape[0]) + + # If batch_size is not passed but steps is, calculate from the input data. + if steps and not batch_size: + batch_size = int(math.ceil(num_samples / steps)) + + if not batch_size: + raise ValueError( + "`batch_size` or `steps` is required for `Tensor` or `NumPy`" + " input data.") + + self._size = int(math.ceil(num_samples / batch_size)) + self._batch_size = batch_size + self._has_partial_batch = (self._size != (num_samples // batch_size)) + + self._partial_batch_size = None + if self._has_partial_batch: + self._partial_batch_size = ( + num_samples - (self._size - 1) * self._batch_size) + + # Vectorized version of shuffle. + # This is a performance improvement over using `from_tensor_slices`. + # The indices of the data are shuffled and batched, and these indices + # are then zipped with the data and used to extract a batch of the data + # at each step. The performance improvements here come from: + # 1. vectorized batch using gather + # 2. parallelized map + # 3. vectorized shuffle by using reshape and unbatch + # 4. disabled static optimizations + indices_ds = None + for _ in range(epochs): + indices = np.arange(num_samples) + if shuffle: + np.random.shuffle(indices) + + full_batch_indices = np.reshape( + indices[:(num_samples // batch_size) * batch_size], [-1, batch_size]) + partial_batch_indices = indices[(num_samples // batch_size) * batch_size:] + + epoch_indices_ds = dataset_ops.DatasetV2.from_tensors( + full_batch_indices).unbatch() + if partial_batch_indices.size: + epoch_indices_ds = epoch_indices_ds.concatenate( + dataset_ops.DatasetV2.from_tensors(partial_batch_indices)) + + if indices_ds is None: + indices_ds = epoch_indices_ds + else: + indices_ds = indices_ds.concatenate(epoch_indices_ds) + + data_ds = dataset_ops.DatasetV2.from_tensors(inputs).repeat() + dataset = dataset_ops.DatasetV2.zip((data_ds, indices_ds)) + + def _nested_grab_batch(data, indices): + """Grabs batches of Tensors in `data` based on `indices`.""" + + def _grab_batch(x): + """Grabs a batch of `x`.""" + x_batch = array_ops.gather(x, indices) + x_shape = x.shape.as_list() + + if not self._has_partial_batch: + # Recover the batch shape info. + x_shape[0] = self._batch_size + x_batch.set_shape(x_shape) + elif self._partial_batch_size >= num_samples: + # Only one batch per epoch. + x_shape[0] = self._partial_batch_size + x_batch.set_shape(x_shape) + return x_batch + + return nest.map_structure(_grab_batch, data) + + dataset = dataset.map( + _nested_grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE) + + # Default optimizations are disabled to avoid the overhead of (unnecessary) + # input pipeline graph serialization and deserialization + options = dataset_ops.Options() + options.experimental_optimization.apply_default_optimizations = False + dataset = dataset.with_options(options) + self._dataset = dataset + + def get_dataset(self): + return self._dataset + + def get_size(self): + return self._size + + def batch_size(self): + return self._batch_size + + def has_partial_batch(self): + return self._has_partial_batch + + def partial_batch_size(self): + return self._partial_batch_size + + def should_recreate_iterator(self, _): + # An infinite dataset is always created here. + return False + + +class CompositeTensorDataAdapter(DataAdapter): + """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" + + @staticmethod + def can_handle(x, y=None): + flat_inputs = nest.flatten(x) + if y is not None: + flat_inputs += nest.flatten(y) + + def _is_composite(v): # Dataset inherits from CompositeTensor but shouldn't be handled here. if (isinstance(v, composite_tensor.CompositeTensor) and not isinstance(v, dataset_ops.DatasetV2)): return True return False - return all(_is_tensor_or_composite(v) for v in flat_inputs) + def _is_tensor_or_composite(v): + if isinstance(v, (ops.Tensor, np.ndarray)): + return True + return _is_composite(v) + + return (any(_is_composite(v) for v in flat_inputs) and + all(_is_tensor_or_composite(v) for v in flat_inputs)) def __init__(self, x, y=None, sample_weights=None, batch_size=None, steps=None, shuffle=False, **kwargs): - super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs) + super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs) x = _process_numpy_inputs(x) y = _process_numpy_inputs(y) sample_weights = _process_numpy_inputs(sample_weights) @@ -431,9 +592,8 @@ class KerasSequenceAdapter(DataAdapter): ALL_ADAPTER_CLS = [ - ListsOfScalarsDataAdapter, - TensorLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter, - KerasSequenceAdapter + ListsOfScalarsDataAdapter, TensorLikeDataAdapter, DatasetAdapter, + GeneratorDataAdapter, KerasSequenceAdapter, CompositeTensorDataAdapter ] diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 8f5fe16acdc..525f99217b4 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -18,11 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + from absl.testing import parameterized import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import data_adapter @@ -133,6 +136,39 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase): self.assertEqual(adapter.get_size(), 10) self.assertFalse(adapter.has_partial_batch()) + def test_shuffle_correctness(self): + with context.eager_mode(): + num_samples = 100 + batch_size = 32 + x = np.arange(num_samples) + np.random.seed(99) + adapter = self.adapter_cls( + x, y=None, batch_size=batch_size, shuffle=True, epochs=2) + + def _get_epoch(ds_iter): + ds_data = [] + for _ in range(int(math.ceil(num_samples / batch_size))): + ds_data.append(next(ds_iter)[0].numpy()) + return np.concatenate(ds_data) + + ds_iter = iter(adapter.get_dataset()) + + # First epoch. + epoch_data = _get_epoch(ds_iter) + # Check that shuffling occurred. + self.assertNotAllClose(x, epoch_data) + # Check that each elements appears, and only once. + self.assertAllClose(x, np.sort(epoch_data)) + + # Second epoch. + second_epoch_data = _get_epoch(ds_iter) + # Check that shuffling occurred. + self.assertNotAllClose(x, second_epoch_data) + # Check that shuffling is different across epochs. + self.assertNotAllClose(epoch_data, second_epoch_data) + # Check that each elements appears, and only once. + self.assertAllClose(x, np.sort(second_epoch_data)) + @parameterized.named_parameters( ('batch_size_5', 5, None, 5), ('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 986198800b2..d7ee4341d4d 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1783,24 +1783,25 @@ class Model(network.Network): layers = super(Model, self).layers # Avoids the override in Sequential. if layers: first_layer = layers[0] + # The per-replica static batch size. static_batch_size = training_utils.get_static_batch_size(first_layer) if static_batch_size is not None: - split_batch_size = self._distribution_strategy and \ + + # Determine number of times the user-supplied batch size will be split. + if (self._distribution_strategy and distributed_training_utils.global_batch_size_supported( - self._distribution_strategy) - if split_batch_size: - num_replicas = self._distribution_strategy.num_replicas_in_sync + self._distribution_strategy)): + num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync + else: + num_splits_for_ds = 1 # Check `batch_size` argument is consistent with InputLayer. if batch_size is not None: - if split_batch_size: - if batch_size % num_replicas != 0: - raise ValueError('The `batch_size` argument value {} cannot be ' - 'divisible by number of replicas {}'.format( - batch_size, num_replicas)) - per_replica_batch_size = batch_size // num_replicas - else: - per_replica_batch_size = batch_size + if batch_size % num_splits_for_ds != 0: + raise ValueError('The `batch_size` argument value {} cannot be ' + 'divisible by number of replicas {}'.format( + batch_size, num_splits_for_ds)) + per_replica_batch_size = batch_size // num_splits_for_ds if per_replica_batch_size != static_batch_size: raise ValueError('The `batch_size` argument value {} is ' @@ -1814,23 +1815,23 @@ class Model(network.Network): ds_batch_size = tensor_shape.as_dimension( nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value if ds_batch_size is not None: - if split_batch_size: - if ds_batch_size % num_replicas != 0: - raise ValueError( - 'The batch output shape of your `Dataset` {} ' - 'cannot be divisible by number of replicas {}'.format( - ds_batch_size, num_replicas)) - ds_batch_size = ds_batch_size // num_replicas + if ds_batch_size % num_splits_for_ds != 0: + raise ValueError( + 'The batch output shape of your `Dataset` {} ' + 'cannot be divisible by number of replicas {}'.format( + ds_batch_size, num_splits_for_ds)) - if ds_batch_size != static_batch_size: + ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds + if ds_per_replica_batch_size != static_batch_size: raise ValueError('The batch output shape of your `Dataset` is ' '{}, which is incompatible with the specified ' 'batch size of your Input Layer: {}'.format( - ds_batch_size, static_batch_size)) + ds_per_replica_batch_size, + static_batch_size)) # Set inferred batch size from the InputLayer. if steps is None: - batch_size = static_batch_size + batch_size = static_batch_size * num_splits_for_ds if batch_size is None and steps is None: # Backwards compatibility diff --git a/tensorflow/python/keras/engine/training_v2.py b/tensorflow/python/keras/engine/training_v2.py index a576d69ec17..6c4fa930ca5 100644 --- a/tensorflow/python/keras/engine/training_v2.py +++ b/tensorflow/python/keras/engine/training_v2.py @@ -49,7 +49,8 @@ _ADAPTER_FOR_VALIDATION_SPLIT = [data_adapter.TensorLikeDataAdapter] # dataset/generate/sequence input will be peeked and processed by # model._standardize_user_data() _ADAPTER_FOR_STANDARDIZE_USER_DATA = [ - data_adapter.TensorLikeDataAdapter, data_adapter.DatasetAdapter + data_adapter.TensorLikeDataAdapter, data_adapter.DatasetAdapter, + data_adapter.CompositeTensorDataAdapter ] @@ -207,6 +208,7 @@ class Loop(training_utils.TrainingLoop): x, y, batch_size=batch_size, + epochs=epochs, sample_weights=sample_weight, class_weights=class_weight, validation_split=validation_split, @@ -247,11 +249,8 @@ class Loop(training_utils.TrainingLoop): model, ModeKeys.TRAIN) training_data_iter = None - # Only recreate iterator when the data has a fixed length, which will be - # fully consumed every epoch, or has a unknown length (dataset, generator) - # and will be fully consumed (steps_per_epoch is None) - recreate_training_iterator = (training_data_adapter.get_size() is not None - or steps_per_epoch is None) + recreate_training_iterator = ( + training_data_adapter.should_recreate_iterator(steps_per_epoch)) if do_validation: if not validation_steps: @@ -474,11 +473,22 @@ def _get_distribution_strategy(model): return strategy -def _process_training_inputs(model, x, y, batch_size=None, - sample_weights=None, class_weights=None, - steps_per_epoch=None, validation_split=0., - validation_data=None, validation_steps=None, - shuffle=True, distribution_strategy=None): +def _process_training_inputs(model, + x, + y, + batch_size=None, + epochs=1, + sample_weights=None, + class_weights=None, + steps_per_epoch=None, + validation_split=0., + validation_data=None, + validation_steps=None, + shuffle=True, + distribution_strategy=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False): """Process the data input for fit() with respect to validation_split.""" if validation_split and 0. < validation_split < 1. and validation_data: raise ValueError('validation_data and validation_split cannot be used ' @@ -508,19 +518,33 @@ def _process_training_inputs(model, x, y, batch_size=None, val_x, val_y, val_sample_weights) = training_utils.split_training_and_validation_data( x, y, sample_weights, validation_split) - train_adapter = adapter_cls(x, y, batch_size=batch_size, - sample_weights=sample_weights, shuffle=shuffle, - distribution_strategy=distribution_strategy) + train_adapter = adapter_cls( + x, + y, + batch_size=batch_size, + epochs=epochs, + sample_weights=sample_weights, + shuffle=shuffle, + distribution_strategy=distribution_strategy) val_adapter = adapter_cls(val_x, val_y, sample_weights=val_sample_weights, batch_size=batch_size, distribution_strategy=distribution_strategy) else: - train_adapter = _process_inputs(model, x, y, sample_weights=sample_weights, - batch_size=batch_size, - class_weights=class_weights, - shuffle=shuffle, steps=steps_per_epoch, - distribution_strategy=distribution_strategy) + train_adapter = _process_inputs( + model, + x, + y, + sample_weights=sample_weights, + batch_size=batch_size, + epochs=epochs, + class_weights=class_weights, + shuffle=shuffle, + steps=steps_per_epoch, + distribution_strategy=distribution_strategy, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) val_adapter = None if validation_data: (val_x, val_y, @@ -543,9 +567,19 @@ def _process_training_inputs(model, x, y, batch_size=None, return train_adapter, val_adapter -def _process_inputs(model, x, y, batch_size=None, sample_weights=None, - class_weights=None, shuffle=False, steps=None, - distribution_strategy=None): +def _process_inputs(model, + x, + y, + batch_size=None, + epochs=1, + sample_weights=None, + class_weights=None, + shuffle=False, + steps=None, + distribution_strategy=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False): """Process the inputs for fit/eval/predict().""" adapter_cls = data_adapter.select_data_adapter(x, y) if adapter_cls in _ADAPTER_FOR_STANDARDIZE_USER_DATA: @@ -557,9 +591,18 @@ def _process_inputs(model, x, y, batch_size=None, sample_weights=None, batch_size=batch_size, check_steps=False, steps=steps) - adapter = adapter_cls(x, y, batch_size=batch_size, steps=steps, - sample_weights=sample_weights, shuffle=shuffle, - distribution_strategy=distribution_strategy) + adapter = adapter_cls( + x, + y, + batch_size=batch_size, + epochs=epochs, + steps=steps, + sample_weights=sample_weights, + shuffle=shuffle, + distribution_strategy=distribution_strategy, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) # As a fallback for the data type that does not work with # _standardize_user_data, use the _prepare_model_with_inputs. if adapter_cls not in _ADAPTER_FOR_STANDARDIZE_USER_DATA: