Disallow dataset iterators in Keras fit, predict, and evaluate.
PiperOrigin-RevId: 259071119
This commit is contained in:
parent
6dcc61a0ae
commit
229dae116a
@ -153,9 +153,8 @@ class TestSequential(keras_parameterized.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||||
dataset = dataset.repeat(100)
|
dataset = dataset.repeat(100)
|
||||||
dataset = dataset.batch(10)
|
dataset = dataset.batch(10)
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
||||||
|
|
||||||
model.fit(iterator, epochs=1, steps_per_epoch=steps_per_epoch)
|
model.fit(dataset, epochs=1, steps_per_epoch=steps_per_epoch)
|
||||||
self.assertTrue(model.built)
|
self.assertTrue(model.built)
|
||||||
self.assertEqual(len(model.weights), 2 * 2)
|
self.assertEqual(len(model.weights), 2 * 2)
|
||||||
self.assertFalse(model._is_graph_network)
|
self.assertFalse(model._is_graph_network)
|
||||||
|
@ -465,11 +465,21 @@ class Model(network.Network):
|
|||||||
|
|
||||||
def _select_training_loop(self, inputs):
|
def _select_training_loop(self, inputs):
|
||||||
"""Select training loop for fit/eval/predict based on the inputs."""
|
"""Select training loop for fit/eval/predict based on the inputs."""
|
||||||
|
# TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely
|
||||||
|
# integrated into the data adapters in the v2 loop. We can't do this yet
|
||||||
|
# because we currently have to fall back for unhandled data types.
|
||||||
|
if isinstance(inputs, (iterator_ops.Iterator,
|
||||||
|
iterator_ops.IteratorV2)):
|
||||||
|
raise ValueError('For performance reasons Keras `fit`, `evaluate` and'
|
||||||
|
'`predict` accept tf.data `Datasets` as input but not '
|
||||||
|
'iterators that have been manually generated from '
|
||||||
|
'Datasets by users. Please directly pass in the '
|
||||||
|
'original `Dataset` object instead of passing in '
|
||||||
|
'`iter(dataset)`.')
|
||||||
|
|
||||||
# Experiment training loop with default DS path.
|
# Experiment training loop with default DS path.
|
||||||
if (context.executing_eagerly()
|
if (context.executing_eagerly()
|
||||||
and self._run_distributed
|
and self._run_distributed
|
||||||
and not isinstance(inputs, (iterator_ops.Iterator,
|
|
||||||
iterator_ops.IteratorV2))
|
|
||||||
# TODO(scottzhu): Finish getting sequences working with the v2 loops.
|
# TODO(scottzhu): Finish getting sequences working with the v2 loops.
|
||||||
and not isinstance(inputs, (data_utils.Sequence))
|
and not isinstance(inputs, (data_utils.Sequence))
|
||||||
and not distributed_training_utils.is_tpu_strategy(
|
and not distributed_training_utils.is_tpu_strategy(
|
||||||
@ -535,7 +545,7 @@ class Model(network.Network):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A dict mapping input names to the corresponding array/tensors,
|
- A dict mapping input names to the corresponding array/tensors,
|
||||||
if the model has named inputs.
|
if the model has named inputs.
|
||||||
- A `tf.data` dataset or a dataset iterator. Should return a tuple
|
- A `tf.data` dataset. Should return a tuple
|
||||||
of either `(inputs, targets)` or
|
of either `(inputs, targets)` or
|
||||||
`(inputs, targets, sample_weights)`.
|
`(inputs, targets, sample_weights)`.
|
||||||
- A generator or `keras.utils.Sequence` returning `(inputs, targets)`
|
- A generator or `keras.utils.Sequence` returning `(inputs, targets)`
|
||||||
@ -543,14 +553,14 @@ class Model(network.Network):
|
|||||||
y: Target data. Like the input data `x`,
|
y: Target data. Like the input data `x`,
|
||||||
it could be either Numpy array(s) or TensorFlow tensor(s).
|
it could be either Numpy array(s) or TensorFlow tensor(s).
|
||||||
It should be consistent with `x` (you cannot have Numpy inputs and
|
It should be consistent with `x` (you cannot have Numpy inputs and
|
||||||
tensor targets, or inversely). If `x` is a dataset, dataset
|
tensor targets, or inversely). If `x` is a dataset, generator,
|
||||||
iterator, generator, or `keras.utils.Sequence` instance, `y` should
|
or `keras.utils.Sequence` instance, `y` should
|
||||||
not be specified (since targets will be obtained from `x`).
|
not be specified (since targets will be obtained from `x`).
|
||||||
batch_size: Integer or `None`.
|
batch_size: Integer or `None`.
|
||||||
Number of samples per gradient update.
|
Number of samples per gradient update.
|
||||||
If unspecified, `batch_size` will default to 32.
|
If unspecified, `batch_size` will default to 32.
|
||||||
Do not specify the `batch_size` if your data is in the
|
Do not specify the `batch_size` if your data is in the
|
||||||
form of symbolic tensors, dataset, dataset iterators,
|
form of symbolic tensors, datasets,
|
||||||
generators, or `keras.utils.Sequence` instances (since they generate
|
generators, or `keras.utils.Sequence` instances (since they generate
|
||||||
batches).
|
batches).
|
||||||
epochs: Integer. Number of epochs to train the model.
|
epochs: Integer. Number of epochs to train the model.
|
||||||
@ -577,7 +587,7 @@ class Model(network.Network):
|
|||||||
on this data at the end of each epoch.
|
on this data at the end of each epoch.
|
||||||
The validation data is selected from the last samples
|
The validation data is selected from the last samples
|
||||||
in the `x` and `y` data provided, before shuffling. This argument is
|
in the `x` and `y` data provided, before shuffling. This argument is
|
||||||
not supported when `x` is a dataset, dataset iterator, generator or
|
not supported when `x` is a dataset, generator or
|
||||||
`keras.utils.Sequence` instance.
|
`keras.utils.Sequence` instance.
|
||||||
validation_data: Data on which to evaluate
|
validation_data: Data on which to evaluate
|
||||||
the loss and any model metrics at the end of each epoch.
|
the loss and any model metrics at the end of each epoch.
|
||||||
@ -586,7 +596,7 @@ class Model(network.Network):
|
|||||||
`validation_data` could be:
|
`validation_data` could be:
|
||||||
- tuple `(x_val, y_val)` of Numpy arrays or tensors
|
- tuple `(x_val, y_val)` of Numpy arrays or tensors
|
||||||
- tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
|
- tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
|
||||||
- dataset or a dataset iterator
|
- dataset
|
||||||
For the first two cases, `batch_size` must be provided.
|
For the first two cases, `batch_size` must be provided.
|
||||||
For the last case, `validation_steps` must be provided.
|
For the last case, `validation_steps` must be provided.
|
||||||
shuffle: Boolean (whether to shuffle the training data
|
shuffle: Boolean (whether to shuffle the training data
|
||||||
@ -611,7 +621,7 @@ class Model(network.Network):
|
|||||||
to apply a different weight to every timestep of every sample.
|
to apply a different weight to every timestep of every sample.
|
||||||
In this case you should make sure to specify
|
In this case you should make sure to specify
|
||||||
`sample_weight_mode="temporal"` in `compile()`. This argument is not
|
`sample_weight_mode="temporal"` in `compile()`. This argument is not
|
||||||
supported when `x` is a dataset, dataset iterator, generator, or
|
supported when `x` is a dataset, generator, or
|
||||||
`keras.utils.Sequence` instance, instead provide the sample_weights
|
`keras.utils.Sequence` instance, instead provide the sample_weights
|
||||||
as the third element of `x`.
|
as the third element of `x`.
|
||||||
initial_epoch: Integer.
|
initial_epoch: Integer.
|
||||||
@ -624,14 +634,14 @@ class Model(network.Network):
|
|||||||
TensorFlow data tensors, the default `None` is equal to
|
TensorFlow data tensors, the default `None` is equal to
|
||||||
the number of samples in your dataset divided by
|
the number of samples in your dataset divided by
|
||||||
the batch size, or 1 if that cannot be determined. If x is a
|
the batch size, or 1 if that cannot be determined. If x is a
|
||||||
`tf.data` dataset or a dataset iterator, and 'steps_per_epoch'
|
`tf.data` dataset, and 'steps_per_epoch'
|
||||||
is None, the epoch will run until the input dataset is exhausted.
|
is None, the epoch will run until the input dataset is exhausted.
|
||||||
This argument is not supported with array inputs.
|
This argument is not supported with array inputs.
|
||||||
validation_steps: Only relevant if `validation_data` is provided and
|
validation_steps: Only relevant if `validation_data` is provided and
|
||||||
is a dataset or dataset iterator. Total number of steps (batches of
|
is a `tf.data` dataset. Total number of steps (batches of
|
||||||
samples) to draw before stopping when performing validation
|
samples) to draw before stopping when performing validation
|
||||||
at the end of every epoch. If validation_data is a `tf.data` dataset
|
at the end of every epoch. If validation_data is a `tf.data` dataset
|
||||||
or a dataset iterator, and 'validation_steps' is None, validation
|
and 'validation_steps' is None, validation
|
||||||
will run until the `validation_data` dataset is exhausted.
|
will run until the `validation_data` dataset is exhausted.
|
||||||
validation_freq: Only relevant if validation data is provided. Integer
|
validation_freq: Only relevant if validation data is provided. Integer
|
||||||
or `collections.Container` instance (e.g. list, tuple, etc.). If an
|
or `collections.Container` instance (e.g. list, tuple, etc.). If an
|
||||||
@ -722,20 +732,20 @@ class Model(network.Network):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A dict mapping input names to the corresponding array/tensors,
|
- A dict mapping input names to the corresponding array/tensors,
|
||||||
if the model has named inputs.
|
if the model has named inputs.
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
- A generator or `keras.utils.Sequence` instance.
|
- A generator or `keras.utils.Sequence` instance.
|
||||||
y: Target data. Like the input data `x`,
|
y: Target data. Like the input data `x`,
|
||||||
it could be either Numpy array(s) or TensorFlow tensor(s).
|
it could be either Numpy array(s) or TensorFlow tensor(s).
|
||||||
It should be consistent with `x` (you cannot have Numpy inputs and
|
It should be consistent with `x` (you cannot have Numpy inputs and
|
||||||
tensor targets, or inversely).
|
tensor targets, or inversely).
|
||||||
If `x` is a dataset, dataset iterator, generator or
|
If `x` is a dataset, generator or
|
||||||
`keras.utils.Sequence` instance, `y` should not be specified (since
|
`keras.utils.Sequence` instance, `y` should not be specified (since
|
||||||
targets will be obtained from the iterator/dataset).
|
targets will be obtained from the iterator/dataset).
|
||||||
batch_size: Integer or `None`.
|
batch_size: Integer or `None`.
|
||||||
Number of samples per gradient update.
|
Number of samples per gradient update.
|
||||||
If unspecified, `batch_size` will default to 32.
|
If unspecified, `batch_size` will default to 32.
|
||||||
Do not specify the `batch_size` is your data is in the
|
Do not specify the `batch_size` is your data is in the
|
||||||
form of symbolic tensors, dataset, dataset iterators,
|
form of symbolic tensors, dataset,
|
||||||
generators, or `keras.utils.Sequence` instances (since they generate
|
generators, or `keras.utils.Sequence` instances (since they generate
|
||||||
batches).
|
batches).
|
||||||
verbose: 0 or 1. Verbosity mode.
|
verbose: 0 or 1. Verbosity mode.
|
||||||
@ -751,13 +761,13 @@ class Model(network.Network):
|
|||||||
to apply a different weight to every timestep of every sample.
|
to apply a different weight to every timestep of every sample.
|
||||||
In this case you should make sure to specify
|
In this case you should make sure to specify
|
||||||
`sample_weight_mode="temporal"` in `compile()`. This argument is not
|
`sample_weight_mode="temporal"` in `compile()`. This argument is not
|
||||||
supported when `x` is a dataset or a dataset iterator, instead pass
|
supported when `x` is a dataset, instead pass
|
||||||
sample weights as the third element of `x`.
|
sample weights as the third element of `x`.
|
||||||
steps: Integer or `None`.
|
steps: Integer or `None`.
|
||||||
Total number of steps (batches of samples)
|
Total number of steps (batches of samples)
|
||||||
before declaring the evaluation round finished.
|
before declaring the evaluation round finished.
|
||||||
Ignored with the default value of `None`.
|
Ignored with the default value of `None`.
|
||||||
If x is a `tf.data` dataset or a dataset iterator, and `steps` is
|
If x is a `tf.data` dataset and `steps` is
|
||||||
None, 'evaluate' will run until the dataset is exhausted.
|
None, 'evaluate' will run until the dataset is exhausted.
|
||||||
This argument is not supported with array inputs.
|
This argument is not supported with array inputs.
|
||||||
callbacks: List of `keras.callbacks.Callback` instances.
|
callbacks: List of `keras.callbacks.Callback` instances.
|
||||||
@ -822,20 +832,20 @@ class Model(network.Network):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A TensorFlow tensor, or a list of tensors
|
- A TensorFlow tensor, or a list of tensors
|
||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
- A generator or `keras.utils.Sequence` instance.
|
- A generator or `keras.utils.Sequence` instance.
|
||||||
batch_size: Integer or `None`.
|
batch_size: Integer or `None`.
|
||||||
Number of samples per gradient update.
|
Number of samples per gradient update.
|
||||||
If unspecified, `batch_size` will default to 32.
|
If unspecified, `batch_size` will default to 32.
|
||||||
Do not specify the `batch_size` is your data is in the
|
Do not specify the `batch_size` is your data is in the
|
||||||
form of symbolic tensors, dataset, dataset iterators,
|
form of symbolic tensors, dataset,
|
||||||
generators, or `keras.utils.Sequence` instances (since they generate
|
generators, or `keras.utils.Sequence` instances (since they generate
|
||||||
batches).
|
batches).
|
||||||
verbose: Verbosity mode, 0 or 1.
|
verbose: Verbosity mode, 0 or 1.
|
||||||
steps: Total number of steps (batches of samples)
|
steps: Total number of steps (batches of samples)
|
||||||
before declaring the prediction round finished.
|
before declaring the prediction round finished.
|
||||||
Ignored with the default value of `None`. If x is a `tf.data`
|
Ignored with the default value of `None`. If x is a `tf.data`
|
||||||
dataset or a dataset iterator, and `steps` is None, `predict` will
|
dataset and `steps` is None, `predict` will
|
||||||
run until the input dataset is exhausted.
|
run until the input dataset is exhausted.
|
||||||
callbacks: List of `keras.callbacks.Callback` instances.
|
callbacks: List of `keras.callbacks.Callback` instances.
|
||||||
List of callbacks to apply during prediction.
|
List of callbacks to apply during prediction.
|
||||||
@ -904,11 +914,11 @@ class Model(network.Network):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A dict mapping input names to the corresponding array/tensors,
|
- A dict mapping input names to the corresponding array/tensors,
|
||||||
if the model has named inputs.
|
if the model has named inputs.
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
y: Target data. Like the input data `x`, it could be either Numpy
|
y: Target data. Like the input data `x`, it could be either Numpy
|
||||||
array(s) or TensorFlow tensor(s). It should be consistent with `x`
|
array(s) or TensorFlow tensor(s). It should be consistent with `x`
|
||||||
(you cannot have Numpy inputs and tensor targets, or inversely). If
|
(you cannot have Numpy inputs and tensor targets, or inversely). If
|
||||||
`x` is a dataset or a dataset iterator, `y` should not be specified
|
`x` is a dataset, `y` should not be specified
|
||||||
(since targets will be obtained from the iterator).
|
(since targets will be obtained from the iterator).
|
||||||
sample_weight: Optional array of the same length as x, containing
|
sample_weight: Optional array of the same length as x, containing
|
||||||
weights to apply to the model's loss for each sample. In the case of
|
weights to apply to the model's loss for each sample. In the case of
|
||||||
@ -916,7 +926,7 @@ class Model(network.Network):
|
|||||||
sequence_length), to apply a different weight to every timestep of
|
sequence_length), to apply a different weight to every timestep of
|
||||||
every sample. In this case you should make sure to specify
|
every sample. In this case you should make sure to specify
|
||||||
sample_weight_mode="temporal" in compile(). This argument is not
|
sample_weight_mode="temporal" in compile(). This argument is not
|
||||||
supported when `x` is a dataset or a dataset iterator.
|
supported when `x` is a dataset.
|
||||||
class_weight: Optional dictionary mapping class indices (integers) to a
|
class_weight: Optional dictionary mapping class indices (integers) to a
|
||||||
weight (float) to apply to the model's loss for the samples from this
|
weight (float) to apply to the model's loss for the samples from this
|
||||||
class during training. This can be useful to tell the model to "pay
|
class during training. This can be useful to tell the model to "pay
|
||||||
@ -993,13 +1003,12 @@ class Model(network.Network):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A dict mapping input names to the corresponding array/tensors,
|
- A dict mapping input names to the corresponding array/tensors,
|
||||||
if the model has named inputs.
|
if the model has named inputs.
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
y: Target data. Like the input data `x`,
|
y: Target data. Like the input data `x`,
|
||||||
it could be either Numpy array(s) or TensorFlow tensor(s).
|
it could be either Numpy array(s) or TensorFlow tensor(s).
|
||||||
It should be consistent with `x` (you cannot have Numpy inputs and
|
It should be consistent with `x` (you cannot have Numpy inputs and
|
||||||
tensor targets, or inversely). If `x` is a dataset or a
|
tensor targets, or inversely). If `x` is a dataset `y` should
|
||||||
dataset iterator, `y` should not be specified
|
not be specified (since targets will be obtained from the iterator).
|
||||||
(since targets will be obtained from the iterator).
|
|
||||||
sample_weight: Optional array of the same length as x, containing
|
sample_weight: Optional array of the same length as x, containing
|
||||||
weights to apply to the model's loss for each sample.
|
weights to apply to the model's loss for each sample.
|
||||||
In the case of temporal data, you can pass a 2D array
|
In the case of temporal data, you can pass a 2D array
|
||||||
@ -1007,7 +1016,7 @@ class Model(network.Network):
|
|||||||
to apply a different weight to every timestep of every sample.
|
to apply a different weight to every timestep of every sample.
|
||||||
In this case you should make sure to specify
|
In this case you should make sure to specify
|
||||||
sample_weight_mode="temporal" in compile(). This argument is not
|
sample_weight_mode="temporal" in compile(). This argument is not
|
||||||
supported when `x` is a dataset or a dataset iterator.
|
supported when `x` is a dataset.
|
||||||
reset_metrics: If `True`, the metrics returned will be only for this
|
reset_metrics: If `True`, the metrics returned will be only for this
|
||||||
batch. If `False`, the metrics will be statefully accumulated across
|
batch. If `False`, the metrics will be statefully accumulated across
|
||||||
batches.
|
batches.
|
||||||
@ -1068,7 +1077,7 @@ class Model(network.Network):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A TensorFlow tensor, or a list of tensors
|
- A TensorFlow tensor, or a list of tensors
|
||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Numpy array(s) of predictions.
|
Numpy array(s) of predictions.
|
||||||
@ -2221,13 +2230,12 @@ class Model(network.Network):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A dict mapping input names to the corresponding array/tensors,
|
- A dict mapping input names to the corresponding array/tensors,
|
||||||
if the model has named inputs.
|
if the model has named inputs.
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
y: Target data. Like the input data `x`,
|
y: Target data. Like the input data `x`,
|
||||||
it could be either Numpy array(s) or TensorFlow tensor(s).
|
it could be either Numpy array(s) or TensorFlow tensor(s).
|
||||||
It should be consistent with `x` (you cannot have Numpy inputs and
|
It should be consistent with `x` (you cannot have Numpy inputs and
|
||||||
tensor targets, or inversely). If `x` is a dataset or a
|
tensor targets, or inversely). If `x` is a dataset, `y` should not be
|
||||||
dataset iterator, `y` should not be specified
|
specified (since targets will be obtained from the iterator).
|
||||||
(since targets will be obtained from the iterator).
|
|
||||||
sample_weight: An optional sample-weight array passed by the user to
|
sample_weight: An optional sample-weight array passed by the user to
|
||||||
weight the importance of each sample in `x`.
|
weight the importance of each sample in `x`.
|
||||||
class_weight: An optional class-weight array by the user to
|
class_weight: An optional class-weight array by the user to
|
||||||
|
@ -47,100 +47,6 @@ class BatchCounterCallback(callbacks.Callback):
|
|||||||
self.batch_count += 1
|
self.batch_count += 1
|
||||||
|
|
||||||
|
|
||||||
class TestTrainingWithDatasetIterators(keras_parameterized.TestCase):
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
|
||||||
def test_training_and_eval_methods_on_iterators_single_io(self):
|
|
||||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
|
||||||
optimizer = 'rmsprop'
|
|
||||||
loss = 'mse'
|
|
||||||
metrics = ['mae', metrics_module.CategoricalAccuracy()]
|
|
||||||
model.compile(
|
|
||||||
optimizer,
|
|
||||||
loss,
|
|
||||||
metrics=metrics,
|
|
||||||
run_eagerly=testing_utils.should_run_eagerly(),
|
|
||||||
run_distributed=testing_utils.should_run_distributed())
|
|
||||||
|
|
||||||
inputs = np.zeros((10, 3), np.float32)
|
|
||||||
targets = np.zeros((10, 4), np.float32)
|
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
|
||||||
dataset = dataset.repeat(100)
|
|
||||||
dataset = dataset.batch(10)
|
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
||||||
|
|
||||||
model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
|
|
||||||
model.evaluate(iterator, steps=2, verbose=1)
|
|
||||||
model.predict(iterator, steps=2)
|
|
||||||
|
|
||||||
# Test with validation data
|
|
||||||
model.fit(iterator,
|
|
||||||
epochs=1, steps_per_epoch=2, verbose=0,
|
|
||||||
validation_data=iterator, validation_steps=2)
|
|
||||||
# Test with validation split
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, '`validation_split` argument is not supported when '):
|
|
||||||
model.fit(iterator,
|
|
||||||
epochs=1, steps_per_epoch=2, verbose=0,
|
|
||||||
validation_split=0.5, validation_steps=2)
|
|
||||||
|
|
||||||
# Test with sample weight.
|
|
||||||
sample_weight = np.random.random((10,))
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, '`sample_weight` argument is not supported '
|
|
||||||
'when input `x` is a dataset or a dataset iterator'):
|
|
||||||
model.fit(
|
|
||||||
iterator,
|
|
||||||
epochs=1,
|
|
||||||
steps_per_epoch=2,
|
|
||||||
verbose=0,
|
|
||||||
sample_weight=sample_weight)
|
|
||||||
|
|
||||||
# Test invalid usage
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
'you should not specify a target'):
|
|
||||||
model.fit(iterator, iterator,
|
|
||||||
epochs=1, steps_per_epoch=2, verbose=0)
|
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, 'the `steps_per_epoch` argument'):
|
|
||||||
model.fit(iterator, epochs=1, verbose=0)
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
'the `steps` argument'):
|
|
||||||
model.evaluate(iterator, verbose=0)
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
'the `steps` argument'):
|
|
||||||
model.predict(iterator, verbose=0)
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
|
||||||
def test_iterators_running_out_of_data(self):
|
|
||||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
|
||||||
optimizer = 'rmsprop'
|
|
||||||
loss = 'mse'
|
|
||||||
metrics = ['mae']
|
|
||||||
model.compile(
|
|
||||||
optimizer,
|
|
||||||
loss,
|
|
||||||
metrics=metrics,
|
|
||||||
run_eagerly=testing_utils.should_run_eagerly(),
|
|
||||||
run_distributed=testing_utils.should_run_distributed())
|
|
||||||
|
|
||||||
inputs = np.zeros((10, 3), np.float32)
|
|
||||||
targets = np.zeros((10, 4), np.float32)
|
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
|
||||||
dataset = dataset.repeat(2)
|
|
||||||
dataset = dataset.batch(10)
|
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
||||||
|
|
||||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
|
||||||
model.fit(iterator, epochs=1, steps_per_epoch=3, verbose=0)
|
|
||||||
self.assertRegexpMatches(
|
|
||||||
str(mock_log.call_args),
|
|
||||||
'dataset iterator ran out of data')
|
|
||||||
|
|
||||||
|
|
||||||
class TestTrainingWithDataset(keras_parameterized.TestCase):
|
class TestTrainingWithDataset(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@ -618,11 +524,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
|||||||
model.fit(dataset)
|
model.fit(dataset)
|
||||||
|
|
||||||
|
|
||||||
class TestMetricsWithDatasetIterators(keras_parameterized.TestCase):
|
class TestMetricsWithDatasets(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_metrics_correctness_with_iterator(self):
|
def test_metrics_correctness_with_dataset(self):
|
||||||
layers = [
|
layers = [
|
||||||
keras.layers.Dense(8, activation='relu', input_dim=4,
|
keras.layers.Dense(8, activation='relu', input_dim=4,
|
||||||
kernel_initializer='ones'),
|
kernel_initializer='ones'),
|
||||||
@ -643,8 +549,7 @@ class TestMetricsWithDatasetIterators(keras_parameterized.TestCase):
|
|||||||
y = np.random.randint(2, size=(100, 1)).astype(np.float32)
|
y = np.random.randint(2, size=(100, 1)).astype(np.float32)
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||||
dataset = dataset.batch(10)
|
dataset = dataset.batch(10)
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
outs = model.evaluate(dataset, steps=10)
|
||||||
outs = model.evaluate(iterator, steps=10)
|
|
||||||
self.assertEqual(np.around(outs[1], decimals=1), 0.5)
|
self.assertEqual(np.around(outs[1], decimals=1), 0.5)
|
||||||
self.assertEqual(np.around(outs[2], decimals=1), 0.5)
|
self.assertEqual(np.around(outs[2], decimals=1), 0.5)
|
||||||
|
|
||||||
@ -652,8 +557,7 @@ class TestMetricsWithDatasetIterators(keras_parameterized.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||||
dataset = dataset.repeat(100)
|
dataset = dataset.repeat(100)
|
||||||
dataset = dataset.batch(10)
|
dataset = dataset.batch(10)
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
outs = model.evaluate(dataset, steps=10)
|
||||||
outs = model.evaluate(iterator, steps=10)
|
|
||||||
self.assertEqual(outs[1], 0.)
|
self.assertEqual(outs[1], 0.)
|
||||||
self.assertEqual(outs[2], 0.)
|
self.assertEqual(outs[2], 0.)
|
||||||
|
|
||||||
|
@ -183,30 +183,20 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
x = array_ops.zeros(shape=(10, 3))
|
x = array_ops.zeros(shape=(10, 3))
|
||||||
y = array_ops.zeros(shape=(10, 4))
|
y = array_ops.zeros(shape=(10, 4))
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat(10).batch(5)
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat(10).batch(5)
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
||||||
validation_dataset = dataset_ops.Dataset.from_tensor_slices(
|
validation_dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||||
(x, y)).repeat().batch(5) # Infinite dataset.
|
(x, y)).repeat().batch(5) # Infinite dataset.
|
||||||
validation_iterator = dataset_ops.make_one_shot_iterator(validation_dataset)
|
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
model.fit(dataset, epochs=1, verbose=0)
|
||||||
ValueError, r'specify .* `steps_per_epoch`'):
|
|
||||||
model.fit(iterator, epochs=1, verbose=0)
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
# In eager execution, `array_ops.zeros` returns value tensors
|
|
||||||
# which can be used for validation without a `validation_steps` argument.
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError, r'provide either `batch_size` or `validation_steps`'):
|
|
||||||
model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
|
|
||||||
validation_data=(x, y))
|
|
||||||
# Step argument is required for infinite datasets.
|
# Step argument is required for infinite datasets.
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
'specify the `validation_steps` argument.'):
|
'specify the `validation_steps` argument.'):
|
||||||
model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
|
model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0,
|
||||||
validation_data=validation_dataset)
|
validation_data=validation_dataset)
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
'specify the `validation_steps` argument.'):
|
'specify the `validation_steps` argument.'):
|
||||||
model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
|
model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0,
|
||||||
validation_data=validation_iterator)
|
validation_data=validation_dataset)
|
||||||
|
|
||||||
# TODO(b/120931266): Enable test on subclassed models after bug causing an
|
# TODO(b/120931266): Enable test on subclassed models after bug causing an
|
||||||
# extra dimension to be added to predict outputs is fixed.
|
# extra dimension to be added to predict outputs is fixed.
|
||||||
@ -282,8 +272,7 @@ class CorrectnessTest(keras_parameterized.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||||
dataset = dataset.repeat(100)
|
dataset = dataset.repeat(100)
|
||||||
dataset = dataset.batch(10)
|
dataset = dataset.batch(10)
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
history = model.fit(dataset, epochs=1, steps_per_epoch=10)
|
||||||
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
|
|
||||||
self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4)
|
self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4)
|
||||||
|
|
||||||
def test_loss_in_call(self):
|
def test_loss_in_call(self):
|
||||||
|
@ -859,8 +859,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
|
dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
|
||||||
dataset = dataset.repeat(10)
|
dataset = dataset.repeat(10)
|
||||||
dataset = dataset.batch(10)
|
dataset = dataset.batch(10)
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
model.fit(dataset, epochs=1, steps_per_epoch=2)
|
||||||
model.fit(iterator, epochs=1, steps_per_epoch=2)
|
|
||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
# Test with eager execution
|
# Test with eager execution
|
||||||
@ -870,7 +869,7 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
model.fit(x_train, y_train, batch_size=5, epochs=1)
|
model.fit(x_train, y_train, batch_size=5, epochs=1)
|
||||||
|
|
||||||
# Test with eager execution and iterator
|
# Test with eager execution and iterator
|
||||||
model.fit(iterator, epochs=1, steps_per_epoch=2)
|
model.fit(dataset, epochs=1, steps_per_epoch=2)
|
||||||
|
|
||||||
def test_losses_in_defun(self):
|
def test_losses_in_defun(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
|
@ -178,11 +178,11 @@ def train_on_batch(
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A dict mapping input names to the corresponding array/tensors,
|
- A dict mapping input names to the corresponding array/tensors,
|
||||||
if the model has named inputs.
|
if the model has named inputs.
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
y: Target data. Like the input data `x`, it could be either Numpy
|
y: Target data. Like the input data `x`, it could be either Numpy
|
||||||
array(s) or TensorFlow tensor(s). It should be consistent with `x`
|
array(s) or TensorFlow tensor(s). It should be consistent with `x`
|
||||||
(you cannot have Numpy inputs and tensor targets, or inversely). If
|
(you cannot have Numpy inputs and tensor targets, or inversely). If
|
||||||
`x` is a dataset or a dataset iterator, `y` should not be specified
|
`x` is a dataset `y` should not be specified
|
||||||
(since targets will be obtained from the iterator).
|
(since targets will be obtained from the iterator).
|
||||||
sample_weight: Optional array of the same length as x, containing
|
sample_weight: Optional array of the same length as x, containing
|
||||||
weights to apply to the model's loss for each sample. In the case of
|
weights to apply to the model's loss for each sample. In the case of
|
||||||
@ -190,7 +190,7 @@ def train_on_batch(
|
|||||||
sequence_length), to apply a different weight to every timestep of
|
sequence_length), to apply a different weight to every timestep of
|
||||||
every sample. In this case you should make sure to specify
|
every sample. In this case you should make sure to specify
|
||||||
sample_weight_mode="temporal" in compile(). This argument is not
|
sample_weight_mode="temporal" in compile(). This argument is not
|
||||||
supported when `x` is a dataset or a dataset iterator.
|
supported when `x` is a dataset.
|
||||||
class_weight: Optional dictionary mapping class indices (integers) to a
|
class_weight: Optional dictionary mapping class indices (integers) to a
|
||||||
weight (float) to apply to the model's loss for the samples from this
|
weight (float) to apply to the model's loss for the samples from this
|
||||||
class during training. This can be useful to tell the model to "pay
|
class during training. This can be useful to tell the model to "pay
|
||||||
@ -249,12 +249,12 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A dict mapping input names to the corresponding array/tensors,
|
- A dict mapping input names to the corresponding array/tensors,
|
||||||
if the model has named inputs.
|
if the model has named inputs.
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
y: Target data. Like the input data `x`,
|
y: Target data. Like the input data `x`,
|
||||||
it could be either Numpy array(s) or TensorFlow tensor(s).
|
it could be either Numpy array(s) or TensorFlow tensor(s).
|
||||||
It should be consistent with `x` (you cannot have Numpy inputs and
|
It should be consistent with `x` (you cannot have Numpy inputs and
|
||||||
tensor targets, or inversely). If `x` is a dataset or a
|
tensor targets, or inversely). If `x` is a dataset,
|
||||||
dataset iterator, `y` should not be specified
|
`y` should not be specified
|
||||||
(since targets will be obtained from the iterator).
|
(since targets will be obtained from the iterator).
|
||||||
sample_weight: Optional array of the same length as x, containing
|
sample_weight: Optional array of the same length as x, containing
|
||||||
weights to apply to the model's loss for each sample.
|
weights to apply to the model's loss for each sample.
|
||||||
@ -263,7 +263,7 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
|
|||||||
to apply a different weight to every timestep of every sample.
|
to apply a different weight to every timestep of every sample.
|
||||||
In this case you should make sure to specify
|
In this case you should make sure to specify
|
||||||
sample_weight_mode="temporal" in compile(). This argument is not
|
sample_weight_mode="temporal" in compile(). This argument is not
|
||||||
supported when `x` is a dataset or a dataset iterator.
|
supported when `x` is a dataset.
|
||||||
reset_metrics: If `True`, the metrics returned will be only for this
|
reset_metrics: If `True`, the metrics returned will be only for this
|
||||||
batch. If `False`, the metrics will be statefully accumulated across
|
batch. If `False`, the metrics will be statefully accumulated across
|
||||||
batches.
|
batches.
|
||||||
@ -310,7 +310,7 @@ def predict_on_batch(model, x):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A TensorFlow tensor, or a list of tensors
|
- A TensorFlow tensor, or a list of tensors
|
||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Numpy array(s) of predictions.
|
Numpy array(s) of predictions.
|
||||||
|
@ -646,7 +646,7 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
|
|||||||
model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
|
model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
|
||||||
_ = model.evaluate([x1, x2], [y1, y2], verbose=0)
|
_ = model.evaluate([x1, x2], [y1, y2], verbose=0)
|
||||||
|
|
||||||
def test_single_io_workflow_with_dataset_iterators(self):
|
def test_single_io_workflow_with_datasets(self):
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
num_samples = 10
|
num_samples = 10
|
||||||
input_dim = 50
|
input_dim = 50
|
||||||
@ -664,10 +664,9 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||||
dataset = dataset.repeat(100)
|
dataset = dataset.repeat(100)
|
||||||
dataset = dataset.batch(10)
|
dataset = dataset.batch(10)
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
||||||
|
|
||||||
model.fit(iterator, epochs=2, steps_per_epoch=10, verbose=0)
|
model.fit(dataset, epochs=2, steps_per_epoch=10, verbose=0)
|
||||||
_ = model.evaluate(iterator, steps=10, verbose=0)
|
_ = model.evaluate(dataset, steps=10, verbose=0)
|
||||||
|
|
||||||
def test_attributes(self):
|
def test_attributes(self):
|
||||||
# layers, weights, trainable_weights, non_trainable_weights, inputs, outputs
|
# layers, weights, trainable_weights, non_trainable_weights, inputs, outputs
|
||||||
|
Loading…
Reference in New Issue
Block a user