Remove the 'mode' argument in on_epoch_begin and on_epoch_end methods of CallbackList and Callback classes, to eliminate exposure of ModeKeys api that's intended to be TensorFlow-internal. Add doc that says the methods should only be called during TRAIN mode.

PiperOrigin-RevId: 226252687
This commit is contained in:
Rick Chao 2018-12-19 16:25:49 -08:00 committed by TensorFlower Gardener
parent 0445684a64
commit 65093ecfe6
21 changed files with 50 additions and 46 deletions

View File

@ -242,35 +242,35 @@ class CallbackList(object):
def on_batch_end(self, batch, logs=None):
self._call_batch_hook(_TRAIN, 'end', batch, logs=logs)
def on_epoch_begin(self, epoch, logs=None, mode='train'):
def on_epoch_begin(self, epoch, logs=None):
"""Calls the `on_epoch_begin` methods of its callbacks.
This function should only be called during TRAIN mode.
Arguments:
epoch: integer, index of epoch.
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
mode: One of 'train'/'test'/'predict'
"""
if mode == _TRAIN:
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
self._reset_batch_timing()
def on_epoch_end(self, epoch, logs=None, mode='train'):
def on_epoch_end(self, epoch, logs=None):
"""Calls the `on_epoch_end` methods of its callbacks.
This function should only be called during TRAIN mode.
Arguments:
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
mode: One of 'train'/'test'/'predict'
"""
if mode == _TRAIN:
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
def on_train_batch_begin(self, batch, logs=None):
"""Calls the `on_train_batch_begin` methods of its callbacks.
@ -437,29 +437,29 @@ class Callback(object):
def on_batch_end(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_end`."""
def on_epoch_begin(self, epoch, logs=None, mode='train'):
def on_epoch_begin(self, epoch, logs=None):
"""Called at the start of an epoch.
Subclasses should override for any actions to run.
Subclasses should override for any actions to run. This function should only
be called during TRAIN mode.
Arguments:
epoch: integer, index of epoch.
logs: dict. Currently no data is passed to this argument for this method
but that may change in the future.
mode: One of 'train'/'test'/'predict'
"""
def on_epoch_end(self, epoch, logs=None, mode='train'):
def on_epoch_end(self, epoch, logs=None):
"""Called at the end of an epoch.
Subclasses should override for any actions to run.
Subclasses should override for any actions to run. This function should only
be called during TRAIN mode.
Arguments:
epoch: integer, index of epoch.
logs: dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
mode: One of 'train'/'test'/'predict'
"""
def on_train_batch_begin(self, batch, logs=None):

View File

@ -32,6 +32,7 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils.generic_utils import make_batches
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.mode_keys import ModeKeys
try:
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
@ -251,7 +252,8 @@ def model_iteration(model,
# Setup work for each epoch
epoch_logs = {}
model.reset_metrics()
callbacks.on_epoch_begin(epoch, epoch_logs, mode=mode)
if mode == ModeKeys.TRAIN:
callbacks.on_epoch_begin(epoch, epoch_logs)
progbar.on_epoch_begin(epoch, epoch_logs)
if use_steps:
@ -371,7 +373,7 @@ def model_iteration(model,
if mode == 'train':
# Epochs only apply to `fit`.
callbacks.on_epoch_end(epoch, epoch_logs, mode=mode)
callbacks.on_epoch_end(epoch, epoch_logs)
progbar.on_epoch_end(epoch, epoch_logs)
callbacks._call_end_hook(mode)

View File

@ -34,6 +34,7 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.mode_keys import ModeKeys
from tensorflow.python.util import nest
@ -170,7 +171,8 @@ def model_iteration(model,
# Setup work for each epoch.
model.reset_metrics()
epoch_logs = {}
callbacks.on_epoch_begin(epoch, epoch_logs, mode=mode)
if mode == ModeKeys.TRAIN:
callbacks.on_epoch_begin(epoch, epoch_logs)
progbar.on_epoch_begin(epoch, epoch_logs)
for step in range(steps_per_epoch):
@ -233,7 +235,7 @@ def model_iteration(model,
if mode == 'train':
# Epochs only apply to `fit`.
callbacks.on_epoch_end(epoch, epoch_logs, mode=mode)
callbacks.on_epoch_end(epoch, epoch_logs)
progbar.on_epoch_end(epoch, epoch_logs)
callbacks._call_end_hook(mode)

View File

@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -16,11 +16,11 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_begin"

View File

@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -17,11 +17,11 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_begin"

View File

@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -17,11 +17,11 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_begin"

View File

@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -16,11 +16,11 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_begin"

View File

@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -17,11 +17,11 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_begin"

View File

@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"

View File

@ -17,11 +17,11 @@ tf_class {
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'train\'], "
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_begin"