Small improvements to handling of Datasets in Keras.
* Allow sparse labels to work with Datasets. * Allow sample_weights to be passed as the third output of a Dataset (like how generator input is treated). PiperOrigin-RevId: 211834259
This commit is contained in:
parent
6d893ecfb9
commit
025277a159
@ -446,8 +446,7 @@ class TestWithDistributionStrategy(test.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
||||||
dataset = dataset.repeat(100)
|
dataset = dataset.repeat(100)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
|
||||||
'expected input to have 2 dimensions'):
|
|
||||||
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
|
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
|
||||||
|
|
||||||
# Wrong input shape
|
# Wrong input shape
|
||||||
|
@ -928,11 +928,16 @@ class Model(Network):
|
|||||||
'Make sure that your dataset can generate '
|
'Make sure that your dataset can generate '
|
||||||
'required number of samples.')
|
'required number of samples.')
|
||||||
|
|
||||||
if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
|
if (not isinstance(next_element, (list, tuple)) or
|
||||||
raise ValueError('Please provide model inputs as a list or tuple of 2 '
|
len(next_element) not in [2, 3]):
|
||||||
'elements: input and target pair. '
|
raise ValueError(
|
||||||
'Received %s' % next_element)
|
'Please provide model inputs as a list or tuple of 2 or 3'
|
||||||
x, y = next_element
|
'elements: (input, target) or (input, target, sample_weights)'
|
||||||
|
'Received %s' % next_element)
|
||||||
|
if len(next_element) == 2:
|
||||||
|
x, y = next_element
|
||||||
|
else:
|
||||||
|
x, y, sample_weight = next_element
|
||||||
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
|
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
|
||||||
class_weight, batch_size)
|
class_weight, batch_size)
|
||||||
return x, y, sample_weights
|
return x, y, sample_weights
|
||||||
@ -1331,7 +1336,8 @@ class Model(Network):
|
|||||||
(in case the model has multiple inputs).
|
(in case the model has multiple inputs).
|
||||||
- A dict mapping input names to the corresponding array/tensors,
|
- A dict mapping input names to the corresponding array/tensors,
|
||||||
if the model has named inputs.
|
if the model has named inputs.
|
||||||
- A `tf.data` dataset or a dataset iterator.
|
- A `tf.data` dataset or a dataset iterator. Should return a tuple
|
||||||
|
of either (inputs, targets) or (inputs, targets, sample_weights).
|
||||||
y: Target data. Like the input data `x`,
|
y: Target data. Like the input data `x`,
|
||||||
it could be either Numpy array(s) or TensorFlow tensor(s).
|
it could be either Numpy array(s) or TensorFlow tensor(s).
|
||||||
It should be consistent with `x` (you cannot have Numpy inputs and
|
It should be consistent with `x` (you cannot have Numpy inputs and
|
||||||
@ -1396,7 +1402,8 @@ class Model(Network):
|
|||||||
to apply a different weight to every timestep of every sample.
|
to apply a different weight to every timestep of every sample.
|
||||||
In this case you should make sure to specify
|
In this case you should make sure to specify
|
||||||
`sample_weight_mode="temporal"` in `compile()`. This argument is not
|
`sample_weight_mode="temporal"` in `compile()`. This argument is not
|
||||||
supported when `x` is a dataset or a dataset iterator.
|
supported when `x` is a dataset or a dataset iterator, instead
|
||||||
|
provide the sample_weights as the third element of `x`.
|
||||||
initial_epoch: Integer.
|
initial_epoch: Integer.
|
||||||
Epoch at which to start training
|
Epoch at which to start training
|
||||||
(useful for resuming a previous training run).
|
(useful for resuming a previous training run).
|
||||||
|
@ -417,11 +417,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
|
|||||||
"""
|
"""
|
||||||
assert isinstance(inputs, iterator_ops.EagerIterator)
|
assert isinstance(inputs, iterator_ops.EagerIterator)
|
||||||
if not isinstance(inputs.output_shapes,
|
if not isinstance(inputs.output_shapes,
|
||||||
(list, tuple)) or len(inputs.output_shapes) > 2:
|
(list, tuple)) or len(inputs.output_shapes) > 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Please provide data as a list or tuple of 1 or 2 elements '
|
'Please provide data as a list or tuple of 1, 2, or 3 elements '
|
||||||
' - input or input and target pair. Received %s. We do not use the '
|
' - `(input)`, or `(input, target)`, or `(input, target,'
|
||||||
'`target` value here.' % inputs.output_shapes)
|
'sample_weights)`. Received %s. We do not use the `target` or'
|
||||||
|
'`sample_weights` value here.' % inputs.output_shapes)
|
||||||
outs = []
|
outs = []
|
||||||
if verbose == 1:
|
if verbose == 1:
|
||||||
progbar = generic_utils.Progbar(target=steps)
|
progbar = generic_utils.Progbar(target=steps)
|
||||||
|
@ -2097,6 +2097,43 @@ class TestTrainingWithDataset(test.TestCase):
|
|||||||
'you should specify the `steps` argument'):
|
'you should specify the `steps` argument'):
|
||||||
model.predict(dataset, verbose=0)
|
model.predict(dataset, verbose=0)
|
||||||
|
|
||||||
|
@tf_test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_dataset_with_sample_weights(self):
|
||||||
|
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||||
|
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||||
|
loss = 'mse'
|
||||||
|
metrics = ['mae', metrics_module.CategoricalAccuracy()]
|
||||||
|
model.compile(optimizer, loss, metrics=metrics)
|
||||||
|
|
||||||
|
inputs = np.zeros((10, 3), np.float32)
|
||||||
|
targets = np.zeros((10, 4), np.float32)
|
||||||
|
sample_weights = np.ones((10), np.float32)
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
|
||||||
|
sample_weights))
|
||||||
|
dataset = dataset.repeat(100)
|
||||||
|
dataset = dataset.batch(10)
|
||||||
|
|
||||||
|
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
|
||||||
|
model.evaluate(dataset, steps=2, verbose=1)
|
||||||
|
model.predict(dataset, steps=2)
|
||||||
|
model.train_on_batch(dataset)
|
||||||
|
model.predict_on_batch(dataset)
|
||||||
|
|
||||||
|
@tf_test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_dataset_with_sparse_labels(self):
|
||||||
|
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||||
|
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||||
|
loss = 'sparse_categorical_crossentropy'
|
||||||
|
model.compile(optimizer, loss)
|
||||||
|
|
||||||
|
inputs = np.zeros((10, 3))
|
||||||
|
targets = np.random.randint(0, 4, size=10, dtype=np.int32)
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
||||||
|
dataset = dataset.repeat(100)
|
||||||
|
dataset = dataset.batch(10)
|
||||||
|
|
||||||
|
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
|
||||||
|
|
||||||
def test_dataset_input_shape_validation(self):
|
def test_dataset_input_shape_validation(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||||
@ -2108,8 +2145,10 @@ class TestTrainingWithDataset(test.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
||||||
dataset = dataset.repeat(100)
|
dataset = dataset.repeat(100)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(
|
||||||
r'expected (.*?) to have 2 dimensions'):
|
ValueError,
|
||||||
|
r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)'
|
||||||
|
):
|
||||||
model.train_on_batch(dataset)
|
model.train_on_batch(dataset)
|
||||||
|
|
||||||
# Wrong input shape
|
# Wrong input shape
|
||||||
|
@ -210,10 +210,11 @@ def check_num_samples(ins,
|
|||||||
def standardize_single_array(x):
|
def standardize_single_array(x):
|
||||||
if x is None:
|
if x is None:
|
||||||
return None
|
return None
|
||||||
elif tensor_util.is_tensor(x):
|
if x.shape is not None and len(x.shape) == 1:
|
||||||
return x
|
if tensor_util.is_tensor(x):
|
||||||
elif x.ndim == 1:
|
return array_ops.expand_dims(x, axis=1)
|
||||||
x = np.expand_dims(x, 1)
|
else:
|
||||||
|
return np.expand_dims(x, 1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -341,7 +342,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: In case of invalid user-provided argument.
|
ValueError: In case of invalid user-provided argument.
|
||||||
"""
|
"""
|
||||||
if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test
|
if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test
|
||||||
return [None for _ in output_names]
|
return [None for _ in output_names]
|
||||||
if len(output_names) == 1:
|
if len(output_names) == 1:
|
||||||
if isinstance(x_weight, list) and len(x_weight) == 1:
|
if isinstance(x_weight, list) and len(x_weight) == 1:
|
||||||
@ -675,7 +676,8 @@ def standardize_weights(y,
|
|||||||
'Expected sample_weight with rank '
|
'Expected sample_weight with rank '
|
||||||
'less than or equal to ' + str(len(y.shape)))
|
'less than or equal to ' + str(len(y.shape)))
|
||||||
|
|
||||||
if y.shape[:sample_weight.ndim] != sample_weight.shape:
|
if (not tensor_util.is_tensor(sample_weight) and
|
||||||
|
y.shape[:sample_weight.ndim] != sample_weight.shape):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Found a sample_weight array with shape ' + str(sample_weight.shape) +
|
'Found a sample_weight array with shape ' + str(sample_weight.shape) +
|
||||||
' for an input with shape ' + str(y.shape) + '. '
|
' for an input with shape ' + str(y.shape) + '. '
|
||||||
@ -777,7 +779,9 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
|
|||||||
'Received: %s' % (x, y))
|
'Received: %s' % (x, y))
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
raise ValueError('`sample_weight` argument is not supported when input '
|
raise ValueError('`sample_weight` argument is not supported when input '
|
||||||
'`x` is a dataset or a dataset iterator. '
|
'`x` is a dataset or a dataset iterator. Instead, you'
|
||||||
|
'can provide sample_weight as the third element of your'
|
||||||
|
'dataset, i.e. (inputs, targets, sample_weight). '
|
||||||
'Received: x=%s, sample_weight=%s' % (x, sample_weight))
|
'Received: x=%s, sample_weight=%s' % (x, sample_weight))
|
||||||
if validation_split is not None and validation_split != 0.0:
|
if validation_split is not None and validation_split != 0.0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
Reference in New Issue
Block a user