Small improvements to handling of Datasets in Keras.
* Allow sparse labels to work with Datasets. * Allow sample_weights to be passed as the third output of a Dataset (like how generator input is treated). PiperOrigin-RevId: 211834259
This commit is contained in:
parent
6d893ecfb9
commit
025277a159
@ -446,8 +446,7 @@ class TestWithDistributionStrategy(test.TestCase):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
||||
dataset = dataset.repeat(100)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'expected input to have 2 dimensions'):
|
||||
with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
|
||||
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
|
||||
|
||||
# Wrong input shape
|
||||
|
@ -928,11 +928,16 @@ class Model(Network):
|
||||
'Make sure that your dataset can generate '
|
||||
'required number of samples.')
|
||||
|
||||
if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
|
||||
raise ValueError('Please provide model inputs as a list or tuple of 2 '
|
||||
'elements: input and target pair. '
|
||||
'Received %s' % next_element)
|
||||
x, y = next_element
|
||||
if (not isinstance(next_element, (list, tuple)) or
|
||||
len(next_element) not in [2, 3]):
|
||||
raise ValueError(
|
||||
'Please provide model inputs as a list or tuple of 2 or 3'
|
||||
'elements: (input, target) or (input, target, sample_weights)'
|
||||
'Received %s' % next_element)
|
||||
if len(next_element) == 2:
|
||||
x, y = next_element
|
||||
else:
|
||||
x, y, sample_weight = next_element
|
||||
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
|
||||
class_weight, batch_size)
|
||||
return x, y, sample_weights
|
||||
@ -1331,7 +1336,8 @@ class Model(Network):
|
||||
(in case the model has multiple inputs).
|
||||
- A dict mapping input names to the corresponding array/tensors,
|
||||
if the model has named inputs.
|
||||
- A `tf.data` dataset or a dataset iterator.
|
||||
- A `tf.data` dataset or a dataset iterator. Should return a tuple
|
||||
of either (inputs, targets) or (inputs, targets, sample_weights).
|
||||
y: Target data. Like the input data `x`,
|
||||
it could be either Numpy array(s) or TensorFlow tensor(s).
|
||||
It should be consistent with `x` (you cannot have Numpy inputs and
|
||||
@ -1396,7 +1402,8 @@ class Model(Network):
|
||||
to apply a different weight to every timestep of every sample.
|
||||
In this case you should make sure to specify
|
||||
`sample_weight_mode="temporal"` in `compile()`. This argument is not
|
||||
supported when `x` is a dataset or a dataset iterator.
|
||||
supported when `x` is a dataset or a dataset iterator, instead
|
||||
provide the sample_weights as the third element of `x`.
|
||||
initial_epoch: Integer.
|
||||
Epoch at which to start training
|
||||
(useful for resuming a previous training run).
|
||||
|
@ -417,11 +417,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
|
||||
"""
|
||||
assert isinstance(inputs, iterator_ops.EagerIterator)
|
||||
if not isinstance(inputs.output_shapes,
|
||||
(list, tuple)) or len(inputs.output_shapes) > 2:
|
||||
(list, tuple)) or len(inputs.output_shapes) > 3:
|
||||
raise ValueError(
|
||||
'Please provide data as a list or tuple of 1 or 2 elements '
|
||||
' - input or input and target pair. Received %s. We do not use the '
|
||||
'`target` value here.' % inputs.output_shapes)
|
||||
'Please provide data as a list or tuple of 1, 2, or 3 elements '
|
||||
' - `(input)`, or `(input, target)`, or `(input, target,'
|
||||
'sample_weights)`. Received %s. We do not use the `target` or'
|
||||
'`sample_weights` value here.' % inputs.output_shapes)
|
||||
outs = []
|
||||
if verbose == 1:
|
||||
progbar = generic_utils.Progbar(target=steps)
|
||||
|
@ -2097,6 +2097,43 @@ class TestTrainingWithDataset(test.TestCase):
|
||||
'you should specify the `steps` argument'):
|
||||
model.predict(dataset, verbose=0)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_dataset_with_sample_weights(self):
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'mse'
|
||||
metrics = ['mae', metrics_module.CategoricalAccuracy()]
|
||||
model.compile(optimizer, loss, metrics=metrics)
|
||||
|
||||
inputs = np.zeros((10, 3), np.float32)
|
||||
targets = np.zeros((10, 4), np.float32)
|
||||
sample_weights = np.ones((10), np.float32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
|
||||
sample_weights))
|
||||
dataset = dataset.repeat(100)
|
||||
dataset = dataset.batch(10)
|
||||
|
||||
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
|
||||
model.evaluate(dataset, steps=2, verbose=1)
|
||||
model.predict(dataset, steps=2)
|
||||
model.train_on_batch(dataset)
|
||||
model.predict_on_batch(dataset)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_dataset_with_sparse_labels(self):
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
optimizer = RMSPropOptimizer(learning_rate=0.001)
|
||||
loss = 'sparse_categorical_crossentropy'
|
||||
model.compile(optimizer, loss)
|
||||
|
||||
inputs = np.zeros((10, 3))
|
||||
targets = np.random.randint(0, 4, size=10, dtype=np.int32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
||||
dataset = dataset.repeat(100)
|
||||
dataset = dataset.batch(10)
|
||||
|
||||
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
|
||||
|
||||
def test_dataset_input_shape_validation(self):
|
||||
with self.test_session():
|
||||
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
|
||||
@ -2108,8 +2145,10 @@ class TestTrainingWithDataset(test.TestCase):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
||||
dataset = dataset.repeat(100)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r'expected (.*?) to have 2 dimensions'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)'
|
||||
):
|
||||
model.train_on_batch(dataset)
|
||||
|
||||
# Wrong input shape
|
||||
|
@ -210,10 +210,11 @@ def check_num_samples(ins,
|
||||
def standardize_single_array(x):
|
||||
if x is None:
|
||||
return None
|
||||
elif tensor_util.is_tensor(x):
|
||||
return x
|
||||
elif x.ndim == 1:
|
||||
x = np.expand_dims(x, 1)
|
||||
if x.shape is not None and len(x.shape) == 1:
|
||||
if tensor_util.is_tensor(x):
|
||||
return array_ops.expand_dims(x, axis=1)
|
||||
else:
|
||||
return np.expand_dims(x, 1)
|
||||
return x
|
||||
|
||||
|
||||
@ -341,7 +342,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
|
||||
Raises:
|
||||
ValueError: In case of invalid user-provided argument.
|
||||
"""
|
||||
if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test
|
||||
if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test
|
||||
return [None for _ in output_names]
|
||||
if len(output_names) == 1:
|
||||
if isinstance(x_weight, list) and len(x_weight) == 1:
|
||||
@ -675,7 +676,8 @@ def standardize_weights(y,
|
||||
'Expected sample_weight with rank '
|
||||
'less than or equal to ' + str(len(y.shape)))
|
||||
|
||||
if y.shape[:sample_weight.ndim] != sample_weight.shape:
|
||||
if (not tensor_util.is_tensor(sample_weight) and
|
||||
y.shape[:sample_weight.ndim] != sample_weight.shape):
|
||||
raise ValueError(
|
||||
'Found a sample_weight array with shape ' + str(sample_weight.shape) +
|
||||
' for an input with shape ' + str(y.shape) + '. '
|
||||
@ -777,7 +779,9 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
|
||||
'Received: %s' % (x, y))
|
||||
if sample_weight is not None:
|
||||
raise ValueError('`sample_weight` argument is not supported when input '
|
||||
'`x` is a dataset or a dataset iterator. '
|
||||
'`x` is a dataset or a dataset iterator. Instead, you'
|
||||
'can provide sample_weight as the third element of your'
|
||||
'dataset, i.e. (inputs, targets, sample_weight). '
|
||||
'Received: x=%s, sample_weight=%s' % (x, sample_weight))
|
||||
if validation_split is not None and validation_split != 0.0:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user