Export unpack_x_y_sample_weight and pack_x_y_sample_weight
These are useful utilities when overriding Model.train_step. PiperOrigin-RevId: 314477485 Change-Id: I4b425ae6a16285cbb5115562c563eeccbbb943d0
This commit is contained in:
		
							parent
							
								
									d4bcf76529
								
							
						
					
					
						commit
						ea96844d61
					
				@ -42,6 +42,7 @@ keras_packages = [
 | 
			
		||||
    "tensorflow.python.keras.datasets.mnist",
 | 
			
		||||
    "tensorflow.python.keras.datasets.reuters",
 | 
			
		||||
    "tensorflow.python.keras.engine.base_layer",
 | 
			
		||||
    "tensorflow.python.keras.engine.data_adapter",
 | 
			
		||||
    "tensorflow.python.keras.engine.input_layer",
 | 
			
		||||
    "tensorflow.python.keras.engine.input_spec",
 | 
			
		||||
    "tensorflow.python.keras.engine.sequential",
 | 
			
		||||
 | 
			
		||||
@ -51,6 +51,7 @@ from tensorflow.python.ops import random_ops
 | 
			
		||||
from tensorflow.python.ops import script_ops
 | 
			
		||||
from tensorflow.python.platform import tf_logging as logging
 | 
			
		||||
from tensorflow.python.util import nest
 | 
			
		||||
from tensorflow.python.util.tf_export import keras_export
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
  from scipy import sparse as scipy_sparse  # pylint: disable=g-import-not-at-top
 | 
			
		||||
@ -1434,8 +1435,54 @@ def train_validation_split(arrays, validation_split, shuffle=True):
 | 
			
		||||
  return train_arrays, val_arrays
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[])
 | 
			
		||||
def unpack_x_y_sample_weight(data):
 | 
			
		||||
  """Unpacks user-provided data tuple."""
 | 
			
		||||
  """Unpacks user-provided data tuple.
 | 
			
		||||
 | 
			
		||||
  This is a convenience utility to be used when overriding
 | 
			
		||||
  `Model.train_step`, `Model.test_step`, or `Model.predict_step`.
 | 
			
		||||
  This utility makes it easy to support data of the form `(x,)`,
 | 
			
		||||
  `(x, y)`, or `(x, y, sample_weight)`.
 | 
			
		||||
 | 
			
		||||
  Standalone usage:
 | 
			
		||||
 | 
			
		||||
  >>> features_batch = tf.ones((10, 5))
 | 
			
		||||
  >>> labels_batch = tf.zeros((10, 5))
 | 
			
		||||
  >>> data = (features_batch, labels_batch)
 | 
			
		||||
  >>> # `y` and `sample_weight` will default to `None` if not provided.
 | 
			
		||||
  >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
 | 
			
		||||
  >>> sample_weight is None
 | 
			
		||||
  True
 | 
			
		||||
 | 
			
		||||
  Example in overridden `Model.train_step`:
 | 
			
		||||
 | 
			
		||||
  ```python
 | 
			
		||||
  class MyModel(tf.keras.Model):
 | 
			
		||||
 | 
			
		||||
    def train_step(self, data):
 | 
			
		||||
      # If `sample_weight` is not provided, all samples will be weighted
 | 
			
		||||
      # equally.
 | 
			
		||||
      x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
 | 
			
		||||
 | 
			
		||||
      with tf.GradientTape() as tape:
 | 
			
		||||
        y_pred = self(x, training=True)
 | 
			
		||||
        loss = self.compiled_loss(
 | 
			
		||||
          y, y_pred, sample_weight, regularization_losses=self.losses)
 | 
			
		||||
        trainable_variables = self.trainable_variables
 | 
			
		||||
        gradients = tape.gradient(loss, trainable_variables)
 | 
			
		||||
        self.optimizer.apply_gradients(zip(gradients, trainable_variables))
 | 
			
		||||
 | 
			
		||||
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
 | 
			
		||||
      return {m.name: m.result() for m in self.metrics}
 | 
			
		||||
  ```
 | 
			
		||||
 | 
			
		||||
  Arguments:
 | 
			
		||||
    data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
 | 
			
		||||
    provided.
 | 
			
		||||
  """
 | 
			
		||||
  if not isinstance(data, tuple):
 | 
			
		||||
    return (data, None, None)
 | 
			
		||||
  elif len(data) == 1:
 | 
			
		||||
@ -1444,12 +1491,38 @@ def unpack_x_y_sample_weight(data):
 | 
			
		||||
    return (data[0], data[1], None)
 | 
			
		||||
  elif len(data) == 3:
 | 
			
		||||
    return (data[0], data[1], data[2])
 | 
			
		||||
 | 
			
		||||
  raise ValueError("Data not understood.")
 | 
			
		||||
  else:
 | 
			
		||||
    error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
 | 
			
		||||
                 "or `(x, y, sample_weight)`, found: {}").format(data)
 | 
			
		||||
    raise ValueError(error_msg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@keras_export("keras.utils.pack_x_y_sample_weight", v1=[])
 | 
			
		||||
def pack_x_y_sample_weight(x, y=None, sample_weight=None):
 | 
			
		||||
  """Packs user-provided data into a tuple."""
 | 
			
		||||
  """Packs user-provided data into a tuple.
 | 
			
		||||
 | 
			
		||||
  This is a convenience utility for packing data into the tuple formats
 | 
			
		||||
  that `Model.fit` uses.
 | 
			
		||||
 | 
			
		||||
  Standalone usage:
 | 
			
		||||
 | 
			
		||||
  >>> x = tf.ones((10, 1))
 | 
			
		||||
  >>> data = tf.keras.utils.pack_x_y_sample_weight(x)
 | 
			
		||||
  >>> len(data)
 | 
			
		||||
  1
 | 
			
		||||
  >>> y = tf.ones((10, 1))
 | 
			
		||||
  >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
 | 
			
		||||
  >>> len(data)
 | 
			
		||||
  2
 | 
			
		||||
 | 
			
		||||
  Arguments:
 | 
			
		||||
    x: Features to pass to `Model`.
 | 
			
		||||
    y: Ground-truth targets to pass to `Model`.
 | 
			
		||||
    sample_weight: Sample weight for each element.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    Tuple in the format used in `Model.fit`.
 | 
			
		||||
  """
 | 
			
		||||
  if y is None:
 | 
			
		||||
    return (x,)
 | 
			
		||||
  elif sample_weight is None:
 | 
			
		||||
 | 
			
		||||
@ -72,6 +72,10 @@ tf_module {
 | 
			
		||||
    name: "normalize"
 | 
			
		||||
    argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], "
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "pack_x_y_sample_weight"
 | 
			
		||||
    argspec: "args=[\'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "plot_model"
 | 
			
		||||
    argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], "
 | 
			
		||||
@ -88,4 +92,8 @@ tf_module {
 | 
			
		||||
    name: "to_categorical"
 | 
			
		||||
    argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "unpack_x_y_sample_weight"
 | 
			
		||||
    argspec: "args=[\'data\'], varargs=None, keywords=None, defaults=None"
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user