Docstring fixes in Keras datasets, callbacks, and utilities

PiperOrigin-RevId: 304544666
Change-Id: I6049cf11ff6fe8090dfadd36cfcebe4b2443bf5f
This commit is contained in:
Francois Chollet 2020-04-02 22:14:06 -07:00 committed by TensorFlower Gardener
parent 394b6bb711
commit 8a370a0077
10 changed files with 182 additions and 191 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=g-import-not-at-top
# pylint: disable=g-classes-have-attributes
"""Callbacks: utilities called at certain points during model training.
"""
from __future__ import absolute_import
@ -336,8 +337,8 @@ class CallbackList(object):
This function should only be called during TRAIN mode.
Arguments:
epoch: integer, index of epoch.
logs: dict. Currently no data is passed to this argument for this method
epoch: Integer, index of epoch.
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
@ -357,8 +358,8 @@ class CallbackList(object):
This function should only be called during TRAIN mode.
Arguments:
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
epoch: Integer, index of epoch.
logs: Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
@ -376,8 +377,8 @@ class CallbackList(object):
"""Calls the `on_train_batch_begin` methods of its callbacks.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch
batch: Integer, index of batch within the current epoch.
logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
# TODO(b/150629188): Make ProgBarLogger callback not use batch hooks
@ -389,8 +390,8 @@ class CallbackList(object):
"""Calls the `on_train_batch_end` methods of its callbacks.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch.
batch: Integer, index of batch within the current epoch.
logs: Dict. Aggregated metric results up until this batch.
"""
if self._should_call_train_batch_hooks:
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
@ -399,8 +400,8 @@ class CallbackList(object):
"""Calls the `on_test_batch_begin` methods of its callbacks.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch
batch: Integer, index of batch within the current epoch.
logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
if self._should_call_test_batch_hooks:
@ -410,8 +411,8 @@ class CallbackList(object):
"""Calls the `on_test_batch_end` methods of its callbacks.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch.
batch: Integer, index of batch within the current epoch.
logs: Dict. Aggregated metric results up until this batch.
"""
if self._should_call_test_batch_hooks:
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
@ -420,8 +421,8 @@ class CallbackList(object):
"""Calls the `on_predict_batch_begin` methods of its callbacks.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch
batch: Integer, index of batch within the current epoch.
logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
if self._should_call_predict_batch_hooks:
@ -431,8 +432,8 @@ class CallbackList(object):
"""Calls the `on_predict_batch_end` methods of its callbacks.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch.
batch: Integer, index of batch within the current epoch.
logs: Dict. Aggregated metric results up until this batch.
"""
if self._should_call_predict_batch_hooks:
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
@ -441,7 +442,7 @@ class CallbackList(object):
"""Calls the `on_train_begin` methods of its callbacks.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
@ -458,7 +459,7 @@ class CallbackList(object):
"""Calls the `on_train_end` methods of its callbacks.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
@ -475,7 +476,7 @@ class CallbackList(object):
"""Calls the `on_test_begin` methods of its callbacks.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
@ -492,7 +493,7 @@ class CallbackList(object):
"""Calls the `on_test_end` methods of its callbacks.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
@ -509,7 +510,7 @@ class CallbackList(object):
"""Calls the 'on_predict_begin` methods of its callbacks.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
@ -526,7 +527,7 @@ class CallbackList(object):
"""Calls the `on_predict_end` methods of its callbacks.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
logs = logs or {}
@ -548,9 +549,9 @@ class Callback(object):
"""Abstract base class used to build new callbacks.
Attributes:
params: dict. Training parameters
params: Dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: instance of `keras.models.Model`.
model: Instance of `keras.models.Model`.
Reference of the model being trained.
The `logs` dictionary that callback methods
@ -591,8 +592,8 @@ class Callback(object):
be called during TRAIN mode.
Arguments:
epoch: integer, index of epoch.
logs: dict. Currently no data is passed to this argument for this method
epoch: Integer, index of epoch.
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
@ -604,8 +605,8 @@ class Callback(object):
be called during TRAIN mode.
Arguments:
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
epoch: Integer, index of epoch.
logs: Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""
@ -618,8 +619,8 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch
batch: Integer, index of batch within the current epoch.
logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
# For backwards compatibility.
@ -633,8 +634,8 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch.
batch: Integer, index of batch within the current epoch.
logs: Dict. Aggregated metric results up until this batch.
"""
# For backwards compatibility.
self.on_batch_end(batch, logs=logs)
@ -650,8 +651,8 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch
batch: Integer, index of batch within the current epoch.
logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
@ -666,8 +667,8 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch.
batch: Integer, index of batch within the current epoch.
logs: Dict. Aggregated metric results up until this batch.
"""
@doc_controls.for_subclass_implementers
@ -678,8 +679,8 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Has keys `batch` and `size` representing the current batch
batch: Integer, index of batch within the current epoch.
logs: Dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
@ -691,8 +692,8 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
batch: integer, index of batch within the current epoch.
logs: dict. Aggregated metric results up until this batch.
batch: Integer, index of batch within the current epoch.
logs: Dict. Aggregated metric results up until this batch.
"""
@doc_controls.for_subclass_implementers
@ -702,7 +703,7 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
@ -713,7 +714,7 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
@ -724,7 +725,7 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
@ -735,7 +736,7 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
@ -746,7 +747,7 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
@ -757,7 +758,7 @@ class Callback(object):
Subclasses should override for any actions to run.
Arguments:
logs: dict. Currently no data is passed to this argument for this method
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
@ -847,7 +848,7 @@ class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout.
Arguments:
count_mode: One of "steps" or "samples".
count_mode: One of `"steps"` or `"samples"`.
Whether the progress bar should
count samples seen or steps (batches) seen.
stateful_metrics: Iterable of string names of metrics that
@ -1411,7 +1412,7 @@ class EarlyStopping(Callback):
"""Stop training when a monitored metric has stopped improving.
Assuming the goal of a training is to minimize the loss. With this, the
metric to be monitored would be 'loss', and mode would be 'min'. A
metric to be monitored would be `'loss'`, and mode would be `'min'`. A
`model.fit()` training loop will check at end of every epoch whether
the loss is no longer decreasing, considering the `min_delta` and
`patience` if applicable. Once it's found no longer decreasing,
@ -1420,6 +1421,30 @@ class EarlyStopping(Callback):
The quantity to be monitored needs to be available in `logs` dict.
To make it so, pass the loss or metrics at `model.compile()`.
Arguments:
monitor: Quantity to be monitored.
min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
patience: Number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
mode: One of `{"auto", "min", "max"}`. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `"max"`
mode it will stop when the quantity
monitored has stopped increasing; in `"auto"`
mode, the direction is automatically inferred
from the name of the monitored quantity.
baseline: Baseline value for the monitored quantity.
Training will stop if the model doesn't show improvement over the
baseline.
restore_best_weights: Whether to restore model weights from
the epoch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
Example:
>>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
@ -1442,32 +1467,6 @@ class EarlyStopping(Callback):
mode='auto',
baseline=None,
restore_best_weights=False):
"""Initialize an EarlyStopping callback.
Arguments:
monitor: Quantity to be monitored.
min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
patience: Number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
mode: One of `{"auto", "min", "max"}`. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
baseline: Baseline value for the monitored quantity.
Training will stop if the model doesn't show improvement over the
baseline.
restore_best_weights: Whether to restore model weights from
the epoch with the best value of the monitored quantity.
If False, the model weights obtained at the last step of
training are used.
"""
super(EarlyStopping, self).__init__()
self.monitor = monitor
@ -1550,18 +1549,19 @@ class RemoteMonitor(Callback):
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
HTTP POST, with a `data` argument which is a
JSON-encoded dictionary of event data.
If send_as_json is set to True, the content type of the request will be
application/json. Otherwise the serialized JSON will be sent within a form.
If `send_as_json=True`, the content type of the request will be
`"application/json"`.
Otherwise the serialized JSON will be sent within a form.
Arguments:
root: String; root url of the target server.
path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored.
The field is used only if the payload is sent within a form
(i.e. send_as_json is set to False).
headers: Dictionary; optional custom HTTP headers.
send_as_json: Boolean; whether the request should be
sent as application/json.
root: String; root url of the target server.
path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored.
The field is used only if the payload is sent within a form
(i.e. send_as_json is set to False).
headers: Dictionary; optional custom HTTP headers.
send_as_json: Boolean; whether the request should be
sent as `"application/json"`.
"""
def __init__(self,
@ -1613,6 +1613,12 @@ class LearningRateScheduler(Callback):
and current learning rate, and applies the updated learning rate
on the optimizer.
Arguments:
schedule: a function that takes an epoch index (integer, indexed from 0)
and current learning rate (float) as inputs and returns a new
learning rate as output (float).
verbose: int. 0: quiet, 1: update messages.
Example:
>>> # This function keeps the initial learning rate for the first ten epochs
@ -1637,14 +1643,6 @@ class LearningRateScheduler(Callback):
"""
def __init__(self, schedule, verbose=0):
"""Initialize a `keras.callbacks.LearningRateScheduler` callback.
Arguments:
schedule: a function that takes an epoch index (integer, indexed from 0)
and current learning rate (float) as inputs and returns a new
learning rate as output (float).
verbose: int. 0: quiet, 1: update messages.
"""
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
self.verbose = verbose
@ -1689,7 +1687,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
If you have installed TensorFlow with pip, you should be able
to launch TensorBoard from the command line:
```sh
```
tensorboard --logdir=path_to_your_logs
```
@ -1697,24 +1695,27 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
Example (Basic):
```python
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations.
```
Example (Profile):
```python
# profile a single batch, e.g. the 5th batch.
tensorboard_callback =
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations in profile plugin.
# Now run the tensorboard command to view the visualizations (profile plugin).
# profile a range of batches, e.g. from 10 to 20.
tensorboard_callback =
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20')
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
profile_batch='10,20')
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations in profile plugin.
# Now run the tensorboard command to view the visualizations (profile plugin).
```
Arguments:
@ -2143,14 +2144,15 @@ class ReduceLROnPlateau(Callback):
Arguments:
monitor: quantity to be monitored.
factor: factor by which the learning rate will be reduced. new_lr = lr *
factor
factor: factor by which the learning rate will be reduced.
`new_lr = lr * factor`.
patience: number of epochs with no improvement after which learning rate
will be reduced.
verbose: int. 0: quiet, 1: update messages.
mode: one of {auto, min, max}. In `min` mode, lr will be reduced when the
quantity monitored has stopped decreasing; in `max` mode it will be
reduced when the quantity monitored has stopped increasing; in `auto`
mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode,
the learning rate will be reduced when the
quantity monitored has stopped decreasing; in `'max'` mode it will be
reduced when the quantity monitored has stopped increasing; in `'auto'`
mode, the direction is automatically inferred from the name of the
monitored quantity.
min_delta: threshold for measuring the new optimum, to only focus on
@ -2249,10 +2251,10 @@ class ReduceLROnPlateau(Callback):
@keras_export('keras.callbacks.CSVLogger')
class CSVLogger(Callback):
"""Callback that streams epoch results to a csv file.
"""Callback that streams epoch results to a CSV file.
Supports all values that can be represented as a string,
including 1D iterables such as np.ndarray.
including 1D iterables such as `np.ndarray`.
Example:
@ -2262,10 +2264,10 @@ class CSVLogger(Callback):
```
Arguments:
filename: filename of the csv file, e.g. 'run/log.csv'.
separator: string used to separate elements in the csv file.
append: True: append if file exists (useful for continuing
training). False: overwrite existing file,
filename: Filename of the CSV file, e.g. `'run/log.csv'`.
separator: String used to separate elements in the CSV file.
append: Boolean. True: append if file exists (useful for continuing
training). False: overwrite existing file.
"""
def __init__(self, filename, separator=',', append=False):
@ -2348,12 +2350,12 @@ class LambdaCallback(Callback):
at the appropriate time. Note that the callbacks expects positional
arguments, as:
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
`epoch`, `logs`
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
`batch`, `logs`
- `on_train_begin` and `on_train_end` expect one positional argument:
`logs`
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
`epoch`, `logs`
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
`batch`, `logs`
- `on_train_begin` and `on_train_end` expect one positional argument:
`logs`
Arguments:
on_epoch_begin: called at the beginning of every epoch.

View File

@ -40,7 +40,7 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
Arguments:
path: path where to cache the dataset locally
(relative to ~/.keras/datasets).
(relative to `~/.keras/datasets`).
test_split: fraction of the data to reserve as test set.
seed: Random seed for shuffling the data
before computing the test split.
@ -48,10 +48,11 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
**x_train, x_test**: numpy arrays with shape (num_samples, 13) containing
either the training samples (for x_train), or test samples (for y_train)
**x_train, x_test**: numpy arrays with shape `(num_samples, 13)`
containing either the training samples (for x_train),
or test samples (for y_train).
**y_train, y_test**: numpy arrays of shape (num_samples, ) containing the
**y_train, y_test**: numpy arrays of shape `(num_samples,)` containing the
target scalars. The targets are float scalars typically between 10 and
50 that represent the home prices in k$.
"""

View File

@ -40,9 +40,9 @@ def load_data():
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
**x_train, x_test**: uint8 arrays of RGB image data with shape
(num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is
'channels_first', or (num_samples, 32, 32, 3) if the data format
is 'channels_last'.
`(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is
`'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format
is `'channels_last'`.
**y_train, y_test**: uint8 arrays of category labels
(integers in range 0-9) each with shape (num_samples, 1).

View File

@ -46,9 +46,9 @@ def load_data(label_mode='fine'):
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
**x_train, x_test**: uint8 arrays of RGB image data with shape
(num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is
'channels_first', or (num_samples, 32, 32, 3) if the data format
is 'channels_last'.
`(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is
`'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format
is `'channels_last'`.
**y_train, y_test**: uint8 arrays of category labels with shape
(num_samples, 1).

View File

@ -80,7 +80,7 @@ def load_data(path='imdb.npz',
**x_train, x_test**: lists of sequences, which are lists of indexes
(integers). If the num_words argument was specific, the maximum
possible index value is num_words-1. If the `maxlen` argument was
possible index value is `num_words - 1`. If the `maxlen` argument was
specified, the largest possible sequence length is `maxlen`.
**y_train, y_test**: lists of integer labels (1 or 0).

View File

@ -31,12 +31,12 @@ def load_data(path='mnist.npz'):
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
along with a test set of 10,000 images.
More info can be found at the
(MNIST homepage)[http://yann.lecun.com/exdb/mnist/].
[MNIST homepage](http://yann.lecun.com/exdb/mnist/).
Arguments:
path: path where to cache the dataset locally
(relative to ~/.keras/datasets).
(relative to `~/.keras/datasets`).
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

View File

@ -42,11 +42,11 @@ def load_data(path='reuters.npz',
"""Loads the Reuters newswire classification dataset.
This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics.
This was originally generated by parsing and preprocessing the classic
Reuters-21578 dataset, but the preprocessing code is no longer packaged
with Keras.
See this [github discussion](https://github.com/keras-team/keras/issues/12072)
with Keras. See this
[github discussion](https://github.com/keras-team/keras/issues/12072)
for more info.
Each newswire is encoded as a list of word indexes (integers).
@ -91,7 +91,7 @@ def load_data(path='reuters.npz',
**x_train, x_test**: lists of sequences, which are lists of indexes
(integers). If the num_words argument was specific, the maximum
possible index value is num_words-1. If the `maxlen` argument was
possible index value is `num_words - 1`. If the `maxlen` argument was
specified, the largest possible sequence length is `maxlen`.
**y_train, y_test**: lists of integer labels (1 or 0).

View File

@ -195,9 +195,9 @@ def get_file(fname,
fname: Name of the file. If an absolute path `/path/to/file.txt` is
specified the file will be saved at that location.
origin: Original URL of the file.
untar: Deprecated in favor of 'extract'.
untar: Deprecated in favor of `extract` argument.
boolean, whether the file should be decompressed
md5_hash: Deprecated in favor of 'file_hash'.
md5_hash: Deprecated in favor of `file_hash` argument.
md5 hash of the file for verification
file_hash: The expected hash string of the file after download.
The sha256 and md5 hash algorithms are both supported.
@ -205,17 +205,16 @@ def get_file(fname,
saved. If an absolute path `/path/to/folder` is
specified the file will be saved at that location.
hash_algorithm: Select the hash algorithm to verify the file.
options are 'md5', 'sha256', and 'auto'.
options are `'md5'`, `'sha256'`, and `'auto'`.
The default 'auto' detects the hash algorithm in use.
extract: True tries extracting the file as an Archive, like tar or zip.
archive_format: Archive format to try for extracting the file.
Options are 'auto', 'tar', 'zip', and None.
'tar' includes tar, tar.gz, and tar.bz files.
The default 'auto' is ['tar', 'zip'].
Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
`'tar'` includes tar, tar.gz, and tar.bz files.
The default `'auto'` corresponds to `['tar', 'zip']`.
None or an empty list will return no matches found.
cache_dir: Location to store cached files, when None it
defaults to the [Keras
Directory](/faq/#where-is-the-keras-configuration-filed-stored).
defaults to the default directory `~/.keras/`.
Returns:
Path to the downloaded file
@ -315,8 +314,8 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
Arguments:
fpath: path to the file being validated
algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
The default 'auto' detects the hash algorithm in use.
algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
The default `'auto'` detects the hash algorithm in use.
chunk_size: Bytes to read at a time, important for large files.
Returns:
@ -420,32 +419,32 @@ class Sequence(object):
Examples:
```python
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
```
"""
@ -485,10 +484,10 @@ def iter_sequence_infinite(seq):
"""Iterates indefinitely over a Sequence.
Arguments:
seq: Sequence instance.
seq: `Sequence` instance.
Yields:
Batches of data from the Sequence.
Batches of data from the `Sequence`.
"""
while True:
for item in seq:

View File

@ -234,7 +234,7 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
@keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
"""Serialize Keras object into JSON."""
"""Serialize a Keras object into a JSON-compatible representation."""
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
@ -327,6 +327,7 @@ def deserialize_keras_object(identifier,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
"""Turns the serialized form of a Keras object back into an actual object."""
if identifier is None:
return None

View File

@ -27,7 +27,18 @@ def to_categorical(y, num_classes=None, dtype='float32'):
E.g. for use with categorical_crossentropy.
Usage Example:
Arguments:
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes. If `None`, this would be inferred
as the (largest number in `y`) + 1.
dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
last.
Example:
>>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
>>> a = tf.constant(a, shape=[4, 4])
@ -51,29 +62,6 @@ def to_categorical(y, num_classes=None, dtype='float32'):
>>> print(np.around(loss, 5))
[0. 0. 0. 0.]
Arguments:
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes. If `None`, this would be inferred
as the (largest number in `y`) + 1.
dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
last.
Usage example:
>>> y = [0, 1, 2, 3, 3, 1, 0]
>>> tf.keras.utils.to_categorical(y, 4)
array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.],
[0., 0., 0., 1.],
[0., 1., 0., 0.],
[1., 0., 0., 0.]], dtype=float32)
Raises:
Value Error: If input contains string value
@ -100,7 +88,7 @@ def normalize(x, axis=-1, order=2):
Arguments:
x: Numpy array to normalize.
axis: axis along which to normalize.
order: Normalization order (e.g. 2 for L2 norm).
order: Normalization order (e.g. `order=2` for L2 norm).
Returns:
A normalized copy of the array.