Export unpack_x_y_sample_weight and pack_x_y_sample_weight
These are useful utilities when overriding Model.train_step. PiperOrigin-RevId: 314477485 Change-Id: I4b425ae6a16285cbb5115562c563eeccbbb943d0
This commit is contained in:
		
							parent
							
								
									d4bcf76529
								
							
						
					
					
						commit
						ea96844d61
					
				| @ -42,6 +42,7 @@ keras_packages = [ | |||||||
|     "tensorflow.python.keras.datasets.mnist", |     "tensorflow.python.keras.datasets.mnist", | ||||||
|     "tensorflow.python.keras.datasets.reuters", |     "tensorflow.python.keras.datasets.reuters", | ||||||
|     "tensorflow.python.keras.engine.base_layer", |     "tensorflow.python.keras.engine.base_layer", | ||||||
|  |     "tensorflow.python.keras.engine.data_adapter", | ||||||
|     "tensorflow.python.keras.engine.input_layer", |     "tensorflow.python.keras.engine.input_layer", | ||||||
|     "tensorflow.python.keras.engine.input_spec", |     "tensorflow.python.keras.engine.input_spec", | ||||||
|     "tensorflow.python.keras.engine.sequential", |     "tensorflow.python.keras.engine.sequential", | ||||||
|  | |||||||
| @ -51,6 +51,7 @@ from tensorflow.python.ops import random_ops | |||||||
| from tensorflow.python.ops import script_ops | from tensorflow.python.ops import script_ops | ||||||
| from tensorflow.python.platform import tf_logging as logging | from tensorflow.python.platform import tf_logging as logging | ||||||
| from tensorflow.python.util import nest | from tensorflow.python.util import nest | ||||||
|  | from tensorflow.python.util.tf_export import keras_export | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|   from scipy import sparse as scipy_sparse  # pylint: disable=g-import-not-at-top |   from scipy import sparse as scipy_sparse  # pylint: disable=g-import-not-at-top | ||||||
| @ -1434,8 +1435,54 @@ def train_validation_split(arrays, validation_split, shuffle=True): | |||||||
|   return train_arrays, val_arrays |   return train_arrays, val_arrays | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @keras_export("keras.utils.unpack_x_y_sample_weight", v1=[]) | ||||||
| def unpack_x_y_sample_weight(data): | def unpack_x_y_sample_weight(data): | ||||||
|   """Unpacks user-provided data tuple.""" |   """Unpacks user-provided data tuple. | ||||||
|  | 
 | ||||||
|  |   This is a convenience utility to be used when overriding | ||||||
|  |   `Model.train_step`, `Model.test_step`, or `Model.predict_step`. | ||||||
|  |   This utility makes it easy to support data of the form `(x,)`, | ||||||
|  |   `(x, y)`, or `(x, y, sample_weight)`. | ||||||
|  | 
 | ||||||
|  |   Standalone usage: | ||||||
|  | 
 | ||||||
|  |   >>> features_batch = tf.ones((10, 5)) | ||||||
|  |   >>> labels_batch = tf.zeros((10, 5)) | ||||||
|  |   >>> data = (features_batch, labels_batch) | ||||||
|  |   >>> # `y` and `sample_weight` will default to `None` if not provided. | ||||||
|  |   >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) | ||||||
|  |   >>> sample_weight is None | ||||||
|  |   True | ||||||
|  | 
 | ||||||
|  |   Example in overridden `Model.train_step`: | ||||||
|  | 
 | ||||||
|  |   ```python | ||||||
|  |   class MyModel(tf.keras.Model): | ||||||
|  | 
 | ||||||
|  |     def train_step(self, data): | ||||||
|  |       # If `sample_weight` is not provided, all samples will be weighted | ||||||
|  |       # equally. | ||||||
|  |       x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) | ||||||
|  | 
 | ||||||
|  |       with tf.GradientTape() as tape: | ||||||
|  |         y_pred = self(x, training=True) | ||||||
|  |         loss = self.compiled_loss( | ||||||
|  |           y, y_pred, sample_weight, regularization_losses=self.losses) | ||||||
|  |         trainable_variables = self.trainable_variables | ||||||
|  |         gradients = tape.gradient(loss, trainable_variables) | ||||||
|  |         self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | ||||||
|  | 
 | ||||||
|  |       self.compiled_metrics.update_state(y, y_pred, sample_weight) | ||||||
|  |       return {m.name: m.result() for m in self.metrics} | ||||||
|  |   ``` | ||||||
|  | 
 | ||||||
|  |   Arguments: | ||||||
|  |     data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not | ||||||
|  |     provided. | ||||||
|  |   """ | ||||||
|   if not isinstance(data, tuple): |   if not isinstance(data, tuple): | ||||||
|     return (data, None, None) |     return (data, None, None) | ||||||
|   elif len(data) == 1: |   elif len(data) == 1: | ||||||
| @ -1444,12 +1491,38 @@ def unpack_x_y_sample_weight(data): | |||||||
|     return (data[0], data[1], None) |     return (data[0], data[1], None) | ||||||
|   elif len(data) == 3: |   elif len(data) == 3: | ||||||
|     return (data[0], data[1], data[2]) |     return (data[0], data[1], data[2]) | ||||||
| 
 |   else: | ||||||
|   raise ValueError("Data not understood.") |     error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, " | ||||||
|  |                  "or `(x, y, sample_weight)`, found: {}").format(data) | ||||||
|  |     raise ValueError(error_msg) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @keras_export("keras.utils.pack_x_y_sample_weight", v1=[]) | ||||||
| def pack_x_y_sample_weight(x, y=None, sample_weight=None): | def pack_x_y_sample_weight(x, y=None, sample_weight=None): | ||||||
|   """Packs user-provided data into a tuple.""" |   """Packs user-provided data into a tuple. | ||||||
|  | 
 | ||||||
|  |   This is a convenience utility for packing data into the tuple formats | ||||||
|  |   that `Model.fit` uses. | ||||||
|  | 
 | ||||||
|  |   Standalone usage: | ||||||
|  | 
 | ||||||
|  |   >>> x = tf.ones((10, 1)) | ||||||
|  |   >>> data = tf.keras.utils.pack_x_y_sample_weight(x) | ||||||
|  |   >>> len(data) | ||||||
|  |   1 | ||||||
|  |   >>> y = tf.ones((10, 1)) | ||||||
|  |   >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y) | ||||||
|  |   >>> len(data) | ||||||
|  |   2 | ||||||
|  | 
 | ||||||
|  |   Arguments: | ||||||
|  |     x: Features to pass to `Model`. | ||||||
|  |     y: Ground-truth targets to pass to `Model`. | ||||||
|  |     sample_weight: Sample weight for each element. | ||||||
|  | 
 | ||||||
|  |   Returns: | ||||||
|  |     Tuple in the format used in `Model.fit`. | ||||||
|  |   """ | ||||||
|   if y is None: |   if y is None: | ||||||
|     return (x,) |     return (x,) | ||||||
|   elif sample_weight is None: |   elif sample_weight is None: | ||||||
|  | |||||||
| @ -72,6 +72,10 @@ tf_module { | |||||||
|     name: "normalize" |     name: "normalize" | ||||||
|     argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], " |     argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], " | ||||||
|   } |   } | ||||||
|  |   member_method { | ||||||
|  |     name: "pack_x_y_sample_weight" | ||||||
|  |     argspec: "args=[\'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " | ||||||
|  |   } | ||||||
|   member_method { |   member_method { | ||||||
|     name: "plot_model" |     name: "plot_model" | ||||||
|     argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], " |     argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\', \'expand_nested\', \'dpi\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\', \'False\', \'96\'], " | ||||||
| @ -88,4 +92,8 @@ tf_module { | |||||||
|     name: "to_categorical" |     name: "to_categorical" | ||||||
|     argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], " |     argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], " | ||||||
|   } |   } | ||||||
|  |   member_method { | ||||||
|  |     name: "unpack_x_y_sample_weight" | ||||||
|  |     argspec: "args=[\'data\'], varargs=None, keywords=None, defaults=None" | ||||||
|  |   } | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user