Export unpack_x_y_sample_weight and pack_x_y_sample_weight
These are useful utilities when overriding Model.train_step. PiperOrigin-RevId: 314477485 Change-Id: I4b425ae6a16285cbb5115562c563eeccbbb943d0
This commit is contained in:
parent
d4bcf76529
commit
ea96844d61
|
@ -42,6 +42,7 @@ keras_packages = [
|
|||
"tensorflow.python.keras.datasets.mnist",
|
||||
"tensorflow.python.keras.datasets.reuters",
|
||||
"tensorflow.python.keras.engine.base_layer",
|
||||
"tensorflow.python.keras.engine.data_adapter",
|
||||
"tensorflow.python.keras.engine.input_layer",
|
||||
"tensorflow.python.keras.engine.input_spec",
|
||||
"tensorflow.python.keras.engine.sequential",
|
||||
|
|
|
@ -51,6 +51,7 @@ from tensorflow.python.ops import random_ops
|
|||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
try:
|
||||
from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top
|
||||
|
@ -1434,8 +1435,54 @@ def train_validation_split(arrays, validation_split, shuffle=True):
|
|||
return train_arrays, val_arrays
|
||||
|
||||
|
||||
@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
|
||||
def unpack_x_y_sample_weight(data):
|
||||
"""Unpacks user-provided data tuple."""
|
||||
"""Unpacks user-provided data tuple.
|
||||
|
||||
This is a convenience utility to be used when overriding
|
||||
`Model.train_step`, `Model.test_step`, or `Model.predict_step`.
|
||||
This utility makes it easy to support data of the form `(x,)`,
|
||||
`(x, y)`, or `(x, y, sample_weight)`.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> features_batch = tf.ones((10, 5))
|
||||
>>> labels_batch = tf.zeros((10, 5))
|
||||
>>> data = (features_batch, labels_batch)
|
||||
>>> # `y` and `sample_weight` will default to `None` if not provided.
|
||||
>>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||
>>> sample_weight is None
|
||||
True
|
||||
|
||||
Example in overridden `Model.train_step`:
|
||||
|
||||
```python
|
||||
class MyModel(tf.keras.Model):
|
||||
|
||||
def train_step(self, data):
|
||||
# If `sample_weight` is not provided, all samples will be weighted
|
||||
# equally.
|
||||
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = self(x, training=True)
|
||||
loss = self.compiled_loss(
|
||||
y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
trainable_variables = self.trainable_variables
|
||||
gradients = tape.gradient(loss, trainable_variables)
|
||||
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
|
||||
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
return {m.name: m.result() for m in self.metrics}
|
||||
```
|
||||
|
||||
Arguments:
|
||||
data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
|
||||
|
||||
Returns:
|
||||
The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
|
||||
provided.
|
||||
"""
|
||||
if not isinstance(data, tuple):
|
||||
return (data, None, None)
|
||||
elif len(data) == 1:
|
||||
|
@ -1444,12 +1491,38 @@ def unpack_x_y_sample_weight(data):
|
|||
return (data[0], data[1], None)
|
||||
elif len(data) == 3:
|
||||
return (data[0], data[1], data[2])
|
||||
|
||||
raise ValueError("Data not understood.")
|
||||
else:
|
||||
error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
|
||||
"or `(x, y, sample_weight)`, found: {}").format(data)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
|
||||
def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
||||
"""Packs user-provided data into a tuple."""
|
||||
"""Packs user-provided data into a tuple.
|
||||
|
||||
This is a convenience utility for packing data into the tuple formats
|
||||
that `Model.fit` uses.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> x = tf.ones((10, 1))
|
||||
>>> data = tf.keras.utils.pack_x_y_sample_weight(x)
|
||||
>>> len(data)
|
||||
1
|
||||
>>> y = tf.ones((10, 1))
|
||||
>>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
|
||||
>>> len(data)
|
||||
2
|
||||
|
||||
Arguments:
|
||||
x: Features to pass to `Model`.
|
||||
y: Ground-truth targets to pass to `Model`.
|
||||
sample_weight: Sample weight for each element.
|
||||
|
||||
Returns:
|
||||
Tuple in the format used in `Model.fit`.
|
||||
"""
|
||||
if y is None:
|
||||
return (x,)
|
||||
elif sample_weight is None:
|
||||
|
|
|
@ -72,6 +72,10 @@ tf_module {
|
|||
name: "normalize"
|
||||
argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "pack_x_y_sample_weight"
|
||||
argspec: "args=[\'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "plot_model"
|
||||
argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], "
|
||||
|
@ -88,4 +92,8 @@ tf_module {
|
|||
name: "to_categorical"
|
||||
argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "unpack_x_y_sample_weight"
|
||||
argspec: "args=[\'data\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue