From ea96844d61de15addf2b85930310ae7c11bda0b0 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Tue, 2 Jun 2020 23:40:33 -0700 Subject: [PATCH] Export unpack_x_y_sample_weight and pack_x_y_sample_weight These are useful utilities when overriding Model.train_step. PiperOrigin-RevId: 314477485 Change-Id: I4b425ae6a16285cbb5115562c563eeccbbb943d0 --- tensorflow/python/keras/api/BUILD | 1 + .../python/keras/engine/data_adapter.py | 81 ++++++++++++++++++- .../golden/v2/tensorflow.keras.utils.pbtxt | 8 ++ 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/api/BUILD b/tensorflow/python/keras/api/BUILD index b393c2006e3..ff54400ae15 100644 --- a/tensorflow/python/keras/api/BUILD +++ b/tensorflow/python/keras/api/BUILD @@ -42,6 +42,7 @@ keras_packages = [ "tensorflow.python.keras.datasets.mnist", "tensorflow.python.keras.datasets.reuters", "tensorflow.python.keras.engine.base_layer", + "tensorflow.python.keras.engine.data_adapter", "tensorflow.python.keras.engine.input_layer", "tensorflow.python.keras.engine.input_spec", "tensorflow.python.keras.engine.sequential", diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index fdfd0af722f..ecfcfda589b 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -51,6 +51,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import script_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import keras_export try: from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top @@ -1434,8 +1435,54 @@ def train_validation_split(arrays, validation_split, shuffle=True): return train_arrays, val_arrays +@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[]) def unpack_x_y_sample_weight(data): - """Unpacks user-provided data tuple.""" + """Unpacks user-provided data tuple. + + This is a convenience utility to be used when overriding + `Model.train_step`, `Model.test_step`, or `Model.predict_step`. + This utility makes it easy to support data of the form `(x,)`, + `(x, y)`, or `(x, y, sample_weight)`. + + Standalone usage: + + >>> features_batch = tf.ones((10, 5)) + >>> labels_batch = tf.zeros((10, 5)) + >>> data = (features_batch, labels_batch) + >>> # `y` and `sample_weight` will default to `None` if not provided. + >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) + >>> sample_weight is None + True + + Example in overridden `Model.train_step`: + + ```python + class MyModel(tf.keras.Model): + + def train_step(self, data): + # If `sample_weight` is not provided, all samples will be weighted + # equally. + x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) + loss = self.compiled_loss( + y, y_pred, sample_weight, regularization_losses=self.losses) + trainable_variables = self.trainable_variables + gradients = tape.gradient(loss, trainable_variables) + self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + return {m.name: m.result() for m in self.metrics} + ``` + + Arguments: + data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. + + Returns: + The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not + provided. + """ if not isinstance(data, tuple): return (data, None, None) elif len(data) == 1: @@ -1444,12 +1491,38 @@ def unpack_x_y_sample_weight(data): return (data[0], data[1], None) elif len(data) == 3: return (data[0], data[1], data[2]) - - raise ValueError("Data not understood.") + else: + error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, " + "or `(x, y, sample_weight)`, found: {}").format(data) + raise ValueError(error_msg) +@keras_export("keras.utils.pack_x_y_sample_weight", v1=[]) def pack_x_y_sample_weight(x, y=None, sample_weight=None): - """Packs user-provided data into a tuple.""" + """Packs user-provided data into a tuple. + + This is a convenience utility for packing data into the tuple formats + that `Model.fit` uses. + + Standalone usage: + + >>> x = tf.ones((10, 1)) + >>> data = tf.keras.utils.pack_x_y_sample_weight(x) + >>> len(data) + 1 + >>> y = tf.ones((10, 1)) + >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y) + >>> len(data) + 2 + + Arguments: + x: Features to pass to `Model`. + y: Ground-truth targets to pass to `Model`. + sample_weight: Sample weight for each element. + + Returns: + Tuple in the format used in `Model.fit`. + """ if y is None: return (x,) elif sample_weight is None: diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt index ae616a1a620..9068d986446 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt @@ -72,6 +72,10 @@ tf_module { name: "normalize" argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], " } + member_method { + name: "pack_x_y_sample_weight" + argspec: "args=[\'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } member_method { name: "plot_model" argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], " @@ -88,4 +92,8 @@ tf_module { name: "to_categorical" argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], " } + member_method { + name: "unpack_x_y_sample_weight" + argspec: "args=[\'data\'], varargs=None, keywords=None, defaults=None" + } }