Merge pull request #31667 from tensorflow/ggadde-cp10
Improve NumPy to Dataset performance with vectorized shuffling.
This commit is contained in:
commit
d96ab53c6e
@ -162,6 +162,13 @@ class DataAdapter(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def should_recreate_iterator(self, steps_per_epoch):
|
||||||
|
"""Returns whether a new iterator should be created every epoch."""
|
||||||
|
# Only recreate iterator when the data has a fixed length, which will be
|
||||||
|
# fully consumed every epoch, or has a unknown length (dataset, generator)
|
||||||
|
# and will be fully consumed (steps_per_epoch is None)
|
||||||
|
return self.get_size() is not None or steps_per_epoch is None
|
||||||
|
|
||||||
|
|
||||||
class TensorLikeDataAdapter(DataAdapter):
|
class TensorLikeDataAdapter(DataAdapter):
|
||||||
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
|
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
|
||||||
@ -174,20 +181,174 @@ class TensorLikeDataAdapter(DataAdapter):
|
|||||||
if y is not None:
|
if y is not None:
|
||||||
flat_inputs += nest.flatten(y)
|
flat_inputs += nest.flatten(y)
|
||||||
|
|
||||||
def _is_tensor_or_composite(v):
|
def _is_tensor(v):
|
||||||
if isinstance(v, (ops.Tensor, np.ndarray)):
|
if isinstance(v, (ops.Tensor, np.ndarray)):
|
||||||
return True
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return all(_is_tensor(v) for v in flat_inputs)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
x,
|
||||||
|
y=None,
|
||||||
|
sample_weights=None,
|
||||||
|
batch_size=None,
|
||||||
|
epochs=1,
|
||||||
|
steps=None,
|
||||||
|
shuffle=False,
|
||||||
|
**kwargs):
|
||||||
|
super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs)
|
||||||
|
x = _process_numpy_inputs(x)
|
||||||
|
y = _process_numpy_inputs(y)
|
||||||
|
sample_weights = _process_numpy_inputs(sample_weights)
|
||||||
|
|
||||||
|
# If sample_weights are not specified for an output use 1.0 as weights.
|
||||||
|
if sample_weights is not None and None in sample_weights:
|
||||||
|
weight = next(s for s in sample_weights if s is not None)
|
||||||
|
sample_weights = training_utils.list_to_tuple([
|
||||||
|
array_ops.ones((weight.shape[0],)) if sw is None else sw
|
||||||
|
for sw in sample_weights
|
||||||
|
])
|
||||||
|
|
||||||
|
if y is not None and sample_weights is not None:
|
||||||
|
inputs = (x, y, sample_weights)
|
||||||
|
elif y is not None:
|
||||||
|
# Sample weight is only needed for training, so if y is None, then
|
||||||
|
# sample_weight is ignored.
|
||||||
|
inputs = (x, y)
|
||||||
|
else:
|
||||||
|
inputs = (x,)
|
||||||
|
|
||||||
|
num_samples = int(nest.flatten(x)[0].shape[0])
|
||||||
|
|
||||||
|
# If batch_size is not passed but steps is, calculate from the input data.
|
||||||
|
if steps and not batch_size:
|
||||||
|
batch_size = int(math.ceil(num_samples / steps))
|
||||||
|
|
||||||
|
if not batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
"`batch_size` or `steps` is required for `Tensor` or `NumPy`"
|
||||||
|
" input data.")
|
||||||
|
|
||||||
|
self._size = int(math.ceil(num_samples / batch_size))
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._has_partial_batch = (self._size != (num_samples // batch_size))
|
||||||
|
|
||||||
|
self._partial_batch_size = None
|
||||||
|
if self._has_partial_batch:
|
||||||
|
self._partial_batch_size = (
|
||||||
|
num_samples - (self._size - 1) * self._batch_size)
|
||||||
|
|
||||||
|
# Vectorized version of shuffle.
|
||||||
|
# This is a performance improvement over using `from_tensor_slices`.
|
||||||
|
# The indices of the data are shuffled and batched, and these indices
|
||||||
|
# are then zipped with the data and used to extract a batch of the data
|
||||||
|
# at each step. The performance improvements here come from:
|
||||||
|
# 1. vectorized batch using gather
|
||||||
|
# 2. parallelized map
|
||||||
|
# 3. vectorized shuffle by using reshape and unbatch
|
||||||
|
# 4. disabled static optimizations
|
||||||
|
indices_list = []
|
||||||
|
for _ in range(epochs):
|
||||||
|
indices = np.arange(num_samples)
|
||||||
|
if shuffle:
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
full_batch_indices = np.reshape(
|
||||||
|
indices[:(num_samples // batch_size) * batch_size], [-1, batch_size])
|
||||||
|
partial_batch_indices = indices[(num_samples // batch_size) * batch_size:]
|
||||||
|
|
||||||
|
epoch_indices_ds = dataset_ops.DatasetV2.from_tensors(
|
||||||
|
full_batch_indices).unbatch()
|
||||||
|
if partial_batch_indices.size:
|
||||||
|
epoch_indices_ds = epoch_indices_ds.concatenate(
|
||||||
|
dataset_ops.DatasetV2.from_tensors(partial_batch_indices))
|
||||||
|
|
||||||
|
indices_list.append(epoch_indices_ds)
|
||||||
|
|
||||||
|
indices_ds = dataset_ops.DatasetV2.from_tensor_slices(
|
||||||
|
indices_list).flat_map(lambda x: x)
|
||||||
|
|
||||||
|
data_ds = dataset_ops.DatasetV2.from_tensors(inputs).repeat()
|
||||||
|
dataset = dataset_ops.DatasetV2.zip((data_ds, indices_ds))
|
||||||
|
|
||||||
|
def _nested_grab_batch(data, indices):
|
||||||
|
"""Grabs batches of Tensors in `data` based on `indices`."""
|
||||||
|
|
||||||
|
def _grab_batch(x):
|
||||||
|
"""Grabs a batch of `x`."""
|
||||||
|
x_batch = array_ops.gather(x, indices)
|
||||||
|
x_shape = x.shape.as_list()
|
||||||
|
|
||||||
|
if not self._has_partial_batch:
|
||||||
|
# Recover the batch shape info.
|
||||||
|
x_shape[0] = self._batch_size
|
||||||
|
x_batch.set_shape(x_shape)
|
||||||
|
elif self._partial_batch_size >= num_samples:
|
||||||
|
# Only one batch per epoch.
|
||||||
|
x_shape[0] = self._partial_batch_size
|
||||||
|
x_batch.set_shape(x_shape)
|
||||||
|
return x_batch
|
||||||
|
|
||||||
|
return nest.map_structure(_grab_batch, data)
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
_nested_grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||||
|
|
||||||
|
# Default optimizations are disabled to avoid the overhead of (unnecessary)
|
||||||
|
# input pipeline graph serialization and deserialization
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_optimization.apply_default_optimizations = False
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
self._dataset = dataset
|
||||||
|
|
||||||
|
def get_dataset(self):
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
def get_size(self):
|
||||||
|
return self._size
|
||||||
|
|
||||||
|
def batch_size(self):
|
||||||
|
return self._batch_size
|
||||||
|
|
||||||
|
def has_partial_batch(self):
|
||||||
|
return self._has_partial_batch
|
||||||
|
|
||||||
|
def partial_batch_size(self):
|
||||||
|
return self._partial_batch_size
|
||||||
|
|
||||||
|
def should_recreate_iterator(self, _):
|
||||||
|
# An infinite dataset is always created here.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeTensorDataAdapter(DataAdapter):
|
||||||
|
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def can_handle(x, y=None):
|
||||||
|
flat_inputs = nest.flatten(x)
|
||||||
|
if y is not None:
|
||||||
|
flat_inputs += nest.flatten(y)
|
||||||
|
|
||||||
|
def _is_composite(v):
|
||||||
# Dataset inherits from CompositeTensor but shouldn't be handled here.
|
# Dataset inherits from CompositeTensor but shouldn't be handled here.
|
||||||
if (isinstance(v, composite_tensor.CompositeTensor) and
|
if (isinstance(v, composite_tensor.CompositeTensor) and
|
||||||
not isinstance(v, dataset_ops.DatasetV2)):
|
not isinstance(v, dataset_ops.DatasetV2)):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return all(_is_tensor_or_composite(v) for v in flat_inputs)
|
def _is_tensor_or_composite(v):
|
||||||
|
if isinstance(v, (ops.Tensor, np.ndarray)):
|
||||||
|
return True
|
||||||
|
return _is_composite(v)
|
||||||
|
|
||||||
|
return (any(_is_composite(v) for v in flat_inputs) and
|
||||||
|
all(_is_tensor_or_composite(v) for v in flat_inputs))
|
||||||
|
|
||||||
def __init__(self, x, y=None, sample_weights=None, batch_size=None,
|
def __init__(self, x, y=None, sample_weights=None, batch_size=None,
|
||||||
steps=None, shuffle=False, **kwargs):
|
steps=None, shuffle=False, **kwargs):
|
||||||
super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs)
|
super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs)
|
||||||
x = _process_numpy_inputs(x)
|
x = _process_numpy_inputs(x)
|
||||||
y = _process_numpy_inputs(y)
|
y = _process_numpy_inputs(y)
|
||||||
sample_weights = _process_numpy_inputs(sample_weights)
|
sample_weights = _process_numpy_inputs(sample_weights)
|
||||||
@ -431,9 +592,8 @@ class KerasSequenceAdapter(DataAdapter):
|
|||||||
|
|
||||||
|
|
||||||
ALL_ADAPTER_CLS = [
|
ALL_ADAPTER_CLS = [
|
||||||
ListsOfScalarsDataAdapter,
|
ListsOfScalarsDataAdapter, TensorLikeDataAdapter, DatasetAdapter,
|
||||||
TensorLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter,
|
GeneratorDataAdapter, KerasSequenceAdapter, CompositeTensorDataAdapter
|
||||||
KerasSequenceAdapter
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,12 +18,16 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.engine import data_adapter
|
from tensorflow.python.keras.engine import data_adapter
|
||||||
from tensorflow.python.keras.utils import data_utils
|
from tensorflow.python.keras.utils import data_utils
|
||||||
@ -106,11 +110,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
|
|||||||
self.assertTrue(adapter.has_partial_batch())
|
self.assertTrue(adapter.has_partial_batch())
|
||||||
self.assertEqual(adapter.partial_batch_size(), 2)
|
self.assertEqual(adapter.partial_batch_size(), 2)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_training_numpy(self):
|
def test_training_numpy(self):
|
||||||
dataset = self.adapter_cls(
|
if not context.executing_eagerly():
|
||||||
self.numpy_input, self.numpy_target, batch_size=5).get_dataset()
|
return # Only test in eager.
|
||||||
|
|
||||||
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
|
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
|
||||||
self.model.fit(dataset)
|
self.model.fit(self.numpy_input, self.numpy_target, batch_size=5)
|
||||||
|
|
||||||
def test_can_handle(self):
|
def test_can_handle(self):
|
||||||
self.assertTrue(self.adapter_cls.can_handle(self.tensor_input))
|
self.assertTrue(self.adapter_cls.can_handle(self.tensor_input))
|
||||||
@ -121,11 +127,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
|
|||||||
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
|
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
|
||||||
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
|
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_training(self):
|
def test_training(self):
|
||||||
dataset = self.adapter_cls(
|
if not context.executing_eagerly():
|
||||||
self.tensor_input, self.tensor_target, batch_size=5).get_dataset()
|
return # Only test EagerTensors.
|
||||||
|
|
||||||
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
|
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
|
||||||
self.model.fit(dataset)
|
self.model.fit(self.tensor_input, self.tensor_target, batch_size=5)
|
||||||
|
|
||||||
def test_size(self):
|
def test_size(self):
|
||||||
adapter = self.adapter_cls(
|
adapter = self.adapter_cls(
|
||||||
@ -133,6 +141,39 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
|
|||||||
self.assertEqual(adapter.get_size(), 10)
|
self.assertEqual(adapter.get_size(), 10)
|
||||||
self.assertFalse(adapter.has_partial_batch())
|
self.assertFalse(adapter.has_partial_batch())
|
||||||
|
|
||||||
|
def test_shuffle_correctness(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
num_samples = 100
|
||||||
|
batch_size = 32
|
||||||
|
x = np.arange(num_samples)
|
||||||
|
np.random.seed(99)
|
||||||
|
adapter = self.adapter_cls(
|
||||||
|
x, y=None, batch_size=batch_size, shuffle=True, epochs=2)
|
||||||
|
|
||||||
|
def _get_epoch(ds_iter):
|
||||||
|
ds_data = []
|
||||||
|
for _ in range(int(math.ceil(num_samples / batch_size))):
|
||||||
|
ds_data.append(next(ds_iter)[0].numpy())
|
||||||
|
return np.concatenate(ds_data)
|
||||||
|
|
||||||
|
ds_iter = iter(adapter.get_dataset())
|
||||||
|
|
||||||
|
# First epoch.
|
||||||
|
epoch_data = _get_epoch(ds_iter)
|
||||||
|
# Check that shuffling occurred.
|
||||||
|
self.assertNotAllClose(x, epoch_data)
|
||||||
|
# Check that each elements appears, and only once.
|
||||||
|
self.assertAllClose(x, np.sort(epoch_data))
|
||||||
|
|
||||||
|
# Second epoch.
|
||||||
|
second_epoch_data = _get_epoch(ds_iter)
|
||||||
|
# Check that shuffling occurred.
|
||||||
|
self.assertNotAllClose(x, second_epoch_data)
|
||||||
|
# Check that shuffling is different across epochs.
|
||||||
|
self.assertNotAllClose(epoch_data, second_epoch_data)
|
||||||
|
# Check that each elements appears, and only once.
|
||||||
|
self.assertAllClose(x, np.sort(second_epoch_data))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('batch_size_5', 5, None, 5),
|
('batch_size_5', 5, None, 5),
|
||||||
('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence
|
('batch_size_50', 50, 4, 50), # Sanity check: batch_size takes precedence
|
||||||
@ -259,4 +300,5 @@ class KerasSequenceAdapterTest(DataAdapterTestBase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
ops.enable_eager_execution()
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -1783,24 +1783,25 @@ class Model(network.Network):
|
|||||||
layers = super(Model, self).layers # Avoids the override in Sequential.
|
layers = super(Model, self).layers # Avoids the override in Sequential.
|
||||||
if layers:
|
if layers:
|
||||||
first_layer = layers[0]
|
first_layer = layers[0]
|
||||||
|
# The per-replica static batch size.
|
||||||
static_batch_size = training_utils.get_static_batch_size(first_layer)
|
static_batch_size = training_utils.get_static_batch_size(first_layer)
|
||||||
if static_batch_size is not None:
|
if static_batch_size is not None:
|
||||||
split_batch_size = self._distribution_strategy and \
|
|
||||||
|
# Determine number of times the user-supplied batch size will be split.
|
||||||
|
if (self._distribution_strategy and
|
||||||
distributed_training_utils.global_batch_size_supported(
|
distributed_training_utils.global_batch_size_supported(
|
||||||
self._distribution_strategy)
|
self._distribution_strategy)):
|
||||||
if split_batch_size:
|
num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync
|
||||||
num_replicas = self._distribution_strategy.num_replicas_in_sync
|
else:
|
||||||
|
num_splits_for_ds = 1
|
||||||
|
|
||||||
# Check `batch_size` argument is consistent with InputLayer.
|
# Check `batch_size` argument is consistent with InputLayer.
|
||||||
if batch_size is not None:
|
if batch_size is not None:
|
||||||
if split_batch_size:
|
if batch_size % num_splits_for_ds != 0:
|
||||||
if batch_size % num_replicas != 0:
|
raise ValueError('The `batch_size` argument value {} cannot be '
|
||||||
raise ValueError('The `batch_size` argument value {} cannot be '
|
'divisible by number of replicas {}'.format(
|
||||||
'divisible by number of replicas {}'.format(
|
batch_size, num_splits_for_ds))
|
||||||
batch_size, num_replicas))
|
per_replica_batch_size = batch_size // num_splits_for_ds
|
||||||
per_replica_batch_size = batch_size // num_replicas
|
|
||||||
else:
|
|
||||||
per_replica_batch_size = batch_size
|
|
||||||
|
|
||||||
if per_replica_batch_size != static_batch_size:
|
if per_replica_batch_size != static_batch_size:
|
||||||
raise ValueError('The `batch_size` argument value {} is '
|
raise ValueError('The `batch_size` argument value {} is '
|
||||||
@ -1814,23 +1815,23 @@ class Model(network.Network):
|
|||||||
ds_batch_size = tensor_shape.as_dimension(
|
ds_batch_size = tensor_shape.as_dimension(
|
||||||
nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
|
nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
|
||||||
if ds_batch_size is not None:
|
if ds_batch_size is not None:
|
||||||
if split_batch_size:
|
if ds_batch_size % num_splits_for_ds != 0:
|
||||||
if ds_batch_size % num_replicas != 0:
|
raise ValueError(
|
||||||
raise ValueError(
|
'The batch output shape of your `Dataset` {} '
|
||||||
'The batch output shape of your `Dataset` {} '
|
'cannot be divisible by number of replicas {}'.format(
|
||||||
'cannot be divisible by number of replicas {}'.format(
|
ds_batch_size, num_splits_for_ds))
|
||||||
ds_batch_size, num_replicas))
|
|
||||||
ds_batch_size = ds_batch_size // num_replicas
|
|
||||||
|
|
||||||
if ds_batch_size != static_batch_size:
|
ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds
|
||||||
|
if ds_per_replica_batch_size != static_batch_size:
|
||||||
raise ValueError('The batch output shape of your `Dataset` is '
|
raise ValueError('The batch output shape of your `Dataset` is '
|
||||||
'{}, which is incompatible with the specified '
|
'{}, which is incompatible with the specified '
|
||||||
'batch size of your Input Layer: {}'.format(
|
'batch size of your Input Layer: {}'.format(
|
||||||
ds_batch_size, static_batch_size))
|
ds_per_replica_batch_size,
|
||||||
|
static_batch_size))
|
||||||
|
|
||||||
# Set inferred batch size from the InputLayer.
|
# Set inferred batch size from the InputLayer.
|
||||||
if steps is None:
|
if steps is None:
|
||||||
batch_size = static_batch_size
|
batch_size = static_batch_size * num_splits_for_ds
|
||||||
|
|
||||||
if batch_size is None and steps is None:
|
if batch_size is None and steps is None:
|
||||||
# Backwards compatibility
|
# Backwards compatibility
|
||||||
|
@ -49,7 +49,8 @@ _ADAPTER_FOR_VALIDATION_SPLIT = [data_adapter.TensorLikeDataAdapter]
|
|||||||
# dataset/generate/sequence input will be peeked and processed by
|
# dataset/generate/sequence input will be peeked and processed by
|
||||||
# model._standardize_user_data()
|
# model._standardize_user_data()
|
||||||
_ADAPTER_FOR_STANDARDIZE_USER_DATA = [
|
_ADAPTER_FOR_STANDARDIZE_USER_DATA = [
|
||||||
data_adapter.TensorLikeDataAdapter, data_adapter.DatasetAdapter
|
data_adapter.TensorLikeDataAdapter, data_adapter.DatasetAdapter,
|
||||||
|
data_adapter.CompositeTensorDataAdapter
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -207,6 +208,7 @@ class Loop(training_utils.TrainingLoop):
|
|||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
epochs=epochs,
|
||||||
sample_weights=sample_weight,
|
sample_weights=sample_weight,
|
||||||
class_weights=class_weight,
|
class_weights=class_weight,
|
||||||
validation_split=validation_split,
|
validation_split=validation_split,
|
||||||
@ -247,11 +249,8 @@ class Loop(training_utils.TrainingLoop):
|
|||||||
model, ModeKeys.TRAIN)
|
model, ModeKeys.TRAIN)
|
||||||
|
|
||||||
training_data_iter = None
|
training_data_iter = None
|
||||||
# Only recreate iterator when the data has a fixed length, which will be
|
recreate_training_iterator = (
|
||||||
# fully consumed every epoch, or has a unknown length (dataset, generator)
|
training_data_adapter.should_recreate_iterator(steps_per_epoch))
|
||||||
# and will be fully consumed (steps_per_epoch is None)
|
|
||||||
recreate_training_iterator = (training_data_adapter.get_size() is not None
|
|
||||||
or steps_per_epoch is None)
|
|
||||||
|
|
||||||
if do_validation:
|
if do_validation:
|
||||||
if not validation_steps:
|
if not validation_steps:
|
||||||
@ -474,11 +473,22 @@ def _get_distribution_strategy(model):
|
|||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
def _process_training_inputs(model, x, y, batch_size=None,
|
def _process_training_inputs(model,
|
||||||
sample_weights=None, class_weights=None,
|
x,
|
||||||
steps_per_epoch=None, validation_split=0.,
|
y,
|
||||||
validation_data=None, validation_steps=None,
|
batch_size=None,
|
||||||
shuffle=True, distribution_strategy=None):
|
epochs=1,
|
||||||
|
sample_weights=None,
|
||||||
|
class_weights=None,
|
||||||
|
steps_per_epoch=None,
|
||||||
|
validation_split=0.,
|
||||||
|
validation_data=None,
|
||||||
|
validation_steps=None,
|
||||||
|
shuffle=True,
|
||||||
|
distribution_strategy=None,
|
||||||
|
max_queue_size=10,
|
||||||
|
workers=1,
|
||||||
|
use_multiprocessing=False):
|
||||||
"""Process the data input for fit() with respect to validation_split."""
|
"""Process the data input for fit() with respect to validation_split."""
|
||||||
if validation_split and 0. < validation_split < 1. and validation_data:
|
if validation_split and 0. < validation_split < 1. and validation_data:
|
||||||
raise ValueError('validation_data and validation_split cannot be used '
|
raise ValueError('validation_data and validation_split cannot be used '
|
||||||
@ -508,19 +518,33 @@ def _process_training_inputs(model, x, y, batch_size=None,
|
|||||||
val_x, val_y,
|
val_x, val_y,
|
||||||
val_sample_weights) = training_utils.split_training_and_validation_data(
|
val_sample_weights) = training_utils.split_training_and_validation_data(
|
||||||
x, y, sample_weights, validation_split)
|
x, y, sample_weights, validation_split)
|
||||||
train_adapter = adapter_cls(x, y, batch_size=batch_size,
|
train_adapter = adapter_cls(
|
||||||
sample_weights=sample_weights, shuffle=shuffle,
|
x,
|
||||||
distribution_strategy=distribution_strategy)
|
y,
|
||||||
|
batch_size=batch_size,
|
||||||
|
epochs=epochs,
|
||||||
|
sample_weights=sample_weights,
|
||||||
|
shuffle=shuffle,
|
||||||
|
distribution_strategy=distribution_strategy)
|
||||||
val_adapter = adapter_cls(val_x, val_y,
|
val_adapter = adapter_cls(val_x, val_y,
|
||||||
sample_weights=val_sample_weights,
|
sample_weights=val_sample_weights,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
distribution_strategy=distribution_strategy)
|
distribution_strategy=distribution_strategy)
|
||||||
else:
|
else:
|
||||||
train_adapter = _process_inputs(model, x, y, sample_weights=sample_weights,
|
train_adapter = _process_inputs(
|
||||||
batch_size=batch_size,
|
model,
|
||||||
class_weights=class_weights,
|
x,
|
||||||
shuffle=shuffle, steps=steps_per_epoch,
|
y,
|
||||||
distribution_strategy=distribution_strategy)
|
sample_weights=sample_weights,
|
||||||
|
batch_size=batch_size,
|
||||||
|
epochs=epochs,
|
||||||
|
class_weights=class_weights,
|
||||||
|
shuffle=shuffle,
|
||||||
|
steps=steps_per_epoch,
|
||||||
|
distribution_strategy=distribution_strategy,
|
||||||
|
max_queue_size=max_queue_size,
|
||||||
|
workers=workers,
|
||||||
|
use_multiprocessing=use_multiprocessing)
|
||||||
val_adapter = None
|
val_adapter = None
|
||||||
if validation_data:
|
if validation_data:
|
||||||
(val_x, val_y,
|
(val_x, val_y,
|
||||||
@ -543,9 +567,19 @@ def _process_training_inputs(model, x, y, batch_size=None,
|
|||||||
return train_adapter, val_adapter
|
return train_adapter, val_adapter
|
||||||
|
|
||||||
|
|
||||||
def _process_inputs(model, x, y, batch_size=None, sample_weights=None,
|
def _process_inputs(model,
|
||||||
class_weights=None, shuffle=False, steps=None,
|
x,
|
||||||
distribution_strategy=None):
|
y,
|
||||||
|
batch_size=None,
|
||||||
|
epochs=1,
|
||||||
|
sample_weights=None,
|
||||||
|
class_weights=None,
|
||||||
|
shuffle=False,
|
||||||
|
steps=None,
|
||||||
|
distribution_strategy=None,
|
||||||
|
max_queue_size=10,
|
||||||
|
workers=1,
|
||||||
|
use_multiprocessing=False):
|
||||||
"""Process the inputs for fit/eval/predict()."""
|
"""Process the inputs for fit/eval/predict()."""
|
||||||
adapter_cls = data_adapter.select_data_adapter(x, y)
|
adapter_cls = data_adapter.select_data_adapter(x, y)
|
||||||
if adapter_cls in _ADAPTER_FOR_STANDARDIZE_USER_DATA:
|
if adapter_cls in _ADAPTER_FOR_STANDARDIZE_USER_DATA:
|
||||||
@ -557,9 +591,18 @@ def _process_inputs(model, x, y, batch_size=None, sample_weights=None,
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
check_steps=False,
|
check_steps=False,
|
||||||
steps=steps)
|
steps=steps)
|
||||||
adapter = adapter_cls(x, y, batch_size=batch_size, steps=steps,
|
adapter = adapter_cls(
|
||||||
sample_weights=sample_weights, shuffle=shuffle,
|
x,
|
||||||
distribution_strategy=distribution_strategy)
|
y,
|
||||||
|
batch_size=batch_size,
|
||||||
|
epochs=epochs,
|
||||||
|
steps=steps,
|
||||||
|
sample_weights=sample_weights,
|
||||||
|
shuffle=shuffle,
|
||||||
|
distribution_strategy=distribution_strategy,
|
||||||
|
max_queue_size=max_queue_size,
|
||||||
|
workers=workers,
|
||||||
|
use_multiprocessing=use_multiprocessing)
|
||||||
# As a fallback for the data type that does not work with
|
# As a fallback for the data type that does not work with
|
||||||
# _standardize_user_data, use the _prepare_model_with_inputs.
|
# _standardize_user_data, use the _prepare_model_with_inputs.
|
||||||
if adapter_cls not in _ADAPTER_FOR_STANDARDIZE_USER_DATA:
|
if adapter_cls not in _ADAPTER_FOR_STANDARDIZE_USER_DATA:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user