Docstring fixes in Keras datasets, callbacks, and utilities
PiperOrigin-RevId: 304544666 Change-Id: I6049cf11ff6fe8090dfadd36cfcebe4b2443bf5f
This commit is contained in:
parent
394b6bb711
commit
8a370a0077
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=g-import-not-at-top
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
"""Callbacks: utilities called at certain points during model training.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
@ -336,8 +337,8 @@ class CallbackList(object):
|
||||
This function should only be called during TRAIN mode.
|
||||
|
||||
Arguments:
|
||||
epoch: integer, index of epoch.
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
epoch: Integer, index of epoch.
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
@ -357,8 +358,8 @@ class CallbackList(object):
|
||||
This function should only be called during TRAIN mode.
|
||||
|
||||
Arguments:
|
||||
epoch: integer, index of epoch.
|
||||
logs: dict, metric results for this training epoch, and for the
|
||||
epoch: Integer, index of epoch.
|
||||
logs: Dict, metric results for this training epoch, and for the
|
||||
validation epoch if validation is performed. Validation result keys
|
||||
are prefixed with `val_`.
|
||||
"""
|
||||
@ -376,8 +377,8 @@ class CallbackList(object):
|
||||
"""Calls the `on_train_batch_begin` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
# TODO(b/150629188): Make ProgBarLogger callback not use batch hooks
|
||||
@ -389,8 +390,8 @@ class CallbackList(object):
|
||||
"""Calls the `on_train_batch_end` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Aggregated metric results up until this batch.
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Aggregated metric results up until this batch.
|
||||
"""
|
||||
if self._should_call_train_batch_hooks:
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
|
||||
@ -399,8 +400,8 @@ class CallbackList(object):
|
||||
"""Calls the `on_test_batch_begin` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
if self._should_call_test_batch_hooks:
|
||||
@ -410,8 +411,8 @@ class CallbackList(object):
|
||||
"""Calls the `on_test_batch_end` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Aggregated metric results up until this batch.
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Aggregated metric results up until this batch.
|
||||
"""
|
||||
if self._should_call_test_batch_hooks:
|
||||
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
|
||||
@ -420,8 +421,8 @@ class CallbackList(object):
|
||||
"""Calls the `on_predict_batch_begin` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
if self._should_call_predict_batch_hooks:
|
||||
@ -431,8 +432,8 @@ class CallbackList(object):
|
||||
"""Calls the `on_predict_batch_end` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Aggregated metric results up until this batch.
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Aggregated metric results up until this batch.
|
||||
"""
|
||||
if self._should_call_predict_batch_hooks:
|
||||
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
|
||||
@ -441,7 +442,7 @@ class CallbackList(object):
|
||||
"""Calls the `on_train_begin` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
@ -458,7 +459,7 @@ class CallbackList(object):
|
||||
"""Calls the `on_train_end` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
@ -475,7 +476,7 @@ class CallbackList(object):
|
||||
"""Calls the `on_test_begin` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
@ -492,7 +493,7 @@ class CallbackList(object):
|
||||
"""Calls the `on_test_end` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
@ -509,7 +510,7 @@ class CallbackList(object):
|
||||
"""Calls the 'on_predict_begin` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
@ -526,7 +527,7 @@ class CallbackList(object):
|
||||
"""Calls the `on_predict_end` methods of its callbacks.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
logs = logs or {}
|
||||
@ -548,9 +549,9 @@ class Callback(object):
|
||||
"""Abstract base class used to build new callbacks.
|
||||
|
||||
Attributes:
|
||||
params: dict. Training parameters
|
||||
params: Dict. Training parameters
|
||||
(eg. verbosity, batch size, number of epochs...).
|
||||
model: instance of `keras.models.Model`.
|
||||
model: Instance of `keras.models.Model`.
|
||||
Reference of the model being trained.
|
||||
|
||||
The `logs` dictionary that callback methods
|
||||
@ -591,8 +592,8 @@ class Callback(object):
|
||||
be called during TRAIN mode.
|
||||
|
||||
Arguments:
|
||||
epoch: integer, index of epoch.
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
epoch: Integer, index of epoch.
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
|
||||
@ -604,8 +605,8 @@ class Callback(object):
|
||||
be called during TRAIN mode.
|
||||
|
||||
Arguments:
|
||||
epoch: integer, index of epoch.
|
||||
logs: dict, metric results for this training epoch, and for the
|
||||
epoch: Integer, index of epoch.
|
||||
logs: Dict, metric results for this training epoch, and for the
|
||||
validation epoch if validation is performed. Validation result keys
|
||||
are prefixed with `val_`.
|
||||
"""
|
||||
@ -618,8 +619,8 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
# For backwards compatibility.
|
||||
@ -633,8 +634,8 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Aggregated metric results up until this batch.
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Aggregated metric results up until this batch.
|
||||
"""
|
||||
# For backwards compatibility.
|
||||
self.on_batch_end(batch, logs=logs)
|
||||
@ -650,8 +651,8 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
|
||||
@ -666,8 +667,8 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Aggregated metric results up until this batch.
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Aggregated metric results up until this batch.
|
||||
"""
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@ -678,8 +679,8 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
|
||||
@ -691,8 +692,8 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Aggregated metric results up until this batch.
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Aggregated metric results up until this batch.
|
||||
"""
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@ -702,7 +703,7 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
|
||||
@ -713,7 +714,7 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
|
||||
@ -724,7 +725,7 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
|
||||
@ -735,7 +736,7 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
|
||||
@ -746,7 +747,7 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
|
||||
@ -757,7 +758,7 @@ class Callback(object):
|
||||
Subclasses should override for any actions to run.
|
||||
|
||||
Arguments:
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
logs: Dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
"""
|
||||
|
||||
@ -847,7 +848,7 @@ class ProgbarLogger(Callback):
|
||||
"""Callback that prints metrics to stdout.
|
||||
|
||||
Arguments:
|
||||
count_mode: One of "steps" or "samples".
|
||||
count_mode: One of `"steps"` or `"samples"`.
|
||||
Whether the progress bar should
|
||||
count samples seen or steps (batches) seen.
|
||||
stateful_metrics: Iterable of string names of metrics that
|
||||
@ -1411,7 +1412,7 @@ class EarlyStopping(Callback):
|
||||
"""Stop training when a monitored metric has stopped improving.
|
||||
|
||||
Assuming the goal of a training is to minimize the loss. With this, the
|
||||
metric to be monitored would be 'loss', and mode would be 'min'. A
|
||||
metric to be monitored would be `'loss'`, and mode would be `'min'`. A
|
||||
`model.fit()` training loop will check at end of every epoch whether
|
||||
the loss is no longer decreasing, considering the `min_delta` and
|
||||
`patience` if applicable. Once it's found no longer decreasing,
|
||||
@ -1420,6 +1421,30 @@ class EarlyStopping(Callback):
|
||||
The quantity to be monitored needs to be available in `logs` dict.
|
||||
To make it so, pass the loss or metrics at `model.compile()`.
|
||||
|
||||
Arguments:
|
||||
monitor: Quantity to be monitored.
|
||||
min_delta: Minimum change in the monitored quantity
|
||||
to qualify as an improvement, i.e. an absolute
|
||||
change of less than min_delta, will count as no
|
||||
improvement.
|
||||
patience: Number of epochs with no improvement
|
||||
after which training will be stopped.
|
||||
verbose: verbosity mode.
|
||||
mode: One of `{"auto", "min", "max"}`. In `min` mode,
|
||||
training will stop when the quantity
|
||||
monitored has stopped decreasing; in `"max"`
|
||||
mode it will stop when the quantity
|
||||
monitored has stopped increasing; in `"auto"`
|
||||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity.
|
||||
baseline: Baseline value for the monitored quantity.
|
||||
Training will stop if the model doesn't show improvement over the
|
||||
baseline.
|
||||
restore_best_weights: Whether to restore model weights from
|
||||
the epoch with the best value of the monitored quantity.
|
||||
If False, the model weights obtained at the last step of
|
||||
training are used.
|
||||
|
||||
Example:
|
||||
|
||||
>>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
|
||||
@ -1442,32 +1467,6 @@ class EarlyStopping(Callback):
|
||||
mode='auto',
|
||||
baseline=None,
|
||||
restore_best_weights=False):
|
||||
"""Initialize an EarlyStopping callback.
|
||||
|
||||
Arguments:
|
||||
monitor: Quantity to be monitored.
|
||||
min_delta: Minimum change in the monitored quantity
|
||||
to qualify as an improvement, i.e. an absolute
|
||||
change of less than min_delta, will count as no
|
||||
improvement.
|
||||
patience: Number of epochs with no improvement
|
||||
after which training will be stopped.
|
||||
verbose: verbosity mode.
|
||||
mode: One of `{"auto", "min", "max"}`. In `min` mode,
|
||||
training will stop when the quantity
|
||||
monitored has stopped decreasing; in `max`
|
||||
mode it will stop when the quantity
|
||||
monitored has stopped increasing; in `auto`
|
||||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity.
|
||||
baseline: Baseline value for the monitored quantity.
|
||||
Training will stop if the model doesn't show improvement over the
|
||||
baseline.
|
||||
restore_best_weights: Whether to restore model weights from
|
||||
the epoch with the best value of the monitored quantity.
|
||||
If False, the model weights obtained at the last step of
|
||||
training are used.
|
||||
"""
|
||||
super(EarlyStopping, self).__init__()
|
||||
|
||||
self.monitor = monitor
|
||||
@ -1550,18 +1549,19 @@ class RemoteMonitor(Callback):
|
||||
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
|
||||
HTTP POST, with a `data` argument which is a
|
||||
JSON-encoded dictionary of event data.
|
||||
If send_as_json is set to True, the content type of the request will be
|
||||
application/json. Otherwise the serialized JSON will be sent within a form.
|
||||
If `send_as_json=True`, the content type of the request will be
|
||||
`"application/json"`.
|
||||
Otherwise the serialized JSON will be sent within a form.
|
||||
|
||||
Arguments:
|
||||
root: String; root url of the target server.
|
||||
path: String; path relative to `root` to which the events will be sent.
|
||||
field: String; JSON field under which the data will be stored.
|
||||
The field is used only if the payload is sent within a form
|
||||
(i.e. send_as_json is set to False).
|
||||
headers: Dictionary; optional custom HTTP headers.
|
||||
send_as_json: Boolean; whether the request should be
|
||||
sent as application/json.
|
||||
root: String; root url of the target server.
|
||||
path: String; path relative to `root` to which the events will be sent.
|
||||
field: String; JSON field under which the data will be stored.
|
||||
The field is used only if the payload is sent within a form
|
||||
(i.e. send_as_json is set to False).
|
||||
headers: Dictionary; optional custom HTTP headers.
|
||||
send_as_json: Boolean; whether the request should be
|
||||
sent as `"application/json"`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -1613,6 +1613,12 @@ class LearningRateScheduler(Callback):
|
||||
and current learning rate, and applies the updated learning rate
|
||||
on the optimizer.
|
||||
|
||||
Arguments:
|
||||
schedule: a function that takes an epoch index (integer, indexed from 0)
|
||||
and current learning rate (float) as inputs and returns a new
|
||||
learning rate as output (float).
|
||||
verbose: int. 0: quiet, 1: update messages.
|
||||
|
||||
Example:
|
||||
|
||||
>>> # This function keeps the initial learning rate for the first ten epochs
|
||||
@ -1637,14 +1643,6 @@ class LearningRateScheduler(Callback):
|
||||
"""
|
||||
|
||||
def __init__(self, schedule, verbose=0):
|
||||
"""Initialize a `keras.callbacks.LearningRateScheduler` callback.
|
||||
|
||||
Arguments:
|
||||
schedule: a function that takes an epoch index (integer, indexed from 0)
|
||||
and current learning rate (float) as inputs and returns a new
|
||||
learning rate as output (float).
|
||||
verbose: int. 0: quiet, 1: update messages.
|
||||
"""
|
||||
super(LearningRateScheduler, self).__init__()
|
||||
self.schedule = schedule
|
||||
self.verbose = verbose
|
||||
@ -1689,7 +1687,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
If you have installed TensorFlow with pip, you should be able
|
||||
to launch TensorBoard from the command line:
|
||||
|
||||
```sh
|
||||
```
|
||||
tensorboard --logdir=path_to_your_logs
|
||||
```
|
||||
|
||||
@ -1697,24 +1695,27 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
|
||||
|
||||
Example (Basic):
|
||||
|
||||
```python
|
||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
# run the tensorboard command to view the visualizations.
|
||||
```
|
||||
|
||||
Example (Profile):
|
||||
|
||||
```python
|
||||
# profile a single batch, e.g. the 5th batch.
|
||||
tensorboard_callback =
|
||||
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5)
|
||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
|
||||
profile_batch=5)
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
# run the tensorboard command to view the visualizations in profile plugin.
|
||||
# Now run the tensorboard command to view the visualizations (profile plugin).
|
||||
|
||||
# profile a range of batches, e.g. from 10 to 20.
|
||||
tensorboard_callback =
|
||||
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20')
|
||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
|
||||
profile_batch='10,20')
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
# run the tensorboard command to view the visualizations in profile plugin.
|
||||
# Now run the tensorboard command to view the visualizations (profile plugin).
|
||||
```
|
||||
|
||||
Arguments:
|
||||
@ -2143,14 +2144,15 @@ class ReduceLROnPlateau(Callback):
|
||||
|
||||
Arguments:
|
||||
monitor: quantity to be monitored.
|
||||
factor: factor by which the learning rate will be reduced. new_lr = lr *
|
||||
factor
|
||||
factor: factor by which the learning rate will be reduced.
|
||||
`new_lr = lr * factor`.
|
||||
patience: number of epochs with no improvement after which learning rate
|
||||
will be reduced.
|
||||
verbose: int. 0: quiet, 1: update messages.
|
||||
mode: one of {auto, min, max}. In `min` mode, lr will be reduced when the
|
||||
quantity monitored has stopped decreasing; in `max` mode it will be
|
||||
reduced when the quantity monitored has stopped increasing; in `auto`
|
||||
mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode,
|
||||
the learning rate will be reduced when the
|
||||
quantity monitored has stopped decreasing; in `'max'` mode it will be
|
||||
reduced when the quantity monitored has stopped increasing; in `'auto'`
|
||||
mode, the direction is automatically inferred from the name of the
|
||||
monitored quantity.
|
||||
min_delta: threshold for measuring the new optimum, to only focus on
|
||||
@ -2249,10 +2251,10 @@ class ReduceLROnPlateau(Callback):
|
||||
|
||||
@keras_export('keras.callbacks.CSVLogger')
|
||||
class CSVLogger(Callback):
|
||||
"""Callback that streams epoch results to a csv file.
|
||||
"""Callback that streams epoch results to a CSV file.
|
||||
|
||||
Supports all values that can be represented as a string,
|
||||
including 1D iterables such as np.ndarray.
|
||||
including 1D iterables such as `np.ndarray`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -2262,10 +2264,10 @@ class CSVLogger(Callback):
|
||||
```
|
||||
|
||||
Arguments:
|
||||
filename: filename of the csv file, e.g. 'run/log.csv'.
|
||||
separator: string used to separate elements in the csv file.
|
||||
append: True: append if file exists (useful for continuing
|
||||
training). False: overwrite existing file,
|
||||
filename: Filename of the CSV file, e.g. `'run/log.csv'`.
|
||||
separator: String used to separate elements in the CSV file.
|
||||
append: Boolean. True: append if file exists (useful for continuing
|
||||
training). False: overwrite existing file.
|
||||
"""
|
||||
|
||||
def __init__(self, filename, separator=',', append=False):
|
||||
@ -2348,12 +2350,12 @@ class LambdaCallback(Callback):
|
||||
at the appropriate time. Note that the callbacks expects positional
|
||||
arguments, as:
|
||||
|
||||
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
|
||||
`epoch`, `logs`
|
||||
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
|
||||
`batch`, `logs`
|
||||
- `on_train_begin` and `on_train_end` expect one positional argument:
|
||||
`logs`
|
||||
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
|
||||
`epoch`, `logs`
|
||||
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
|
||||
`batch`, `logs`
|
||||
- `on_train_begin` and `on_train_end` expect one positional argument:
|
||||
`logs`
|
||||
|
||||
Arguments:
|
||||
on_epoch_begin: called at the beginning of every epoch.
|
||||
|
@ -40,7 +40,7 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
|
||||
|
||||
Arguments:
|
||||
path: path where to cache the dataset locally
|
||||
(relative to ~/.keras/datasets).
|
||||
(relative to `~/.keras/datasets`).
|
||||
test_split: fraction of the data to reserve as test set.
|
||||
seed: Random seed for shuffling the data
|
||||
before computing the test split.
|
||||
@ -48,10 +48,11 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
|
||||
Returns:
|
||||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||||
|
||||
**x_train, x_test**: numpy arrays with shape (num_samples, 13) containing
|
||||
either the training samples (for x_train), or test samples (for y_train)
|
||||
**x_train, x_test**: numpy arrays with shape `(num_samples, 13)`
|
||||
containing either the training samples (for x_train),
|
||||
or test samples (for y_train).
|
||||
|
||||
**y_train, y_test**: numpy arrays of shape (num_samples, ) containing the
|
||||
**y_train, y_test**: numpy arrays of shape `(num_samples,)` containing the
|
||||
target scalars. The targets are float scalars typically between 10 and
|
||||
50 that represent the home prices in k$.
|
||||
"""
|
||||
|
@ -40,9 +40,9 @@ def load_data():
|
||||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||||
|
||||
**x_train, x_test**: uint8 arrays of RGB image data with shape
|
||||
(num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is
|
||||
'channels_first', or (num_samples, 32, 32, 3) if the data format
|
||||
is 'channels_last'.
|
||||
`(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is
|
||||
`'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format
|
||||
is `'channels_last'`.
|
||||
|
||||
**y_train, y_test**: uint8 arrays of category labels
|
||||
(integers in range 0-9) each with shape (num_samples, 1).
|
||||
|
@ -46,9 +46,9 @@ def load_data(label_mode='fine'):
|
||||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||||
|
||||
**x_train, x_test**: uint8 arrays of RGB image data with shape
|
||||
(num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is
|
||||
'channels_first', or (num_samples, 32, 32, 3) if the data format
|
||||
is 'channels_last'.
|
||||
`(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is
|
||||
`'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format
|
||||
is `'channels_last'`.
|
||||
|
||||
**y_train, y_test**: uint8 arrays of category labels with shape
|
||||
(num_samples, 1).
|
||||
|
@ -80,7 +80,7 @@ def load_data(path='imdb.npz',
|
||||
|
||||
**x_train, x_test**: lists of sequences, which are lists of indexes
|
||||
(integers). If the num_words argument was specific, the maximum
|
||||
possible index value is num_words-1. If the `maxlen` argument was
|
||||
possible index value is `num_words - 1`. If the `maxlen` argument was
|
||||
specified, the largest possible sequence length is `maxlen`.
|
||||
|
||||
**y_train, y_test**: lists of integer labels (1 or 0).
|
||||
|
@ -31,12 +31,12 @@ def load_data(path='mnist.npz'):
|
||||
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
|
||||
along with a test set of 10,000 images.
|
||||
More info can be found at the
|
||||
(MNIST homepage)[http://yann.lecun.com/exdb/mnist/].
|
||||
[MNIST homepage](http://yann.lecun.com/exdb/mnist/).
|
||||
|
||||
|
||||
Arguments:
|
||||
path: path where to cache the dataset locally
|
||||
(relative to ~/.keras/datasets).
|
||||
(relative to `~/.keras/datasets`).
|
||||
|
||||
Returns:
|
||||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||||
|
@ -42,11 +42,11 @@ def load_data(path='reuters.npz',
|
||||
"""Loads the Reuters newswire classification dataset.
|
||||
|
||||
This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics.
|
||||
|
||||
This was originally generated by parsing and preprocessing the classic
|
||||
Reuters-21578 dataset, but the preprocessing code is no longer packaged
|
||||
with Keras.
|
||||
|
||||
See this [github discussion](https://github.com/keras-team/keras/issues/12072)
|
||||
with Keras. See this
|
||||
[github discussion](https://github.com/keras-team/keras/issues/12072)
|
||||
for more info.
|
||||
|
||||
Each newswire is encoded as a list of word indexes (integers).
|
||||
@ -91,7 +91,7 @@ def load_data(path='reuters.npz',
|
||||
|
||||
**x_train, x_test**: lists of sequences, which are lists of indexes
|
||||
(integers). If the num_words argument was specific, the maximum
|
||||
possible index value is num_words-1. If the `maxlen` argument was
|
||||
possible index value is `num_words - 1`. If the `maxlen` argument was
|
||||
specified, the largest possible sequence length is `maxlen`.
|
||||
|
||||
**y_train, y_test**: lists of integer labels (1 or 0).
|
||||
|
@ -195,9 +195,9 @@ def get_file(fname,
|
||||
fname: Name of the file. If an absolute path `/path/to/file.txt` is
|
||||
specified the file will be saved at that location.
|
||||
origin: Original URL of the file.
|
||||
untar: Deprecated in favor of 'extract'.
|
||||
untar: Deprecated in favor of `extract` argument.
|
||||
boolean, whether the file should be decompressed
|
||||
md5_hash: Deprecated in favor of 'file_hash'.
|
||||
md5_hash: Deprecated in favor of `file_hash` argument.
|
||||
md5 hash of the file for verification
|
||||
file_hash: The expected hash string of the file after download.
|
||||
The sha256 and md5 hash algorithms are both supported.
|
||||
@ -205,17 +205,16 @@ def get_file(fname,
|
||||
saved. If an absolute path `/path/to/folder` is
|
||||
specified the file will be saved at that location.
|
||||
hash_algorithm: Select the hash algorithm to verify the file.
|
||||
options are 'md5', 'sha256', and 'auto'.
|
||||
options are `'md5'`, `'sha256'`, and `'auto'`.
|
||||
The default 'auto' detects the hash algorithm in use.
|
||||
extract: True tries extracting the file as an Archive, like tar or zip.
|
||||
archive_format: Archive format to try for extracting the file.
|
||||
Options are 'auto', 'tar', 'zip', and None.
|
||||
'tar' includes tar, tar.gz, and tar.bz files.
|
||||
The default 'auto' is ['tar', 'zip'].
|
||||
Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
|
||||
`'tar'` includes tar, tar.gz, and tar.bz files.
|
||||
The default `'auto'` corresponds to `['tar', 'zip']`.
|
||||
None or an empty list will return no matches found.
|
||||
cache_dir: Location to store cached files, when None it
|
||||
defaults to the [Keras
|
||||
Directory](/faq/#where-is-the-keras-configuration-filed-stored).
|
||||
defaults to the default directory `~/.keras/`.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded file
|
||||
@ -315,8 +314,8 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
|
||||
|
||||
Arguments:
|
||||
fpath: path to the file being validated
|
||||
algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
|
||||
The default 'auto' detects the hash algorithm in use.
|
||||
algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
|
||||
The default `'auto'` detects the hash algorithm in use.
|
||||
chunk_size: Bytes to read at a time, important for large files.
|
||||
|
||||
Returns:
|
||||
@ -420,32 +419,32 @@ class Sequence(object):
|
||||
Examples:
|
||||
|
||||
```python
|
||||
from skimage.io import imread
|
||||
from skimage.transform import resize
|
||||
import numpy as np
|
||||
import math
|
||||
from skimage.io import imread
|
||||
from skimage.transform import resize
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
# Here, `x_set` is list of path to the images
|
||||
# and `y_set` are the associated classes.
|
||||
# Here, `x_set` is list of path to the images
|
||||
# and `y_set` are the associated classes.
|
||||
|
||||
class CIFAR10Sequence(Sequence):
|
||||
class CIFAR10Sequence(Sequence):
|
||||
|
||||
def __init__(self, x_set, y_set, batch_size):
|
||||
self.x, self.y = x_set, y_set
|
||||
self.batch_size = batch_size
|
||||
def __init__(self, x_set, y_set, batch_size):
|
||||
self.x, self.y = x_set, y_set
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return math.ceil(len(self.x) / self.batch_size)
|
||||
def __len__(self):
|
||||
return math.ceil(len(self.x) / self.batch_size)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
batch_x = self.x[idx * self.batch_size:(idx + 1) *
|
||||
self.batch_size]
|
||||
batch_y = self.y[idx * self.batch_size:(idx + 1) *
|
||||
self.batch_size]
|
||||
def __getitem__(self, idx):
|
||||
batch_x = self.x[idx * self.batch_size:(idx + 1) *
|
||||
self.batch_size]
|
||||
batch_y = self.y[idx * self.batch_size:(idx + 1) *
|
||||
self.batch_size]
|
||||
|
||||
return np.array([
|
||||
resize(imread(file_name), (200, 200))
|
||||
for file_name in batch_x]), np.array(batch_y)
|
||||
return np.array([
|
||||
resize(imread(file_name), (200, 200))
|
||||
for file_name in batch_x]), np.array(batch_y)
|
||||
```
|
||||
"""
|
||||
|
||||
@ -485,10 +484,10 @@ def iter_sequence_infinite(seq):
|
||||
"""Iterates indefinitely over a Sequence.
|
||||
|
||||
Arguments:
|
||||
seq: Sequence instance.
|
||||
seq: `Sequence` instance.
|
||||
|
||||
Yields:
|
||||
Batches of data from the Sequence.
|
||||
Batches of data from the `Sequence`.
|
||||
"""
|
||||
while True:
|
||||
for item in seq:
|
||||
|
@ -234,7 +234,7 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
|
||||
|
||||
@keras_export('keras.utils.serialize_keras_object')
|
||||
def serialize_keras_object(instance):
|
||||
"""Serialize Keras object into JSON."""
|
||||
"""Serialize a Keras object into a JSON-compatible representation."""
|
||||
_, instance = tf_decorator.unwrap(instance)
|
||||
if instance is None:
|
||||
return None
|
||||
@ -327,6 +327,7 @@ def deserialize_keras_object(identifier,
|
||||
module_objects=None,
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
"""Turns the serialized form of a Keras object back into an actual object."""
|
||||
if identifier is None:
|
||||
return None
|
||||
|
||||
|
@ -27,7 +27,18 @@ def to_categorical(y, num_classes=None, dtype='float32'):
|
||||
|
||||
E.g. for use with categorical_crossentropy.
|
||||
|
||||
Usage Example:
|
||||
Arguments:
|
||||
y: class vector to be converted into a matrix
|
||||
(integers from 0 to num_classes).
|
||||
num_classes: total number of classes. If `None`, this would be inferred
|
||||
as the (largest number in `y`) + 1.
|
||||
dtype: The data type expected by the input. Default: `'float32'`.
|
||||
|
||||
Returns:
|
||||
A binary matrix representation of the input. The classes axis is placed
|
||||
last.
|
||||
|
||||
Example:
|
||||
|
||||
>>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
|
||||
>>> a = tf.constant(a, shape=[4, 4])
|
||||
@ -51,29 +62,6 @@ def to_categorical(y, num_classes=None, dtype='float32'):
|
||||
>>> print(np.around(loss, 5))
|
||||
[0. 0. 0. 0.]
|
||||
|
||||
Arguments:
|
||||
y: class vector to be converted into a matrix
|
||||
(integers from 0 to num_classes).
|
||||
num_classes: total number of classes. If `None`, this would be inferred
|
||||
as the (largest number in `y`) + 1.
|
||||
dtype: The data type expected by the input. Default: `'float32'`.
|
||||
|
||||
Returns:
|
||||
A binary matrix representation of the input. The classes axis is placed
|
||||
last.
|
||||
|
||||
Usage example:
|
||||
|
||||
>>> y = [0, 1, 2, 3, 3, 1, 0]
|
||||
>>> tf.keras.utils.to_categorical(y, 4)
|
||||
array([[1., 0., 0., 0.],
|
||||
[0., 1., 0., 0.],
|
||||
[0., 0., 1., 0.],
|
||||
[0., 0., 0., 1.],
|
||||
[0., 0., 0., 1.],
|
||||
[0., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]], dtype=float32)
|
||||
|
||||
Raises:
|
||||
Value Error: If input contains string value
|
||||
|
||||
@ -100,7 +88,7 @@ def normalize(x, axis=-1, order=2):
|
||||
Arguments:
|
||||
x: Numpy array to normalize.
|
||||
axis: axis along which to normalize.
|
||||
order: Normalization order (e.g. 2 for L2 norm).
|
||||
order: Normalization order (e.g. `order=2` for L2 norm).
|
||||
|
||||
Returns:
|
||||
A normalized copy of the array.
|
||||
|
Loading…
Reference in New Issue
Block a user