Remove the 'mode' argument in on_epoch_begin and on_epoch_end methods of CallbackList and Callback classes, to eliminate exposure of ModeKeys api that's intended to be TensorFlow-internal. Add doc that says the methods should only be called during TRAIN mode.
PiperOrigin-RevId: 226252687
This commit is contained in:
parent
0445684a64
commit
65093ecfe6
tensorflow
python/keras
tools/api/golden
v1
tensorflow.keras.callbacks.-c-s-v-logger.pbtxttensorflow.keras.callbacks.-callback.pbtxttensorflow.keras.callbacks.-early-stopping.pbtxttensorflow.keras.callbacks.-history.pbtxttensorflow.keras.callbacks.-lambda-callback.pbtxttensorflow.keras.callbacks.-model-checkpoint.pbtxttensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxttensorflow.keras.callbacks.-remote-monitor.pbtxttensorflow.keras.callbacks.-terminate-on-na-n.pbtxt
v2
tensorflow.keras.callbacks.-c-s-v-logger.pbtxttensorflow.keras.callbacks.-callback.pbtxttensorflow.keras.callbacks.-early-stopping.pbtxttensorflow.keras.callbacks.-history.pbtxttensorflow.keras.callbacks.-lambda-callback.pbtxttensorflow.keras.callbacks.-model-checkpoint.pbtxttensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxttensorflow.keras.callbacks.-remote-monitor.pbtxttensorflow.keras.callbacks.-terminate-on-na-n.pbtxt
@ -242,35 +242,35 @@ class CallbackList(object):
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
self._call_batch_hook(_TRAIN, 'end', batch, logs=logs)
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None, mode='train'):
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
"""Calls the `on_epoch_begin` methods of its callbacks.
|
||||
|
||||
This function should only be called during TRAIN mode.
|
||||
|
||||
Arguments:
|
||||
epoch: integer, index of epoch.
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
mode: One of 'train'/'test'/'predict'
|
||||
"""
|
||||
if mode == _TRAIN:
|
||||
logs = logs or {}
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_begin(epoch, logs)
|
||||
logs = logs or {}
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_begin(epoch, logs)
|
||||
self._reset_batch_timing()
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None, mode='train'):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
"""Calls the `on_epoch_end` methods of its callbacks.
|
||||
|
||||
This function should only be called during TRAIN mode.
|
||||
|
||||
Arguments:
|
||||
epoch: integer, index of epoch.
|
||||
logs: dict, metric results for this training epoch, and for the
|
||||
validation epoch if validation is performed. Validation result keys
|
||||
are prefixed with `val_`.
|
||||
mode: One of 'train'/'test'/'predict'
|
||||
"""
|
||||
if mode == _TRAIN:
|
||||
logs = logs or {}
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_end(epoch, logs)
|
||||
logs = logs or {}
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_end(epoch, logs)
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
"""Calls the `on_train_batch_begin` methods of its callbacks.
|
||||
@ -437,29 +437,29 @@ class Callback(object):
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
"""A backwards compatibility alias for `on_train_batch_end`."""
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None, mode='train'):
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
"""Called at the start of an epoch.
|
||||
|
||||
Subclasses should override for any actions to run.
|
||||
Subclasses should override for any actions to run. This function should only
|
||||
be called during TRAIN mode.
|
||||
|
||||
Arguments:
|
||||
epoch: integer, index of epoch.
|
||||
logs: dict. Currently no data is passed to this argument for this method
|
||||
but that may change in the future.
|
||||
mode: One of 'train'/'test'/'predict'
|
||||
"""
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None, mode='train'):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
"""Called at the end of an epoch.
|
||||
|
||||
Subclasses should override for any actions to run.
|
||||
Subclasses should override for any actions to run. This function should only
|
||||
be called during TRAIN mode.
|
||||
|
||||
Arguments:
|
||||
epoch: integer, index of epoch.
|
||||
logs: dict, metric results for this training epoch, and for the
|
||||
validation epoch if validation is performed. Validation result keys
|
||||
are prefixed with `val_`.
|
||||
mode: One of 'train'/'test'/'predict'
|
||||
"""
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import make_batches
|
||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
|
||||
try:
|
||||
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
|
||||
@ -251,7 +252,8 @@ def model_iteration(model,
|
||||
# Setup work for each epoch
|
||||
epoch_logs = {}
|
||||
model.reset_metrics()
|
||||
callbacks.on_epoch_begin(epoch, epoch_logs, mode=mode)
|
||||
if mode == ModeKeys.TRAIN:
|
||||
callbacks.on_epoch_begin(epoch, epoch_logs)
|
||||
progbar.on_epoch_begin(epoch, epoch_logs)
|
||||
|
||||
if use_steps:
|
||||
@ -371,7 +373,7 @@ def model_iteration(model,
|
||||
|
||||
if mode == 'train':
|
||||
# Epochs only apply to `fit`.
|
||||
callbacks.on_epoch_end(epoch, epoch_logs, mode=mode)
|
||||
callbacks.on_epoch_end(epoch, epoch_logs)
|
||||
progbar.on_epoch_end(epoch, epoch_logs)
|
||||
|
||||
callbacks._call_end_hook(mode)
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.mode_keys import ModeKeys
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -170,7 +171,8 @@ def model_iteration(model,
|
||||
# Setup work for each epoch.
|
||||
model.reset_metrics()
|
||||
epoch_logs = {}
|
||||
callbacks.on_epoch_begin(epoch, epoch_logs, mode=mode)
|
||||
if mode == ModeKeys.TRAIN:
|
||||
callbacks.on_epoch_begin(epoch, epoch_logs)
|
||||
progbar.on_epoch_begin(epoch, epoch_logs)
|
||||
|
||||
for step in range(steps_per_epoch):
|
||||
@ -233,7 +235,7 @@ def model_iteration(model,
|
||||
|
||||
if mode == 'train':
|
||||
# Epochs only apply to `fit`.
|
||||
callbacks.on_epoch_end(epoch, epoch_logs, mode=mode)
|
||||
callbacks.on_epoch_end(epoch, epoch_logs)
|
||||
progbar.on_epoch_end(epoch, epoch_logs)
|
||||
|
||||
callbacks._call_end_hook(mode)
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -16,11 +16,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_predict_batch_begin"
|
||||
|
@ -21,7 +21,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -17,11 +17,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_predict_batch_begin"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -21,7 +21,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -17,11 +17,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_predict_batch_begin"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -16,11 +16,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_predict_batch_begin"
|
||||
|
@ -21,7 +21,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -17,11 +17,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_predict_batch_begin"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -21,7 +21,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -17,7 +17,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
|
@ -17,11 +17,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_begin"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_epoch_end"
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
|
||||
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_predict_batch_begin"
|
||||
|
Loading…
Reference in New Issue
Block a user