Export unpack_x_y_sample_weight and pack_x_y_sample_weight
These are useful utilities when overriding Model.train_step. PiperOrigin-RevId: 314477485 Change-Id: I4b425ae6a16285cbb5115562c563eeccbbb943d0
This commit is contained in:
parent
d4bcf76529
commit
ea96844d61
|
@ -42,6 +42,7 @@ keras_packages = [
|
||||||
"tensorflow.python.keras.datasets.mnist",
|
"tensorflow.python.keras.datasets.mnist",
|
||||||
"tensorflow.python.keras.datasets.reuters",
|
"tensorflow.python.keras.datasets.reuters",
|
||||||
"tensorflow.python.keras.engine.base_layer",
|
"tensorflow.python.keras.engine.base_layer",
|
||||||
|
"tensorflow.python.keras.engine.data_adapter",
|
||||||
"tensorflow.python.keras.engine.input_layer",
|
"tensorflow.python.keras.engine.input_layer",
|
||||||
"tensorflow.python.keras.engine.input_spec",
|
"tensorflow.python.keras.engine.input_spec",
|
||||||
"tensorflow.python.keras.engine.sequential",
|
"tensorflow.python.keras.engine.sequential",
|
||||||
|
|
|
@ -51,6 +51,7 @@ from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top
|
from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top
|
||||||
|
@ -1434,8 +1435,54 @@ def train_validation_split(arrays, validation_split, shuffle=True):
|
||||||
return train_arrays, val_arrays
|
return train_arrays, val_arrays
|
||||||
|
|
||||||
|
|
||||||
|
@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
|
||||||
def unpack_x_y_sample_weight(data):
|
def unpack_x_y_sample_weight(data):
|
||||||
"""Unpacks user-provided data tuple."""
|
"""Unpacks user-provided data tuple.
|
||||||
|
|
||||||
|
This is a convenience utility to be used when overriding
|
||||||
|
`Model.train_step`, `Model.test_step`, or `Model.predict_step`.
|
||||||
|
This utility makes it easy to support data of the form `(x,)`,
|
||||||
|
`(x, y)`, or `(x, y, sample_weight)`.
|
||||||
|
|
||||||
|
Standalone usage:
|
||||||
|
|
||||||
|
>>> features_batch = tf.ones((10, 5))
|
||||||
|
>>> labels_batch = tf.zeros((10, 5))
|
||||||
|
>>> data = (features_batch, labels_batch)
|
||||||
|
>>> # `y` and `sample_weight` will default to `None` if not provided.
|
||||||
|
>>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||||
|
>>> sample_weight is None
|
||||||
|
True
|
||||||
|
|
||||||
|
Example in overridden `Model.train_step`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyModel(tf.keras.Model):
|
||||||
|
|
||||||
|
def train_step(self, data):
|
||||||
|
# If `sample_weight` is not provided, all samples will be weighted
|
||||||
|
# equally.
|
||||||
|
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||||
|
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
y_pred = self(x, training=True)
|
||||||
|
loss = self.compiled_loss(
|
||||||
|
y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||||
|
trainable_variables = self.trainable_variables
|
||||||
|
gradients = tape.gradient(loss, trainable_variables)
|
||||||
|
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
|
||||||
|
|
||||||
|
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||||
|
return {m.name: m.result() for m in self.metrics}
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
|
||||||
|
provided.
|
||||||
|
"""
|
||||||
if not isinstance(data, tuple):
|
if not isinstance(data, tuple):
|
||||||
return (data, None, None)
|
return (data, None, None)
|
||||||
elif len(data) == 1:
|
elif len(data) == 1:
|
||||||
|
@ -1444,12 +1491,38 @@ def unpack_x_y_sample_weight(data):
|
||||||
return (data[0], data[1], None)
|
return (data[0], data[1], None)
|
||||||
elif len(data) == 3:
|
elif len(data) == 3:
|
||||||
return (data[0], data[1], data[2])
|
return (data[0], data[1], data[2])
|
||||||
|
else:
|
||||||
raise ValueError("Data not understood.")
|
error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
|
||||||
|
"or `(x, y, sample_weight)`, found: {}").format(data)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
|
||||||
def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
||||||
"""Packs user-provided data into a tuple."""
|
"""Packs user-provided data into a tuple.
|
||||||
|
|
||||||
|
This is a convenience utility for packing data into the tuple formats
|
||||||
|
that `Model.fit` uses.
|
||||||
|
|
||||||
|
Standalone usage:
|
||||||
|
|
||||||
|
>>> x = tf.ones((10, 1))
|
||||||
|
>>> data = tf.keras.utils.pack_x_y_sample_weight(x)
|
||||||
|
>>> len(data)
|
||||||
|
1
|
||||||
|
>>> y = tf.ones((10, 1))
|
||||||
|
>>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
|
||||||
|
>>> len(data)
|
||||||
|
2
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: Features to pass to `Model`.
|
||||||
|
y: Ground-truth targets to pass to `Model`.
|
||||||
|
sample_weight: Sample weight for each element.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple in the format used in `Model.fit`.
|
||||||
|
"""
|
||||||
if y is None:
|
if y is None:
|
||||||
return (x,)
|
return (x,)
|
||||||
elif sample_weight is None:
|
elif sample_weight is None:
|
||||||
|
|
|
@ -72,6 +72,10 @@ tf_module {
|
||||||
name: "normalize"
|
name: "normalize"
|
||||||
argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], "
|
argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "pack_x_y_sample_weight"
|
||||||
|
argspec: "args=[\'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "plot_model"
|
name: "plot_model"
|
||||||
argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], "
|
argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], "
|
||||||
|
@ -88,4 +92,8 @@ tf_module {
|
||||||
name: "to_categorical"
|
name: "to_categorical"
|
||||||
argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
|
argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unpack_x_y_sample_weight"
|
||||||
|
argspec: "args=[\'data\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue