Export unpack_x_y_sample_weight and pack_x_y_sample_weight

These are useful utilities when overriding Model.train_step.

PiperOrigin-RevId: 314477485
Change-Id: I4b425ae6a16285cbb5115562c563eeccbbb943d0
This commit is contained in:
Thomas O'Malley 2020-06-02 23:40:33 -07:00 committed by TensorFlower Gardener
parent d4bcf76529
commit ea96844d61
3 changed files with 86 additions and 4 deletions

View File

@ -42,6 +42,7 @@ keras_packages = [
"tensorflow.python.keras.datasets.mnist", "tensorflow.python.keras.datasets.mnist",
"tensorflow.python.keras.datasets.reuters", "tensorflow.python.keras.datasets.reuters",
"tensorflow.python.keras.engine.base_layer", "tensorflow.python.keras.engine.base_layer",
"tensorflow.python.keras.engine.data_adapter",
"tensorflow.python.keras.engine.input_layer", "tensorflow.python.keras.engine.input_layer",
"tensorflow.python.keras.engine.input_spec", "tensorflow.python.keras.engine.input_spec",
"tensorflow.python.keras.engine.sequential", "tensorflow.python.keras.engine.sequential",

View File

@ -51,6 +51,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import script_ops from tensorflow.python.ops import script_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
try: try:
from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top
@ -1434,8 +1435,54 @@ def train_validation_split(arrays, validation_split, shuffle=True):
return train_arrays, val_arrays return train_arrays, val_arrays
@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
def unpack_x_y_sample_weight(data): def unpack_x_y_sample_weight(data):
"""Unpacks user-provided data tuple.""" """Unpacks user-provided data tuple.
This is a convenience utility to be used when overriding
`Model.train_step`, `Model.test_step`, or `Model.predict_step`.
This utility makes it easy to support data of the form `(x,)`,
`(x, y)`, or `(x, y, sample_weight)`.
Standalone usage:
>>> features_batch = tf.ones((10, 5))
>>> labels_batch = tf.zeros((10, 5))
>>> data = (features_batch, labels_batch)
>>> # `y` and `sample_weight` will default to `None` if not provided.
>>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
>>> sample_weight is None
True
Example in overridden `Model.train_step`:
```python
class MyModel(tf.keras.Model):
def train_step(self, data):
# If `sample_weight` is not provided, all samples will be weighted
# equally.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
trainable_variables = self.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
self.compiled_metrics.update_state(y, y_pred, sample_weight)
return {m.name: m.result() for m in self.metrics}
```
Arguments:
data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
Returns:
The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
provided.
"""
if not isinstance(data, tuple): if not isinstance(data, tuple):
return (data, None, None) return (data, None, None)
elif len(data) == 1: elif len(data) == 1:
@ -1444,12 +1491,38 @@ def unpack_x_y_sample_weight(data):
return (data[0], data[1], None) return (data[0], data[1], None)
elif len(data) == 3: elif len(data) == 3:
return (data[0], data[1], data[2]) return (data[0], data[1], data[2])
else:
raise ValueError("Data not understood.") error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
"or `(x, y, sample_weight)`, found: {}").format(data)
raise ValueError(error_msg)
@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
def pack_x_y_sample_weight(x, y=None, sample_weight=None): def pack_x_y_sample_weight(x, y=None, sample_weight=None):
"""Packs user-provided data into a tuple.""" """Packs user-provided data into a tuple.
This is a convenience utility for packing data into the tuple formats
that `Model.fit` uses.
Standalone usage:
>>> x = tf.ones((10, 1))
>>> data = tf.keras.utils.pack_x_y_sample_weight(x)
>>> len(data)
1
>>> y = tf.ones((10, 1))
>>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
>>> len(data)
2
Arguments:
x: Features to pass to `Model`.
y: Ground-truth targets to pass to `Model`.
sample_weight: Sample weight for each element.
Returns:
Tuple in the format used in `Model.fit`.
"""
if y is None: if y is None:
return (x,) return (x,)
elif sample_weight is None: elif sample_weight is None:

View File

@ -72,6 +72,10 @@ tf_module {
name: "normalize" name: "normalize"
argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], " argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], "
} }
member_method {
name: "pack_x_y_sample_weight"
argspec: "args=[\'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method { member_method {
name: "plot_model" name: "plot_model"
argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], " argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], "
@ -88,4 +92,8 @@ tf_module {
name: "to_categorical" name: "to_categorical"
argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], " argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
} }
member_method {
name: "unpack_x_y_sample_weight"
argspec: "args=[\'data\'], varargs=None, keywords=None, defaults=None"
}
} }