Docstring fixes in Keras datasets, callbacks, and utilities

PiperOrigin-RevId: 304544666
Change-Id: I6049cf11ff6fe8090dfadd36cfcebe4b2443bf5f
This commit is contained in:
Francois Chollet 2020-04-02 22:14:06 -07:00 committed by TensorFlower Gardener
parent 394b6bb711
commit 8a370a0077
10 changed files with 182 additions and 191 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
# pylint: disable=g-classes-have-attributes
"""Callbacks: utilities called at certain points during model training. """Callbacks: utilities called at certain points during model training.
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -336,8 +337,8 @@ class CallbackList(object):
This function should only be called during TRAIN mode. This function should only be called during TRAIN mode.
Arguments: Arguments:
epoch: integer, index of epoch. epoch: Integer, index of epoch.
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
logs = logs or {} logs = logs or {}
@ -357,8 +358,8 @@ class CallbackList(object):
This function should only be called during TRAIN mode. This function should only be called during TRAIN mode.
Arguments: Arguments:
epoch: integer, index of epoch. epoch: Integer, index of epoch.
logs: dict, metric results for this training epoch, and for the logs: Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys validation epoch if validation is performed. Validation result keys
are prefixed with `val_`. are prefixed with `val_`.
""" """
@ -376,8 +377,8 @@ class CallbackList(object):
"""Calls the `on_train_batch_begin` methods of its callbacks. """Calls the `on_train_batch_begin` methods of its callbacks.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch. number and the size of the batch.
""" """
# TODO(b/150629188): Make ProgBarLogger callback not use batch hooks # TODO(b/150629188): Make ProgBarLogger callback not use batch hooks
@ -389,8 +390,8 @@ class CallbackList(object):
"""Calls the `on_train_batch_end` methods of its callbacks. """Calls the `on_train_batch_end` methods of its callbacks.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch. logs: Dict. Aggregated metric results up until this batch.
""" """
if self._should_call_train_batch_hooks: if self._should_call_train_batch_hooks:
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
@ -399,8 +400,8 @@ class CallbackList(object):
"""Calls the `on_test_batch_begin` methods of its callbacks. """Calls the `on_test_batch_begin` methods of its callbacks.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch. number and the size of the batch.
""" """
if self._should_call_test_batch_hooks: if self._should_call_test_batch_hooks:
@ -410,8 +411,8 @@ class CallbackList(object):
"""Calls the `on_test_batch_end` methods of its callbacks. """Calls the `on_test_batch_end` methods of its callbacks.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch. logs: Dict. Aggregated metric results up until this batch.
""" """
if self._should_call_test_batch_hooks: if self._should_call_test_batch_hooks:
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs) self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
@ -420,8 +421,8 @@ class CallbackList(object):
"""Calls the `on_predict_batch_begin` methods of its callbacks. """Calls the `on_predict_batch_begin` methods of its callbacks.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch. number and the size of the batch.
""" """
if self._should_call_predict_batch_hooks: if self._should_call_predict_batch_hooks:
@ -431,8 +432,8 @@ class CallbackList(object):
"""Calls the `on_predict_batch_end` methods of its callbacks. """Calls the `on_predict_batch_end` methods of its callbacks.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch. logs: Dict. Aggregated metric results up until this batch.
""" """
if self._should_call_predict_batch_hooks: if self._should_call_predict_batch_hooks:
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
@ -441,7 +442,7 @@ class CallbackList(object):
"""Calls the `on_train_begin` methods of its callbacks. """Calls the `on_train_begin` methods of its callbacks.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
logs = logs or {} logs = logs or {}
@ -458,7 +459,7 @@ class CallbackList(object):
"""Calls the `on_train_end` methods of its callbacks. """Calls the `on_train_end` methods of its callbacks.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
logs = logs or {} logs = logs or {}
@ -475,7 +476,7 @@ class CallbackList(object):
"""Calls the `on_test_begin` methods of its callbacks. """Calls the `on_test_begin` methods of its callbacks.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
logs = logs or {} logs = logs or {}
@ -492,7 +493,7 @@ class CallbackList(object):
"""Calls the `on_test_end` methods of its callbacks. """Calls the `on_test_end` methods of its callbacks.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
logs = logs or {} logs = logs or {}
@ -509,7 +510,7 @@ class CallbackList(object):
"""Calls the 'on_predict_begin` methods of its callbacks. """Calls the 'on_predict_begin` methods of its callbacks.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
logs = logs or {} logs = logs or {}
@ -526,7 +527,7 @@ class CallbackList(object):
"""Calls the `on_predict_end` methods of its callbacks. """Calls the `on_predict_end` methods of its callbacks.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
logs = logs or {} logs = logs or {}
@ -548,9 +549,9 @@ class Callback(object):
"""Abstract base class used to build new callbacks. """Abstract base class used to build new callbacks.
Attributes: Attributes:
params: dict. Training parameters params: Dict. Training parameters
(eg. verbosity, batch size, number of epochs...). (eg. verbosity, batch size, number of epochs...).
model: instance of `keras.models.Model`. model: Instance of `keras.models.Model`.
Reference of the model being trained. Reference of the model being trained.
The `logs` dictionary that callback methods The `logs` dictionary that callback methods
@ -591,8 +592,8 @@ class Callback(object):
be called during TRAIN mode. be called during TRAIN mode.
Arguments: Arguments:
epoch: integer, index of epoch. epoch: Integer, index of epoch.
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
@ -604,8 +605,8 @@ class Callback(object):
be called during TRAIN mode. be called during TRAIN mode.
Arguments: Arguments:
epoch: integer, index of epoch. epoch: Integer, index of epoch.
logs: dict, metric results for this training epoch, and for the logs: Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys validation epoch if validation is performed. Validation result keys
are prefixed with `val_`. are prefixed with `val_`.
""" """
@ -618,8 +619,8 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch. number and the size of the batch.
""" """
# For backwards compatibility. # For backwards compatibility.
@ -633,8 +634,8 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch. logs: Dict. Aggregated metric results up until this batch.
""" """
# For backwards compatibility. # For backwards compatibility.
self.on_batch_end(batch, logs=logs) self.on_batch_end(batch, logs=logs)
@ -650,8 +651,8 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch. number and the size of the batch.
""" """
@ -666,8 +667,8 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch. logs: Dict. Aggregated metric results up until this batch.
""" """
@doc_controls.for_subclass_implementers @doc_controls.for_subclass_implementers
@ -678,8 +679,8 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch. number and the size of the batch.
""" """
@ -691,8 +692,8 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
batch: integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch. logs: Dict. Aggregated metric results up until this batch.
""" """
@doc_controls.for_subclass_implementers @doc_controls.for_subclass_implementers
@ -702,7 +703,7 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
@ -713,7 +714,7 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
@ -724,7 +725,7 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
@ -735,7 +736,7 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
@ -746,7 +747,7 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
@ -757,7 +758,7 @@ class Callback(object):
Subclasses should override for any actions to run. Subclasses should override for any actions to run.
Arguments: Arguments:
logs: dict. Currently no data is passed to this argument for this method logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future. but that may change in the future.
""" """
@ -847,7 +848,7 @@ class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout. """Callback that prints metrics to stdout.
Arguments: Arguments:
count_mode: One of "steps" or "samples". count_mode: One of `"steps"` or `"samples"`.
Whether the progress bar should Whether the progress bar should
count samples seen or steps (batches) seen. count samples seen or steps (batches) seen.
stateful_metrics: Iterable of string names of metrics that stateful_metrics: Iterable of string names of metrics that
@ -1411,7 +1412,7 @@ class EarlyStopping(Callback):
"""Stop training when a monitored metric has stopped improving. """Stop training when a monitored metric has stopped improving.
Assuming the goal of a training is to minimize the loss. With this, the Assuming the goal of a training is to minimize the loss. With this, the
metric to be monitored would be 'loss', and mode would be 'min'. A metric to be monitored would be `'loss'`, and mode would be `'min'`. A
`model.fit()` training loop will check at end of every epoch whether `model.fit()` training loop will check at end of every epoch whether
the loss is no longer decreasing, considering the `min_delta` and the loss is no longer decreasing, considering the `min_delta` and
`patience` if applicable. Once it's found no longer decreasing, `patience` if applicable. Once it's found no longer decreasing,
@ -1420,6 +1421,30 @@ class EarlyStopping(Callback):
The quantity to be monitored needs to be available in `logs` dict. The quantity to be monitored needs to be available in `logs` dict.
To make it so, pass the loss or metrics at `model.compile()`. To make it so, pass the loss or metrics at `model.compile()`.
Arguments:
monitor: Quantity to be monitored.
min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
patience: Number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
mode: One of `{"auto", "min", "max"}`. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `"max"`
mode it will stop when the quantity
monitored has stopped increasing; in `"auto"`
mode, the direction is automatically inferred
from the name of the monitored quantity.
baseline: Baseline value for the monitored quantity.
Training will stop if the model doesn't show improvement over the
baseline.
restore_best_weights: Whether to restore model weights from
the epoch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
Example: Example:
>>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
@ -1442,32 +1467,6 @@ class EarlyStopping(Callback):
mode='auto', mode='auto',
baseline=None, baseline=None,
restore_best_weights=False): restore_best_weights=False):
"""Initialize an EarlyStopping callback.
Arguments:
monitor: Quantity to be monitored.
min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
patience: Number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
mode: One of `{"auto", "min", "max"}`. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
baseline: Baseline value for the monitored quantity.
Training will stop if the model doesn't show improvement over the
baseline.
restore_best_weights: Whether to restore model weights from
the epoch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
"""
super(EarlyStopping, self).__init__() super(EarlyStopping, self).__init__()
self.monitor = monitor self.monitor = monitor
@ -1550,18 +1549,19 @@ class RemoteMonitor(Callback):
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
HTTP POST, with a `data` argument which is a HTTP POST, with a `data` argument which is a
JSON-encoded dictionary of event data. JSON-encoded dictionary of event data.
If send_as_json is set to True, the content type of the request will be If `send_as_json=True`, the content type of the request will be
application/json. Otherwise the serialized JSON will be sent within a form. `"application/json"`.
Otherwise the serialized JSON will be sent within a form.
Arguments: Arguments:
root: String; root url of the target server. root: String; root url of the target server.
path: String; path relative to `root` to which the events will be sent. path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored. field: String; JSON field under which the data will be stored.
The field is used only if the payload is sent within a form The field is used only if the payload is sent within a form
(i.e. send_as_json is set to False). (i.e. send_as_json is set to False).
headers: Dictionary; optional custom HTTP headers. headers: Dictionary; optional custom HTTP headers.
send_as_json: Boolean; whether the request should be send_as_json: Boolean; whether the request should be
sent as application/json. sent as `"application/json"`.
""" """
def __init__(self, def __init__(self,
@ -1613,6 +1613,12 @@ class LearningRateScheduler(Callback):
and current learning rate, and applies the updated learning rate and current learning rate, and applies the updated learning rate
on the optimizer. on the optimizer.
Arguments:
schedule: a function that takes an epoch index (integer, indexed from 0)
and current learning rate (float) as inputs and returns a new
learning rate as output (float).
verbose: int. 0: quiet, 1: update messages.
Example: Example:
>>> # This function keeps the initial learning rate for the first ten epochs >>> # This function keeps the initial learning rate for the first ten epochs
@ -1637,14 +1643,6 @@ class LearningRateScheduler(Callback):
""" """
def __init__(self, schedule, verbose=0): def __init__(self, schedule, verbose=0):
"""Initialize a `keras.callbacks.LearningRateScheduler` callback.
Arguments:
schedule: a function that takes an epoch index (integer, indexed from 0)
and current learning rate (float) as inputs and returns a new
learning rate as output (float).
verbose: int. 0: quiet, 1: update messages.
"""
super(LearningRateScheduler, self).__init__() super(LearningRateScheduler, self).__init__()
self.schedule = schedule self.schedule = schedule
self.verbose = verbose self.verbose = verbose
@ -1689,7 +1687,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
If you have installed TensorFlow with pip, you should be able If you have installed TensorFlow with pip, you should be able
to launch TensorBoard from the command line: to launch TensorBoard from the command line:
```sh ```
tensorboard --logdir=path_to_your_logs tensorboard --logdir=path_to_your_logs
``` ```
@ -1697,24 +1695,27 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
Example (Basic): Example (Basic):
```python ```python
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations. # run the tensorboard command to view the visualizations.
``` ```
Example (Profile): Example (Profile):
```python ```python
# profile a single batch, e.g. the 5th batch. # profile a single batch, e.g. the 5th batch.
tensorboard_callback = tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5) profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations in profile plugin. # Now run the tensorboard command to view the visualizations (profile plugin).
# profile a range of batches, e.g. from 10 to 20. # profile a range of batches, e.g. from 10 to 20.
tensorboard_callback = tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20') profile_batch='10,20')
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations in profile plugin. # Now run the tensorboard command to view the visualizations (profile plugin).
``` ```
Arguments: Arguments:
@ -2143,14 +2144,15 @@ class ReduceLROnPlateau(Callback):
Arguments: Arguments:
monitor: quantity to be monitored. monitor: quantity to be monitored.
factor: factor by which the learning rate will be reduced. new_lr = lr * factor: factor by which the learning rate will be reduced.
factor `new_lr = lr * factor`.
patience: number of epochs with no improvement after which learning rate patience: number of epochs with no improvement after which learning rate
will be reduced. will be reduced.
verbose: int. 0: quiet, 1: update messages. verbose: int. 0: quiet, 1: update messages.
mode: one of {auto, min, max}. In `min` mode, lr will be reduced when the mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode,
quantity monitored has stopped decreasing; in `max` mode it will be the learning rate will be reduced when the
reduced when the quantity monitored has stopped increasing; in `auto` quantity monitored has stopped decreasing; in `'max'` mode it will be
reduced when the quantity monitored has stopped increasing; in `'auto'`
mode, the direction is automatically inferred from the name of the mode, the direction is automatically inferred from the name of the
monitored quantity. monitored quantity.
min_delta: threshold for measuring the new optimum, to only focus on min_delta: threshold for measuring the new optimum, to only focus on
@ -2249,10 +2251,10 @@ class ReduceLROnPlateau(Callback):
@keras_export('keras.callbacks.CSVLogger') @keras_export('keras.callbacks.CSVLogger')
class CSVLogger(Callback): class CSVLogger(Callback):
"""Callback that streams epoch results to a csv file. """Callback that streams epoch results to a CSV file.
Supports all values that can be represented as a string, Supports all values that can be represented as a string,
including 1D iterables such as np.ndarray. including 1D iterables such as `np.ndarray`.
Example: Example:
@ -2262,10 +2264,10 @@ class CSVLogger(Callback):
``` ```
Arguments: Arguments:
filename: filename of the csv file, e.g. 'run/log.csv'. filename: Filename of the CSV file, e.g. `'run/log.csv'`.
separator: string used to separate elements in the csv file. separator: String used to separate elements in the CSV file.
append: True: append if file exists (useful for continuing append: Boolean. True: append if file exists (useful for continuing
training). False: overwrite existing file, training). False: overwrite existing file.
""" """
def __init__(self, filename, separator=',', append=False): def __init__(self, filename, separator=',', append=False):
@ -2348,12 +2350,12 @@ class LambdaCallback(Callback):
at the appropriate time. Note that the callbacks expects positional at the appropriate time. Note that the callbacks expects positional
arguments, as: arguments, as:
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments: - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
`epoch`, `logs` `epoch`, `logs`
- `on_batch_begin` and `on_batch_end` expect two positional arguments: - `on_batch_begin` and `on_batch_end` expect two positional arguments:
`batch`, `logs` `batch`, `logs`
- `on_train_begin` and `on_train_end` expect one positional argument: - `on_train_begin` and `on_train_end` expect one positional argument:
`logs` `logs`
Arguments: Arguments:
on_epoch_begin: called at the beginning of every epoch. on_epoch_begin: called at the beginning of every epoch.

View File

@ -40,7 +40,7 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
Arguments: Arguments:
path: path where to cache the dataset locally path: path where to cache the dataset locally
(relative to ~/.keras/datasets). (relative to `~/.keras/datasets`).
test_split: fraction of the data to reserve as test set. test_split: fraction of the data to reserve as test set.
seed: Random seed for shuffling the data seed: Random seed for shuffling the data
before computing the test split. before computing the test split.
@ -48,10 +48,11 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
Returns: Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
**x_train, x_test**: numpy arrays with shape (num_samples, 13) containing **x_train, x_test**: numpy arrays with shape `(num_samples, 13)`
either the training samples (for x_train), or test samples (for y_train) containing either the training samples (for x_train),
or test samples (for y_train).
**y_train, y_test**: numpy arrays of shape (num_samples, ) containing the **y_train, y_test**: numpy arrays of shape `(num_samples,)` containing the
target scalars. The targets are float scalars typically between 10 and target scalars. The targets are float scalars typically between 10 and
50 that represent the home prices in k$. 50 that represent the home prices in k$.
""" """

View File

@ -40,9 +40,9 @@ def load_data():
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
**x_train, x_test**: uint8 arrays of RGB image data with shape **x_train, x_test**: uint8 arrays of RGB image data with shape
(num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is `(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is
'channels_first', or (num_samples, 32, 32, 3) if the data format `'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format
is 'channels_last'. is `'channels_last'`.
**y_train, y_test**: uint8 arrays of category labels **y_train, y_test**: uint8 arrays of category labels
(integers in range 0-9) each with shape (num_samples, 1). (integers in range 0-9) each with shape (num_samples, 1).

View File

@ -46,9 +46,9 @@ def load_data(label_mode='fine'):
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
**x_train, x_test**: uint8 arrays of RGB image data with shape **x_train, x_test**: uint8 arrays of RGB image data with shape
(num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is `(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is
'channels_first', or (num_samples, 32, 32, 3) if the data format `'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format
is 'channels_last'. is `'channels_last'`.
**y_train, y_test**: uint8 arrays of category labels with shape **y_train, y_test**: uint8 arrays of category labels with shape
(num_samples, 1). (num_samples, 1).

View File

@ -80,7 +80,7 @@ def load_data(path='imdb.npz',
**x_train, x_test**: lists of sequences, which are lists of indexes **x_train, x_test**: lists of sequences, which are lists of indexes
(integers). If the num_words argument was specific, the maximum (integers). If the num_words argument was specific, the maximum
possible index value is num_words-1. If the `maxlen` argument was possible index value is `num_words - 1`. If the `maxlen` argument was
specified, the largest possible sequence length is `maxlen`. specified, the largest possible sequence length is `maxlen`.
**y_train, y_test**: lists of integer labels (1 or 0). **y_train, y_test**: lists of integer labels (1 or 0).

View File

@ -31,12 +31,12 @@ def load_data(path='mnist.npz'):
This is a dataset of 60,000 28x28 grayscale images of the 10 digits, This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
along with a test set of 10,000 images. along with a test set of 10,000 images.
More info can be found at the More info can be found at the
(MNIST homepage)[http://yann.lecun.com/exdb/mnist/]. [MNIST homepage](http://yann.lecun.com/exdb/mnist/).
Arguments: Arguments:
path: path where to cache the dataset locally path: path where to cache the dataset locally
(relative to ~/.keras/datasets). (relative to `~/.keras/datasets`).
Returns: Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

View File

@ -42,11 +42,11 @@ def load_data(path='reuters.npz',
"""Loads the Reuters newswire classification dataset. """Loads the Reuters newswire classification dataset.
This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics. This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics.
This was originally generated by parsing and preprocessing the classic This was originally generated by parsing and preprocessing the classic
Reuters-21578 dataset, but the preprocessing code is no longer packaged Reuters-21578 dataset, but the preprocessing code is no longer packaged
with Keras. with Keras. See this
[github discussion](https://github.com/keras-team/keras/issues/12072)
See this [github discussion](https://github.com/keras-team/keras/issues/12072)
for more info. for more info.
Each newswire is encoded as a list of word indexes (integers). Each newswire is encoded as a list of word indexes (integers).
@ -91,7 +91,7 @@ def load_data(path='reuters.npz',
**x_train, x_test**: lists of sequences, which are lists of indexes **x_train, x_test**: lists of sequences, which are lists of indexes
(integers). If the num_words argument was specific, the maximum (integers). If the num_words argument was specific, the maximum
possible index value is num_words-1. If the `maxlen` argument was possible index value is `num_words - 1`. If the `maxlen` argument was
specified, the largest possible sequence length is `maxlen`. specified, the largest possible sequence length is `maxlen`.
**y_train, y_test**: lists of integer labels (1 or 0). **y_train, y_test**: lists of integer labels (1 or 0).

View File

@ -195,9 +195,9 @@ def get_file(fname,
fname: Name of the file. If an absolute path `/path/to/file.txt` is fname: Name of the file. If an absolute path `/path/to/file.txt` is
specified the file will be saved at that location. specified the file will be saved at that location.
origin: Original URL of the file. origin: Original URL of the file.
untar: Deprecated in favor of 'extract'. untar: Deprecated in favor of `extract` argument.
boolean, whether the file should be decompressed boolean, whether the file should be decompressed
md5_hash: Deprecated in favor of 'file_hash'. md5_hash: Deprecated in favor of `file_hash` argument.
md5 hash of the file for verification md5 hash of the file for verification
file_hash: The expected hash string of the file after download. file_hash: The expected hash string of the file after download.
The sha256 and md5 hash algorithms are both supported. The sha256 and md5 hash algorithms are both supported.
@ -205,17 +205,16 @@ def get_file(fname,
saved. If an absolute path `/path/to/folder` is saved. If an absolute path `/path/to/folder` is
specified the file will be saved at that location. specified the file will be saved at that location.
hash_algorithm: Select the hash algorithm to verify the file. hash_algorithm: Select the hash algorithm to verify the file.
options are 'md5', 'sha256', and 'auto'. options are `'md5'`, `'sha256'`, and `'auto'`.
The default 'auto' detects the hash algorithm in use. The default 'auto' detects the hash algorithm in use.
extract: True tries extracting the file as an Archive, like tar or zip. extract: True tries extracting the file as an Archive, like tar or zip.
archive_format: Archive format to try for extracting the file. archive_format: Archive format to try for extracting the file.
Options are 'auto', 'tar', 'zip', and None. Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
'tar' includes tar, tar.gz, and tar.bz files. `'tar'` includes tar, tar.gz, and tar.bz files.
The default 'auto' is ['tar', 'zip']. The default `'auto'` corresponds to `['tar', 'zip']`.
None or an empty list will return no matches found. None or an empty list will return no matches found.
cache_dir: Location to store cached files, when None it cache_dir: Location to store cached files, when None it
defaults to the [Keras defaults to the default directory `~/.keras/`.
Directory](/faq/#where-is-the-keras-configuration-filed-stored).
Returns: Returns:
Path to the downloaded file Path to the downloaded file
@ -315,8 +314,8 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
Arguments: Arguments:
fpath: path to the file being validated fpath: path to the file being validated
algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'. algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
The default 'auto' detects the hash algorithm in use. The default `'auto'` detects the hash algorithm in use.
chunk_size: Bytes to read at a time, important for large files. chunk_size: Bytes to read at a time, important for large files.
Returns: Returns:
@ -420,32 +419,32 @@ class Sequence(object):
Examples: Examples:
```python ```python
from skimage.io import imread from skimage.io import imread
from skimage.transform import resize from skimage.transform import resize
import numpy as np import numpy as np
import math import math
# Here, `x_set` is list of path to the images # Here, `x_set` is list of path to the images
# and `y_set` are the associated classes. # and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence): class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size): def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set self.x, self.y = x_set, y_set
self.batch_size = batch_size self.batch_size = batch_size
def __len__(self): def __len__(self):
return math.ceil(len(self.x) / self.batch_size) return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx): def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size] self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size] self.batch_size]
return np.array([ return np.array([
resize(imread(file_name), (200, 200)) resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y) for file_name in batch_x]), np.array(batch_y)
``` ```
""" """
@ -485,10 +484,10 @@ def iter_sequence_infinite(seq):
"""Iterates indefinitely over a Sequence. """Iterates indefinitely over a Sequence.
Arguments: Arguments:
seq: Sequence instance. seq: `Sequence` instance.
Yields: Yields:
Batches of data from the Sequence. Batches of data from the `Sequence`.
""" """
while True: while True:
for item in seq: for item in seq:

View File

@ -234,7 +234,7 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
@keras_export('keras.utils.serialize_keras_object') @keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance): def serialize_keras_object(instance):
"""Serialize Keras object into JSON.""" """Serialize a Keras object into a JSON-compatible representation."""
_, instance = tf_decorator.unwrap(instance) _, instance = tf_decorator.unwrap(instance)
if instance is None: if instance is None:
return None return None
@ -327,6 +327,7 @@ def deserialize_keras_object(identifier,
module_objects=None, module_objects=None,
custom_objects=None, custom_objects=None,
printable_module_name='object'): printable_module_name='object'):
"""Turns the serialized form of a Keras object back into an actual object."""
if identifier is None: if identifier is None:
return None return None

View File

@ -27,7 +27,18 @@ def to_categorical(y, num_classes=None, dtype='float32'):
E.g. for use with categorical_crossentropy. E.g. for use with categorical_crossentropy.
Usage Example: Arguments:
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes. If `None`, this would be inferred
as the (largest number in `y`) + 1.
dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
last.
Example:
>>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4) >>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
>>> a = tf.constant(a, shape=[4, 4]) >>> a = tf.constant(a, shape=[4, 4])
@ -51,29 +62,6 @@ def to_categorical(y, num_classes=None, dtype='float32'):
>>> print(np.around(loss, 5)) >>> print(np.around(loss, 5))
[0. 0. 0. 0.] [0. 0. 0. 0.]
Arguments:
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes. If `None`, this would be inferred
as the (largest number in `y`) + 1.
dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
last.
Usage example:
>>> y = [0, 1, 2, 3, 3, 1, 0]
>>> tf.keras.utils.to_categorical(y, 4)
array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.],
[0., 0., 0., 1.],
[0., 1., 0., 0.],
[1., 0., 0., 0.]], dtype=float32)
Raises: Raises:
Value Error: If input contains string value Value Error: If input contains string value
@ -100,7 +88,7 @@ def normalize(x, axis=-1, order=2):
Arguments: Arguments:
x: Numpy array to normalize. x: Numpy array to normalize.
axis: axis along which to normalize. axis: axis along which to normalize.
order: Normalization order (e.g. 2 for L2 norm). order: Normalization order (e.g. `order=2` for L2 norm).
Returns: Returns:
A normalized copy of the array. A normalized copy of the array.