Docstring fixes in Keras datasets, callbacks, and utilities
PiperOrigin-RevId: 304544666 Change-Id: I6049cf11ff6fe8090dfadd36cfcebe4b2443bf5f
This commit is contained in:
parent
394b6bb711
commit
8a370a0077
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# pylint: disable=g-import-not-at-top
|
# pylint: disable=g-import-not-at-top
|
||||||
|
# pylint: disable=g-classes-have-attributes
|
||||||
"""Callbacks: utilities called at certain points during model training.
|
"""Callbacks: utilities called at certain points during model training.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -336,8 +337,8 @@ class CallbackList(object):
|
|||||||
This function should only be called during TRAIN mode.
|
This function should only be called during TRAIN mode.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
epoch: integer, index of epoch.
|
epoch: Integer, index of epoch.
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
logs = logs or {}
|
logs = logs or {}
|
||||||
@ -357,8 +358,8 @@ class CallbackList(object):
|
|||||||
This function should only be called during TRAIN mode.
|
This function should only be called during TRAIN mode.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
epoch: integer, index of epoch.
|
epoch: Integer, index of epoch.
|
||||||
logs: dict, metric results for this training epoch, and for the
|
logs: Dict, metric results for this training epoch, and for the
|
||||||
validation epoch if validation is performed. Validation result keys
|
validation epoch if validation is performed. Validation result keys
|
||||||
are prefixed with `val_`.
|
are prefixed with `val_`.
|
||||||
"""
|
"""
|
||||||
@ -376,8 +377,8 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_train_batch_begin` methods of its callbacks.
|
"""Calls the `on_train_batch_begin` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||||
number and the size of the batch.
|
number and the size of the batch.
|
||||||
"""
|
"""
|
||||||
# TODO(b/150629188): Make ProgBarLogger callback not use batch hooks
|
# TODO(b/150629188): Make ProgBarLogger callback not use batch hooks
|
||||||
@ -389,8 +390,8 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_train_batch_end` methods of its callbacks.
|
"""Calls the `on_train_batch_end` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Aggregated metric results up until this batch.
|
logs: Dict. Aggregated metric results up until this batch.
|
||||||
"""
|
"""
|
||||||
if self._should_call_train_batch_hooks:
|
if self._should_call_train_batch_hooks:
|
||||||
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
|
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
|
||||||
@ -399,8 +400,8 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_test_batch_begin` methods of its callbacks.
|
"""Calls the `on_test_batch_begin` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||||
number and the size of the batch.
|
number and the size of the batch.
|
||||||
"""
|
"""
|
||||||
if self._should_call_test_batch_hooks:
|
if self._should_call_test_batch_hooks:
|
||||||
@ -410,8 +411,8 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_test_batch_end` methods of its callbacks.
|
"""Calls the `on_test_batch_end` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Aggregated metric results up until this batch.
|
logs: Dict. Aggregated metric results up until this batch.
|
||||||
"""
|
"""
|
||||||
if self._should_call_test_batch_hooks:
|
if self._should_call_test_batch_hooks:
|
||||||
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
|
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
|
||||||
@ -420,8 +421,8 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_predict_batch_begin` methods of its callbacks.
|
"""Calls the `on_predict_batch_begin` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||||
number and the size of the batch.
|
number and the size of the batch.
|
||||||
"""
|
"""
|
||||||
if self._should_call_predict_batch_hooks:
|
if self._should_call_predict_batch_hooks:
|
||||||
@ -431,8 +432,8 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_predict_batch_end` methods of its callbacks.
|
"""Calls the `on_predict_batch_end` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Aggregated metric results up until this batch.
|
logs: Dict. Aggregated metric results up until this batch.
|
||||||
"""
|
"""
|
||||||
if self._should_call_predict_batch_hooks:
|
if self._should_call_predict_batch_hooks:
|
||||||
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
|
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
|
||||||
@ -441,7 +442,7 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_train_begin` methods of its callbacks.
|
"""Calls the `on_train_begin` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
logs = logs or {}
|
logs = logs or {}
|
||||||
@ -458,7 +459,7 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_train_end` methods of its callbacks.
|
"""Calls the `on_train_end` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
logs = logs or {}
|
logs = logs or {}
|
||||||
@ -475,7 +476,7 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_test_begin` methods of its callbacks.
|
"""Calls the `on_test_begin` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
logs = logs or {}
|
logs = logs or {}
|
||||||
@ -492,7 +493,7 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_test_end` methods of its callbacks.
|
"""Calls the `on_test_end` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
logs = logs or {}
|
logs = logs or {}
|
||||||
@ -509,7 +510,7 @@ class CallbackList(object):
|
|||||||
"""Calls the 'on_predict_begin` methods of its callbacks.
|
"""Calls the 'on_predict_begin` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
logs = logs or {}
|
logs = logs or {}
|
||||||
@ -526,7 +527,7 @@ class CallbackList(object):
|
|||||||
"""Calls the `on_predict_end` methods of its callbacks.
|
"""Calls the `on_predict_end` methods of its callbacks.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
logs = logs or {}
|
logs = logs or {}
|
||||||
@ -548,9 +549,9 @@ class Callback(object):
|
|||||||
"""Abstract base class used to build new callbacks.
|
"""Abstract base class used to build new callbacks.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
params: dict. Training parameters
|
params: Dict. Training parameters
|
||||||
(eg. verbosity, batch size, number of epochs...).
|
(eg. verbosity, batch size, number of epochs...).
|
||||||
model: instance of `keras.models.Model`.
|
model: Instance of `keras.models.Model`.
|
||||||
Reference of the model being trained.
|
Reference of the model being trained.
|
||||||
|
|
||||||
The `logs` dictionary that callback methods
|
The `logs` dictionary that callback methods
|
||||||
@ -591,8 +592,8 @@ class Callback(object):
|
|||||||
be called during TRAIN mode.
|
be called during TRAIN mode.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
epoch: integer, index of epoch.
|
epoch: Integer, index of epoch.
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -604,8 +605,8 @@ class Callback(object):
|
|||||||
be called during TRAIN mode.
|
be called during TRAIN mode.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
epoch: integer, index of epoch.
|
epoch: Integer, index of epoch.
|
||||||
logs: dict, metric results for this training epoch, and for the
|
logs: Dict, metric results for this training epoch, and for the
|
||||||
validation epoch if validation is performed. Validation result keys
|
validation epoch if validation is performed. Validation result keys
|
||||||
are prefixed with `val_`.
|
are prefixed with `val_`.
|
||||||
"""
|
"""
|
||||||
@ -618,8 +619,8 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||||
number and the size of the batch.
|
number and the size of the batch.
|
||||||
"""
|
"""
|
||||||
# For backwards compatibility.
|
# For backwards compatibility.
|
||||||
@ -633,8 +634,8 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Aggregated metric results up until this batch.
|
logs: Dict. Aggregated metric results up until this batch.
|
||||||
"""
|
"""
|
||||||
# For backwards compatibility.
|
# For backwards compatibility.
|
||||||
self.on_batch_end(batch, logs=logs)
|
self.on_batch_end(batch, logs=logs)
|
||||||
@ -650,8 +651,8 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||||
number and the size of the batch.
|
number and the size of the batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -666,8 +667,8 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Aggregated metric results up until this batch.
|
logs: Dict. Aggregated metric results up until this batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@doc_controls.for_subclass_implementers
|
@doc_controls.for_subclass_implementers
|
||||||
@ -678,8 +679,8 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict. Has keys `batch` and `size` representing the current batch
|
||||||
number and the size of the batch.
|
number and the size of the batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -691,8 +692,8 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: dict. Aggregated metric results up until this batch.
|
logs: Dict. Aggregated metric results up until this batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@doc_controls.for_subclass_implementers
|
@doc_controls.for_subclass_implementers
|
||||||
@ -702,7 +703,7 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -713,7 +714,7 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -724,7 +725,7 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -735,7 +736,7 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -746,7 +747,7 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -757,7 +758,7 @@ class Callback(object):
|
|||||||
Subclasses should override for any actions to run.
|
Subclasses should override for any actions to run.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
logs: dict. Currently no data is passed to this argument for this method
|
logs: Dict. Currently no data is passed to this argument for this method
|
||||||
but that may change in the future.
|
but that may change in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -847,7 +848,7 @@ class ProgbarLogger(Callback):
|
|||||||
"""Callback that prints metrics to stdout.
|
"""Callback that prints metrics to stdout.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
count_mode: One of "steps" or "samples".
|
count_mode: One of `"steps"` or `"samples"`.
|
||||||
Whether the progress bar should
|
Whether the progress bar should
|
||||||
count samples seen or steps (batches) seen.
|
count samples seen or steps (batches) seen.
|
||||||
stateful_metrics: Iterable of string names of metrics that
|
stateful_metrics: Iterable of string names of metrics that
|
||||||
@ -1411,7 +1412,7 @@ class EarlyStopping(Callback):
|
|||||||
"""Stop training when a monitored metric has stopped improving.
|
"""Stop training when a monitored metric has stopped improving.
|
||||||
|
|
||||||
Assuming the goal of a training is to minimize the loss. With this, the
|
Assuming the goal of a training is to minimize the loss. With this, the
|
||||||
metric to be monitored would be 'loss', and mode would be 'min'. A
|
metric to be monitored would be `'loss'`, and mode would be `'min'`. A
|
||||||
`model.fit()` training loop will check at end of every epoch whether
|
`model.fit()` training loop will check at end of every epoch whether
|
||||||
the loss is no longer decreasing, considering the `min_delta` and
|
the loss is no longer decreasing, considering the `min_delta` and
|
||||||
`patience` if applicable. Once it's found no longer decreasing,
|
`patience` if applicable. Once it's found no longer decreasing,
|
||||||
@ -1420,6 +1421,30 @@ class EarlyStopping(Callback):
|
|||||||
The quantity to be monitored needs to be available in `logs` dict.
|
The quantity to be monitored needs to be available in `logs` dict.
|
||||||
To make it so, pass the loss or metrics at `model.compile()`.
|
To make it so, pass the loss or metrics at `model.compile()`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
monitor: Quantity to be monitored.
|
||||||
|
min_delta: Minimum change in the monitored quantity
|
||||||
|
to qualify as an improvement, i.e. an absolute
|
||||||
|
change of less than min_delta, will count as no
|
||||||
|
improvement.
|
||||||
|
patience: Number of epochs with no improvement
|
||||||
|
after which training will be stopped.
|
||||||
|
verbose: verbosity mode.
|
||||||
|
mode: One of `{"auto", "min", "max"}`. In `min` mode,
|
||||||
|
training will stop when the quantity
|
||||||
|
monitored has stopped decreasing; in `"max"`
|
||||||
|
mode it will stop when the quantity
|
||||||
|
monitored has stopped increasing; in `"auto"`
|
||||||
|
mode, the direction is automatically inferred
|
||||||
|
from the name of the monitored quantity.
|
||||||
|
baseline: Baseline value for the monitored quantity.
|
||||||
|
Training will stop if the model doesn't show improvement over the
|
||||||
|
baseline.
|
||||||
|
restore_best_weights: Whether to restore model weights from
|
||||||
|
the epoch with the best value of the monitored quantity.
|
||||||
|
If False, the model weights obtained at the last step of
|
||||||
|
training are used.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
>>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
|
>>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
|
||||||
@ -1442,32 +1467,6 @@ class EarlyStopping(Callback):
|
|||||||
mode='auto',
|
mode='auto',
|
||||||
baseline=None,
|
baseline=None,
|
||||||
restore_best_weights=False):
|
restore_best_weights=False):
|
||||||
"""Initialize an EarlyStopping callback.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
monitor: Quantity to be monitored.
|
|
||||||
min_delta: Minimum change in the monitored quantity
|
|
||||||
to qualify as an improvement, i.e. an absolute
|
|
||||||
change of less than min_delta, will count as no
|
|
||||||
improvement.
|
|
||||||
patience: Number of epochs with no improvement
|
|
||||||
after which training will be stopped.
|
|
||||||
verbose: verbosity mode.
|
|
||||||
mode: One of `{"auto", "min", "max"}`. In `min` mode,
|
|
||||||
training will stop when the quantity
|
|
||||||
monitored has stopped decreasing; in `max`
|
|
||||||
mode it will stop when the quantity
|
|
||||||
monitored has stopped increasing; in `auto`
|
|
||||||
mode, the direction is automatically inferred
|
|
||||||
from the name of the monitored quantity.
|
|
||||||
baseline: Baseline value for the monitored quantity.
|
|
||||||
Training will stop if the model doesn't show improvement over the
|
|
||||||
baseline.
|
|
||||||
restore_best_weights: Whether to restore model weights from
|
|
||||||
the epoch with the best value of the monitored quantity.
|
|
||||||
If False, the model weights obtained at the last step of
|
|
||||||
training are used.
|
|
||||||
"""
|
|
||||||
super(EarlyStopping, self).__init__()
|
super(EarlyStopping, self).__init__()
|
||||||
|
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
@ -1550,18 +1549,19 @@ class RemoteMonitor(Callback):
|
|||||||
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
|
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
|
||||||
HTTP POST, with a `data` argument which is a
|
HTTP POST, with a `data` argument which is a
|
||||||
JSON-encoded dictionary of event data.
|
JSON-encoded dictionary of event data.
|
||||||
If send_as_json is set to True, the content type of the request will be
|
If `send_as_json=True`, the content type of the request will be
|
||||||
application/json. Otherwise the serialized JSON will be sent within a form.
|
`"application/json"`.
|
||||||
|
Otherwise the serialized JSON will be sent within a form.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
root: String; root url of the target server.
|
root: String; root url of the target server.
|
||||||
path: String; path relative to `root` to which the events will be sent.
|
path: String; path relative to `root` to which the events will be sent.
|
||||||
field: String; JSON field under which the data will be stored.
|
field: String; JSON field under which the data will be stored.
|
||||||
The field is used only if the payload is sent within a form
|
The field is used only if the payload is sent within a form
|
||||||
(i.e. send_as_json is set to False).
|
(i.e. send_as_json is set to False).
|
||||||
headers: Dictionary; optional custom HTTP headers.
|
headers: Dictionary; optional custom HTTP headers.
|
||||||
send_as_json: Boolean; whether the request should be
|
send_as_json: Boolean; whether the request should be
|
||||||
sent as application/json.
|
sent as `"application/json"`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -1613,6 +1613,12 @@ class LearningRateScheduler(Callback):
|
|||||||
and current learning rate, and applies the updated learning rate
|
and current learning rate, and applies the updated learning rate
|
||||||
on the optimizer.
|
on the optimizer.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
schedule: a function that takes an epoch index (integer, indexed from 0)
|
||||||
|
and current learning rate (float) as inputs and returns a new
|
||||||
|
learning rate as output (float).
|
||||||
|
verbose: int. 0: quiet, 1: update messages.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
>>> # This function keeps the initial learning rate for the first ten epochs
|
>>> # This function keeps the initial learning rate for the first ten epochs
|
||||||
@ -1637,14 +1643,6 @@ class LearningRateScheduler(Callback):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, schedule, verbose=0):
|
def __init__(self, schedule, verbose=0):
|
||||||
"""Initialize a `keras.callbacks.LearningRateScheduler` callback.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
schedule: a function that takes an epoch index (integer, indexed from 0)
|
|
||||||
and current learning rate (float) as inputs and returns a new
|
|
||||||
learning rate as output (float).
|
|
||||||
verbose: int. 0: quiet, 1: update messages.
|
|
||||||
"""
|
|
||||||
super(LearningRateScheduler, self).__init__()
|
super(LearningRateScheduler, self).__init__()
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@ -1689,7 +1687,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
|||||||
If you have installed TensorFlow with pip, you should be able
|
If you have installed TensorFlow with pip, you should be able
|
||||||
to launch TensorBoard from the command line:
|
to launch TensorBoard from the command line:
|
||||||
|
|
||||||
```sh
|
```
|
||||||
tensorboard --logdir=path_to_your_logs
|
tensorboard --logdir=path_to_your_logs
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -1697,24 +1695,27 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
|||||||
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
|
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
|
||||||
|
|
||||||
Example (Basic):
|
Example (Basic):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
|
||||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||||
# run the tensorboard command to view the visualizations.
|
# run the tensorboard command to view the visualizations.
|
||||||
```
|
```
|
||||||
|
|
||||||
Example (Profile):
|
Example (Profile):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# profile a single batch, e.g. the 5th batch.
|
# profile a single batch, e.g. the 5th batch.
|
||||||
tensorboard_callback =
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
|
||||||
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5)
|
profile_batch=5)
|
||||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||||
# run the tensorboard command to view the visualizations in profile plugin.
|
# Now run the tensorboard command to view the visualizations (profile plugin).
|
||||||
|
|
||||||
# profile a range of batches, e.g. from 10 to 20.
|
# profile a range of batches, e.g. from 10 to 20.
|
||||||
tensorboard_callback =
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
|
||||||
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20')
|
profile_batch='10,20')
|
||||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||||
# run the tensorboard command to view the visualizations in profile plugin.
|
# Now run the tensorboard command to view the visualizations (profile plugin).
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -2143,14 +2144,15 @@ class ReduceLROnPlateau(Callback):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
monitor: quantity to be monitored.
|
monitor: quantity to be monitored.
|
||||||
factor: factor by which the learning rate will be reduced. new_lr = lr *
|
factor: factor by which the learning rate will be reduced.
|
||||||
factor
|
`new_lr = lr * factor`.
|
||||||
patience: number of epochs with no improvement after which learning rate
|
patience: number of epochs with no improvement after which learning rate
|
||||||
will be reduced.
|
will be reduced.
|
||||||
verbose: int. 0: quiet, 1: update messages.
|
verbose: int. 0: quiet, 1: update messages.
|
||||||
mode: one of {auto, min, max}. In `min` mode, lr will be reduced when the
|
mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode,
|
||||||
quantity monitored has stopped decreasing; in `max` mode it will be
|
the learning rate will be reduced when the
|
||||||
reduced when the quantity monitored has stopped increasing; in `auto`
|
quantity monitored has stopped decreasing; in `'max'` mode it will be
|
||||||
|
reduced when the quantity monitored has stopped increasing; in `'auto'`
|
||||||
mode, the direction is automatically inferred from the name of the
|
mode, the direction is automatically inferred from the name of the
|
||||||
monitored quantity.
|
monitored quantity.
|
||||||
min_delta: threshold for measuring the new optimum, to only focus on
|
min_delta: threshold for measuring the new optimum, to only focus on
|
||||||
@ -2249,10 +2251,10 @@ class ReduceLROnPlateau(Callback):
|
|||||||
|
|
||||||
@keras_export('keras.callbacks.CSVLogger')
|
@keras_export('keras.callbacks.CSVLogger')
|
||||||
class CSVLogger(Callback):
|
class CSVLogger(Callback):
|
||||||
"""Callback that streams epoch results to a csv file.
|
"""Callback that streams epoch results to a CSV file.
|
||||||
|
|
||||||
Supports all values that can be represented as a string,
|
Supports all values that can be represented as a string,
|
||||||
including 1D iterables such as np.ndarray.
|
including 1D iterables such as `np.ndarray`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -2262,10 +2264,10 @@ class CSVLogger(Callback):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
filename: filename of the csv file, e.g. 'run/log.csv'.
|
filename: Filename of the CSV file, e.g. `'run/log.csv'`.
|
||||||
separator: string used to separate elements in the csv file.
|
separator: String used to separate elements in the CSV file.
|
||||||
append: True: append if file exists (useful for continuing
|
append: Boolean. True: append if file exists (useful for continuing
|
||||||
training). False: overwrite existing file,
|
training). False: overwrite existing file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, filename, separator=',', append=False):
|
def __init__(self, filename, separator=',', append=False):
|
||||||
@ -2348,12 +2350,12 @@ class LambdaCallback(Callback):
|
|||||||
at the appropriate time. Note that the callbacks expects positional
|
at the appropriate time. Note that the callbacks expects positional
|
||||||
arguments, as:
|
arguments, as:
|
||||||
|
|
||||||
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
|
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
|
||||||
`epoch`, `logs`
|
`epoch`, `logs`
|
||||||
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
|
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
|
||||||
`batch`, `logs`
|
`batch`, `logs`
|
||||||
- `on_train_begin` and `on_train_end` expect one positional argument:
|
- `on_train_begin` and `on_train_end` expect one positional argument:
|
||||||
`logs`
|
`logs`
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
on_epoch_begin: called at the beginning of every epoch.
|
on_epoch_begin: called at the beginning of every epoch.
|
||||||
|
@ -40,7 +40,7 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
path: path where to cache the dataset locally
|
path: path where to cache the dataset locally
|
||||||
(relative to ~/.keras/datasets).
|
(relative to `~/.keras/datasets`).
|
||||||
test_split: fraction of the data to reserve as test set.
|
test_split: fraction of the data to reserve as test set.
|
||||||
seed: Random seed for shuffling the data
|
seed: Random seed for shuffling the data
|
||||||
before computing the test split.
|
before computing the test split.
|
||||||
@ -48,10 +48,11 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||||||
|
|
||||||
**x_train, x_test**: numpy arrays with shape (num_samples, 13) containing
|
**x_train, x_test**: numpy arrays with shape `(num_samples, 13)`
|
||||||
either the training samples (for x_train), or test samples (for y_train)
|
containing either the training samples (for x_train),
|
||||||
|
or test samples (for y_train).
|
||||||
|
|
||||||
**y_train, y_test**: numpy arrays of shape (num_samples, ) containing the
|
**y_train, y_test**: numpy arrays of shape `(num_samples,)` containing the
|
||||||
target scalars. The targets are float scalars typically between 10 and
|
target scalars. The targets are float scalars typically between 10 and
|
||||||
50 that represent the home prices in k$.
|
50 that represent the home prices in k$.
|
||||||
"""
|
"""
|
||||||
|
@ -40,9 +40,9 @@ def load_data():
|
|||||||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||||||
|
|
||||||
**x_train, x_test**: uint8 arrays of RGB image data with shape
|
**x_train, x_test**: uint8 arrays of RGB image data with shape
|
||||||
(num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is
|
`(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is
|
||||||
'channels_first', or (num_samples, 32, 32, 3) if the data format
|
`'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format
|
||||||
is 'channels_last'.
|
is `'channels_last'`.
|
||||||
|
|
||||||
**y_train, y_test**: uint8 arrays of category labels
|
**y_train, y_test**: uint8 arrays of category labels
|
||||||
(integers in range 0-9) each with shape (num_samples, 1).
|
(integers in range 0-9) each with shape (num_samples, 1).
|
||||||
|
@ -46,9 +46,9 @@ def load_data(label_mode='fine'):
|
|||||||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||||||
|
|
||||||
**x_train, x_test**: uint8 arrays of RGB image data with shape
|
**x_train, x_test**: uint8 arrays of RGB image data with shape
|
||||||
(num_samples, 3, 32, 32) if the `tf.keras.backend.image_data_format` is
|
`(num_samples, 3, 32, 32)` if `tf.keras.backend.image_data_format()` is
|
||||||
'channels_first', or (num_samples, 32, 32, 3) if the data format
|
`'channels_first'`, or `(num_samples, 32, 32, 3)` if the data format
|
||||||
is 'channels_last'.
|
is `'channels_last'`.
|
||||||
|
|
||||||
**y_train, y_test**: uint8 arrays of category labels with shape
|
**y_train, y_test**: uint8 arrays of category labels with shape
|
||||||
(num_samples, 1).
|
(num_samples, 1).
|
||||||
|
@ -80,7 +80,7 @@ def load_data(path='imdb.npz',
|
|||||||
|
|
||||||
**x_train, x_test**: lists of sequences, which are lists of indexes
|
**x_train, x_test**: lists of sequences, which are lists of indexes
|
||||||
(integers). If the num_words argument was specific, the maximum
|
(integers). If the num_words argument was specific, the maximum
|
||||||
possible index value is num_words-1. If the `maxlen` argument was
|
possible index value is `num_words - 1`. If the `maxlen` argument was
|
||||||
specified, the largest possible sequence length is `maxlen`.
|
specified, the largest possible sequence length is `maxlen`.
|
||||||
|
|
||||||
**y_train, y_test**: lists of integer labels (1 or 0).
|
**y_train, y_test**: lists of integer labels (1 or 0).
|
||||||
|
@ -31,12 +31,12 @@ def load_data(path='mnist.npz'):
|
|||||||
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
|
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
|
||||||
along with a test set of 10,000 images.
|
along with a test set of 10,000 images.
|
||||||
More info can be found at the
|
More info can be found at the
|
||||||
(MNIST homepage)[http://yann.lecun.com/exdb/mnist/].
|
[MNIST homepage](http://yann.lecun.com/exdb/mnist/).
|
||||||
|
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
path: path where to cache the dataset locally
|
path: path where to cache the dataset locally
|
||||||
(relative to ~/.keras/datasets).
|
(relative to `~/.keras/datasets`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
||||||
|
@ -42,11 +42,11 @@ def load_data(path='reuters.npz',
|
|||||||
"""Loads the Reuters newswire classification dataset.
|
"""Loads the Reuters newswire classification dataset.
|
||||||
|
|
||||||
This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics.
|
This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics.
|
||||||
|
|
||||||
This was originally generated by parsing and preprocessing the classic
|
This was originally generated by parsing and preprocessing the classic
|
||||||
Reuters-21578 dataset, but the preprocessing code is no longer packaged
|
Reuters-21578 dataset, but the preprocessing code is no longer packaged
|
||||||
with Keras.
|
with Keras. See this
|
||||||
|
[github discussion](https://github.com/keras-team/keras/issues/12072)
|
||||||
See this [github discussion](https://github.com/keras-team/keras/issues/12072)
|
|
||||||
for more info.
|
for more info.
|
||||||
|
|
||||||
Each newswire is encoded as a list of word indexes (integers).
|
Each newswire is encoded as a list of word indexes (integers).
|
||||||
@ -91,7 +91,7 @@ def load_data(path='reuters.npz',
|
|||||||
|
|
||||||
**x_train, x_test**: lists of sequences, which are lists of indexes
|
**x_train, x_test**: lists of sequences, which are lists of indexes
|
||||||
(integers). If the num_words argument was specific, the maximum
|
(integers). If the num_words argument was specific, the maximum
|
||||||
possible index value is num_words-1. If the `maxlen` argument was
|
possible index value is `num_words - 1`. If the `maxlen` argument was
|
||||||
specified, the largest possible sequence length is `maxlen`.
|
specified, the largest possible sequence length is `maxlen`.
|
||||||
|
|
||||||
**y_train, y_test**: lists of integer labels (1 or 0).
|
**y_train, y_test**: lists of integer labels (1 or 0).
|
||||||
|
@ -195,9 +195,9 @@ def get_file(fname,
|
|||||||
fname: Name of the file. If an absolute path `/path/to/file.txt` is
|
fname: Name of the file. If an absolute path `/path/to/file.txt` is
|
||||||
specified the file will be saved at that location.
|
specified the file will be saved at that location.
|
||||||
origin: Original URL of the file.
|
origin: Original URL of the file.
|
||||||
untar: Deprecated in favor of 'extract'.
|
untar: Deprecated in favor of `extract` argument.
|
||||||
boolean, whether the file should be decompressed
|
boolean, whether the file should be decompressed
|
||||||
md5_hash: Deprecated in favor of 'file_hash'.
|
md5_hash: Deprecated in favor of `file_hash` argument.
|
||||||
md5 hash of the file for verification
|
md5 hash of the file for verification
|
||||||
file_hash: The expected hash string of the file after download.
|
file_hash: The expected hash string of the file after download.
|
||||||
The sha256 and md5 hash algorithms are both supported.
|
The sha256 and md5 hash algorithms are both supported.
|
||||||
@ -205,17 +205,16 @@ def get_file(fname,
|
|||||||
saved. If an absolute path `/path/to/folder` is
|
saved. If an absolute path `/path/to/folder` is
|
||||||
specified the file will be saved at that location.
|
specified the file will be saved at that location.
|
||||||
hash_algorithm: Select the hash algorithm to verify the file.
|
hash_algorithm: Select the hash algorithm to verify the file.
|
||||||
options are 'md5', 'sha256', and 'auto'.
|
options are `'md5'`, `'sha256'`, and `'auto'`.
|
||||||
The default 'auto' detects the hash algorithm in use.
|
The default 'auto' detects the hash algorithm in use.
|
||||||
extract: True tries extracting the file as an Archive, like tar or zip.
|
extract: True tries extracting the file as an Archive, like tar or zip.
|
||||||
archive_format: Archive format to try for extracting the file.
|
archive_format: Archive format to try for extracting the file.
|
||||||
Options are 'auto', 'tar', 'zip', and None.
|
Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
|
||||||
'tar' includes tar, tar.gz, and tar.bz files.
|
`'tar'` includes tar, tar.gz, and tar.bz files.
|
||||||
The default 'auto' is ['tar', 'zip'].
|
The default `'auto'` corresponds to `['tar', 'zip']`.
|
||||||
None or an empty list will return no matches found.
|
None or an empty list will return no matches found.
|
||||||
cache_dir: Location to store cached files, when None it
|
cache_dir: Location to store cached files, when None it
|
||||||
defaults to the [Keras
|
defaults to the default directory `~/.keras/`.
|
||||||
Directory](/faq/#where-is-the-keras-configuration-filed-stored).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the downloaded file
|
Path to the downloaded file
|
||||||
@ -315,8 +314,8 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
fpath: path to the file being validated
|
fpath: path to the file being validated
|
||||||
algorithm: hash algorithm, one of 'auto', 'sha256', or 'md5'.
|
algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
|
||||||
The default 'auto' detects the hash algorithm in use.
|
The default `'auto'` detects the hash algorithm in use.
|
||||||
chunk_size: Bytes to read at a time, important for large files.
|
chunk_size: Bytes to read at a time, important for large files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -420,32 +419,32 @@ class Sequence(object):
|
|||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from skimage.io import imread
|
from skimage.io import imread
|
||||||
from skimage.transform import resize
|
from skimage.transform import resize
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
|
||||||
# Here, `x_set` is list of path to the images
|
# Here, `x_set` is list of path to the images
|
||||||
# and `y_set` are the associated classes.
|
# and `y_set` are the associated classes.
|
||||||
|
|
||||||
class CIFAR10Sequence(Sequence):
|
class CIFAR10Sequence(Sequence):
|
||||||
|
|
||||||
def __init__(self, x_set, y_set, batch_size):
|
def __init__(self, x_set, y_set, batch_size):
|
||||||
self.x, self.y = x_set, y_set
|
self.x, self.y = x_set, y_set
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return math.ceil(len(self.x) / self.batch_size)
|
return math.ceil(len(self.x) / self.batch_size)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
batch_x = self.x[idx * self.batch_size:(idx + 1) *
|
batch_x = self.x[idx * self.batch_size:(idx + 1) *
|
||||||
self.batch_size]
|
self.batch_size]
|
||||||
batch_y = self.y[idx * self.batch_size:(idx + 1) *
|
batch_y = self.y[idx * self.batch_size:(idx + 1) *
|
||||||
self.batch_size]
|
self.batch_size]
|
||||||
|
|
||||||
return np.array([
|
return np.array([
|
||||||
resize(imread(file_name), (200, 200))
|
resize(imread(file_name), (200, 200))
|
||||||
for file_name in batch_x]), np.array(batch_y)
|
for file_name in batch_x]), np.array(batch_y)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -485,10 +484,10 @@ def iter_sequence_infinite(seq):
|
|||||||
"""Iterates indefinitely over a Sequence.
|
"""Iterates indefinitely over a Sequence.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
seq: Sequence instance.
|
seq: `Sequence` instance.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Batches of data from the Sequence.
|
Batches of data from the `Sequence`.
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
for item in seq:
|
for item in seq:
|
||||||
|
@ -234,7 +234,7 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
|
|||||||
|
|
||||||
@keras_export('keras.utils.serialize_keras_object')
|
@keras_export('keras.utils.serialize_keras_object')
|
||||||
def serialize_keras_object(instance):
|
def serialize_keras_object(instance):
|
||||||
"""Serialize Keras object into JSON."""
|
"""Serialize a Keras object into a JSON-compatible representation."""
|
||||||
_, instance = tf_decorator.unwrap(instance)
|
_, instance = tf_decorator.unwrap(instance)
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return None
|
return None
|
||||||
@ -327,6 +327,7 @@ def deserialize_keras_object(identifier,
|
|||||||
module_objects=None,
|
module_objects=None,
|
||||||
custom_objects=None,
|
custom_objects=None,
|
||||||
printable_module_name='object'):
|
printable_module_name='object'):
|
||||||
|
"""Turns the serialized form of a Keras object back into an actual object."""
|
||||||
if identifier is None:
|
if identifier is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -27,7 +27,18 @@ def to_categorical(y, num_classes=None, dtype='float32'):
|
|||||||
|
|
||||||
E.g. for use with categorical_crossentropy.
|
E.g. for use with categorical_crossentropy.
|
||||||
|
|
||||||
Usage Example:
|
Arguments:
|
||||||
|
y: class vector to be converted into a matrix
|
||||||
|
(integers from 0 to num_classes).
|
||||||
|
num_classes: total number of classes. If `None`, this would be inferred
|
||||||
|
as the (largest number in `y`) + 1.
|
||||||
|
dtype: The data type expected by the input. Default: `'float32'`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A binary matrix representation of the input. The classes axis is placed
|
||||||
|
last.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
>>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
|
>>> a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
|
||||||
>>> a = tf.constant(a, shape=[4, 4])
|
>>> a = tf.constant(a, shape=[4, 4])
|
||||||
@ -51,29 +62,6 @@ def to_categorical(y, num_classes=None, dtype='float32'):
|
|||||||
>>> print(np.around(loss, 5))
|
>>> print(np.around(loss, 5))
|
||||||
[0. 0. 0. 0.]
|
[0. 0. 0. 0.]
|
||||||
|
|
||||||
Arguments:
|
|
||||||
y: class vector to be converted into a matrix
|
|
||||||
(integers from 0 to num_classes).
|
|
||||||
num_classes: total number of classes. If `None`, this would be inferred
|
|
||||||
as the (largest number in `y`) + 1.
|
|
||||||
dtype: The data type expected by the input. Default: `'float32'`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A binary matrix representation of the input. The classes axis is placed
|
|
||||||
last.
|
|
||||||
|
|
||||||
Usage example:
|
|
||||||
|
|
||||||
>>> y = [0, 1, 2, 3, 3, 1, 0]
|
|
||||||
>>> tf.keras.utils.to_categorical(y, 4)
|
|
||||||
array([[1., 0., 0., 0.],
|
|
||||||
[0., 1., 0., 0.],
|
|
||||||
[0., 0., 1., 0.],
|
|
||||||
[0., 0., 0., 1.],
|
|
||||||
[0., 0., 0., 1.],
|
|
||||||
[0., 1., 0., 0.],
|
|
||||||
[1., 0., 0., 0.]], dtype=float32)
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Value Error: If input contains string value
|
Value Error: If input contains string value
|
||||||
|
|
||||||
@ -100,7 +88,7 @@ def normalize(x, axis=-1, order=2):
|
|||||||
Arguments:
|
Arguments:
|
||||||
x: Numpy array to normalize.
|
x: Numpy array to normalize.
|
||||||
axis: axis along which to normalize.
|
axis: axis along which to normalize.
|
||||||
order: Normalization order (e.g. 2 for L2 norm).
|
order: Normalization order (e.g. `order=2` for L2 norm).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A normalized copy of the array.
|
A normalized copy of the array.
|
||||||
|
Loading…
Reference in New Issue
Block a user