Frees up Model.fit to take advantage of async eager when no batch-level Callbacks are

passed.
For Datasets of unknown sizes, the first epoch will still block on each batch.
However, after the first epoch, the number of steps to run will be known, and
therefore these Datasets will be able to take advantage
of async eager as well.

Also fixes issue where the "steps" argument was not always passed correctly to
Callbacks (tested now in data_adapter_test)

PiperOrigin-RevId: 298728965
Change-Id: I430a6935ffc19718dab1be992f03a637d3ff4584
This commit is contained in:
Thomas O'Malley 2020-03-03 17:07:36 -08:00 committed by TensorFlower Gardener
parent c738e60e3d
commit 3d5ed682b4
11 changed files with 279 additions and 72 deletions

View File

@ -39,6 +39,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.data_utils import Sequence
from tensorflow.python.keras.utils.generic_utils import Progbar
@ -221,6 +222,18 @@ class CallbackList(object):
self._queue_length = 10
self._reset_batch_timing()
# Determines if batch-level hooks need to be called.
# This is important for performance, because processing batch-level logs
# will cause async eager to block on each batch.
# pylint: disable=protected-access
self._should_call_train_batch_hooks = any(
cb._implements_train_batch_hooks() for cb in self.callbacks)
self._should_call_test_batch_hooks = any(
cb._implements_test_batch_hooks() for cb in self.callbacks)
self._should_call_predict_batch_hooks = any(
cb._implements_predict_batch_hooks() for cb in self.callbacks)
# pylint: enable=protected-access
def _add_default_callbacks(self, add_history, add_progbar):
"""Adds `Callback`s that are always present."""
self._progbar = None
@ -311,12 +324,14 @@ class CallbackList(object):
self.on_predict_end()
def on_batch_begin(self, batch, logs=None):
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
if self._should_call_train_batch_hooks:
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
def on_batch_end(self, batch, logs=None):
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
if self._should_call_train_batch_hooks:
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
def on_epoch_begin(self, epoch, logs=None):
"""Calls the `on_epoch_begin` methods of its callbacks.
@ -356,8 +371,11 @@ class CallbackList(object):
logs: dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
# TODO(b/150629188): Make ProgBarLogger callback not use batch hooks
# when verbose != 1
if self._should_call_train_batch_hooks:
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
def on_train_batch_end(self, batch, logs=None):
"""Calls the `on_train_batch_end` methods of its callbacks.
@ -366,8 +384,9 @@ class CallbackList(object):
batch: integer, index of batch within the current epoch.
logs: dict. Metric results for this batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
if self._should_call_train_batch_hooks:
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
def on_test_batch_begin(self, batch, logs=None):
"""Calls the `on_test_batch_begin` methods of its callbacks.
@ -377,8 +396,9 @@ class CallbackList(object):
logs: dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
if self._should_call_test_batch_hooks:
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
def on_test_batch_end(self, batch, logs=None):
"""Calls the `on_test_batch_end` methods of its callbacks.
@ -387,7 +407,9 @@ class CallbackList(object):
batch: integer, index of batch within the current epoch.
logs: dict. Metric results for this batch.
"""
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
if self._should_call_test_batch_hooks:
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
def on_predict_batch_begin(self, batch, logs=None):
"""Calls the `on_predict_batch_begin` methods of its callbacks.
@ -397,8 +419,9 @@ class CallbackList(object):
logs: dict. Has keys `batch` and `size` representing the current batch
number and the size of the batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
if self._should_call_predict_batch_hooks:
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
def on_predict_batch_end(self, batch, logs=None):
"""Calls the `on_predict_batch_end` methods of its callbacks.
@ -407,8 +430,9 @@ class CallbackList(object):
batch: integer, index of batch within the current epoch.
logs: dict. Metric results for this batch.
"""
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
if self._should_call_predict_batch_hooks:
logs = self._process_logs(logs)
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
def on_train_begin(self, logs=None):
"""Calls the `on_train_begin` methods of its callbacks.
@ -524,10 +548,12 @@ class Callback(object):
self.model = model
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`."""
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_batch_end(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_end`."""
@ -559,6 +585,7 @@ class Callback(object):
"""
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_train_batch_begin(self, batch, logs=None):
"""Called at the beginning of a training batch in `fit` methods.
@ -573,6 +600,7 @@ class Callback(object):
self.on_batch_begin(batch, logs=logs)
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_train_batch_end(self, batch, logs=None):
"""Called at the end of a training batch in `fit` methods.
@ -586,6 +614,7 @@ class Callback(object):
self.on_batch_end(batch, logs=logs)
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_test_batch_begin(self, batch, logs=None):
"""Called at the beginning of a batch in `evaluate` methods.
@ -601,6 +630,7 @@ class Callback(object):
"""
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_test_batch_end(self, batch, logs=None):
"""Called at the end of a batch in `evaluate` methods.
@ -615,6 +645,7 @@ class Callback(object):
"""
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_predict_batch_begin(self, batch, logs=None):
"""Called at the beginning of a batch in `predict` methods.
@ -627,6 +658,7 @@ class Callback(object):
"""
@doc_controls.for_subclass_implementers
@generic_utils.default
def on_predict_batch_end(self, batch, logs=None):
"""Called at the end of a batch in `predict` methods.
@ -703,6 +735,23 @@ class Callback(object):
but that may change in the future.
"""
def _implements_train_batch_hooks(self):
"""Determines if this Callback should be called for each train batch."""
return (not generic_utils.is_default(self.on_batch_begin) or
not generic_utils.is_default(self.on_batch_end) or
not generic_utils.is_default(self.on_train_batch_begin) or
not generic_utils.is_default(self.on_train_batch_end))
def _implements_test_batch_hooks(self):
"""Determines if this Callback should be called for each test batch."""
return (not generic_utils.is_default(self.on_test_batch_begin) or
not generic_utils.is_default(self.on_test_batch_end))
def _implements_predict_batch_hooks(self):
"""Determines if this Callback should be called for each predict batch."""
return (not generic_utils.is_default(self.on_predict_batch_begin) or
not generic_utils.is_default(self.on_predict_batch_end))
@keras_export('keras.callbacks.BaseLogger')
class BaseLogger(Callback):
@ -1105,8 +1154,8 @@ class ModelCheckpoint(Callback):
self.model._training_state = None
def on_batch_end(self, batch, logs=None):
logs = logs or {}
if isinstance(self.save_freq, int):
if self._implements_train_batch_hooks():
logs = logs or {}
self._batches_seen_since_last_saving += 1
if self._batches_seen_since_last_saving >= self.save_freq:
self._save_model(epoch=self._current_epoch, logs=logs)
@ -1307,6 +1356,10 @@ class ModelCheckpoint(Callback):
# the file path with the largest file name.
return file_path_with_largest_file_name
def _implements_train_batch_hooks(self):
# If save_freq="epoch", batch-level hooks don't need to be run.
return isinstance(self.save_freq, int)
@keras_export('keras.callbacks.EarlyStopping')
class EarlyStopping(Callback):
@ -1911,6 +1964,8 @@ class TensorBoard(Callback):
batch: Integer, index of batch within the current epoch.
logs: Dict. Metric results for this batch.
"""
# TODO(b/150629188): Make TensorBoard callback not use batch hooks
# by default.
if self.update_freq == 'epoch' and self._start_batch is None:
return

View File

@ -1442,6 +1442,123 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
self.assertTrue(callback.on_batch_end_called)
self.assertTrue(callback.on_batch_end_called)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_implements_batch_hooks(self):
class MyCallbackWithBatchHooks(keras.callbacks.Callback):
def __init__(self):
self.train_batches = 0
self.test_batches = 0
self.predict_batches = 0
def on_train_batch_end(self, batch, logs=None):
self.train_batches += 1
def on_test_batch_end(self, batch, logs=None):
self.test_batches += 1
def on_predict_batch_end(self, batch, logs=None):
self.predict_batches += 1
class MyCallbackWithoutBatchHooks(keras.callbacks.Callback):
def __init__(self):
self.epochs = 0
def on_epoch_end(self, epoch, logs=None):
self.epochs += 1
x, y = np.ones((10, 1)), np.ones((10, 1))
model = keras.Sequential([keras.layers.Dense(1)])
model.compile('sgd', 'mse')
my_cb = MyCallbackWithBatchHooks()
cb_list = keras.callbacks.CallbackList([my_cb], verbose=0)
self.assertTrue(cb_list._should_call_train_batch_hooks)
self.assertTrue(cb_list._should_call_test_batch_hooks)
self.assertTrue(cb_list._should_call_predict_batch_hooks)
model.fit(x, y, epochs=2, batch_size=10, callbacks=[my_cb], verbose=0)
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
self.assertEqual(my_cb.train_batches, 2)
self.assertEqual(my_cb.test_batches, 1)
self.assertEqual(my_cb.predict_batches, 1)
my_cb = MyCallbackWithoutBatchHooks()
cb_list = keras.callbacks.CallbackList([my_cb], verbose=0)
self.assertLen(cb_list.callbacks, 1)
self.assertFalse(cb_list._should_call_train_batch_hooks)
self.assertFalse(cb_list._should_call_test_batch_hooks)
self.assertFalse(cb_list._should_call_predict_batch_hooks)
model.fit(x, y, epochs=2, batch_size=10, callbacks=[my_cb], verbose=0)
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_implements_batch_hooks_override(self):
class MyCallback(keras.callbacks.Callback):
def __init__(self, should_run=True):
self.should_run = should_run
self.train_batches = 0
self.test_batches = 0
self.predict_batches = 0
def on_train_batch_end(self, batch, logs=None):
self.train_batches += 1
def on_test_batch_end(self, batch, logs=None):
self.test_batches += 1
def on_predict_batch_end(self, batch, logs=None):
self.predict_batches += 1
def _implements_train_batch_hooks(self):
return self.should_run
def _implements_test_batch_hooks(self):
return self.should_run
def _implements_predict_batch_hooks(self):
return self.should_run
x, y = np.ones((10, 1)), np.ones((10, 1))
model = keras.Sequential([keras.layers.Dense(1)])
model.compile('sgd', 'mse')
my_cb = MyCallback(should_run=True)
cb_list = keras.callbacks.CallbackList([my_cb], verbose=0)
self.assertTrue(cb_list._should_call_train_batch_hooks)
self.assertTrue(cb_list._should_call_test_batch_hooks)
self.assertTrue(cb_list._should_call_predict_batch_hooks)
model.fit(x, y, epochs=2, batch_size=10, callbacks=[my_cb], verbose=0)
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
self.assertEqual(my_cb.train_batches, 2)
self.assertEqual(my_cb.test_batches, 1)
self.assertEqual(my_cb.predict_batches, 1)
my_cb = MyCallback(should_run=False)
cb_list = keras.callbacks.CallbackList([my_cb], verbose=0)
self.assertFalse(cb_list._should_call_train_batch_hooks)
self.assertFalse(cb_list._should_call_test_batch_hooks)
self.assertFalse(cb_list._should_call_predict_batch_hooks)
model.fit(x, y, epochs=2, batch_size=10, callbacks=[my_cb], verbose=0)
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
self.assertEqual(my_cb.train_batches, 0)
self.assertEqual(my_cb.test_batches, 0)
self.assertEqual(my_cb.predict_batches, 0)
# A summary that was emitted during a test. Fields:
# logdir: str. The logdir of the FileWriter to which the summary was

View File

@ -383,7 +383,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._auto_track_sub_layers = True
@trackable.no_automatic_dependency_tracking
@base_layer_utils.default
@generic_utils.default
def build(self, input_shape):
"""Creates the variables of the layer (optional, for subclass implementers).
@ -597,7 +597,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._non_trainable_weights.append(variable)
return variable
@base_layer_utils.default
@generic_utils.default
def get_config(self):
"""Returns the config of the layer.
@ -737,7 +737,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
output_shape)
@base_layer_utils.default
@generic_utils.default
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
"""Computes an output mask tensor.

View File

@ -651,12 +651,6 @@ def mark_as_return(outputs, acd):
return nest.map_structure(_mark_as_return, outputs)
def default(method):
"""Decorates a method to detect overrides in subclasses."""
method._is_default = True # pylint: disable=protected-access
return method
V2_DTYPE_BEHAVIOR = None

View File

@ -254,7 +254,7 @@ class Layer(base_layer.Layer):
self._auto_track_sub_layers = True
@trackable.no_automatic_dependency_tracking
@base_layer_utils.default
@generic_utils.default
def build(self, input_shape):
"""Creates the variables of the layer (optional, for subclass implementers).
@ -467,7 +467,7 @@ class Layer(base_layer.Layer):
self._non_trainable_weights.append(variable)
return variable
@base_layer_utils.default
@generic_utils.default
def get_config(self):
"""Returns the config of the layer.
@ -605,7 +605,7 @@ class Layer(base_layer.Layer):
lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
output_shape)
@base_layer_utils.default
@generic_utils.default
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
"""Computes an output mask tensor.

View File

@ -33,6 +33,7 @@ from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -1111,7 +1112,7 @@ class DataHandler(object):
dataset = self._adapter.get_dataset()
if class_weight:
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
self._steps_per_epoch = self._infer_steps(steps_per_epoch, dataset)
self._inferred_steps = self._infer_steps(steps_per_epoch, dataset)
self._dataset = strategy.experimental_distribute_dataset(dataset)
def enumerate_epochs(self):
@ -1135,12 +1136,13 @@ class DataHandler(object):
"""Catches errors when an iterator runs out of data."""
try:
yield
context.async_wait()
except (StopIteration, errors.OutOfRangeError):
if (self._adapter.get_size() is None and self._steps_per_epoch is None and
if (self._adapter.get_size() is None and self._inferred_steps is None and
self._current_step > 0):
# The input passed by the user ran out of batches.
# Now we know the cardinality of the input(dataset or generator).
self._steps_per_epoch = self._current_step
self._inferred_steps = self._current_step
else:
self._insufficient_data = True
total_epochs = self._epochs - self._initial_epoch
@ -1150,19 +1152,34 @@ class DataHandler(object):
"least `steps_per_epoch * epochs` batches (in this case, "
"{} batches). You may need to use the repeat() function "
"when building your dataset.".format(total_epochs *
self._steps_per_epoch))
self._inferred_steps))
def steps(self):
"""Yields steps for the current epoch."""
self._current_step = 0
# `self._steps_per_epoch` can be changed by `catch_stop_iteration`.
while (self._steps_per_epoch is None or
self._current_step < self._steps_per_epoch):
# `self._inferred_steps` can be changed by `catch_stop_iteration`.
while (self._inferred_steps is None or
self._current_step < self._inferred_steps):
if self._insufficient_data: # Set by `catch_stop_iteration`.
break
yield self._current_step
self._current_step += 1
@property
def inferred_steps(self):
"""The inferred steps per epoch of the created `Dataset`.
This will be `None` in the case where:
(1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, and
(2) `steps_per_epoch` was not provided, and
(3) The first epoch of iteration has not yet completed.
Returns:
The inferred steps per epoch of the created `Dataset`.
"""
return self._inferred_steps
def _infer_steps(self, steps, dataset):
"""Infers steps_per_epoch needed to loop through a dataset."""
if steps is not None:
@ -1189,17 +1206,13 @@ class DataHandler(object):
raise ValueError("When passing an infinitely repeating dataset, you "
"must specify how many steps to draw.")
if size >= 0:
return size
return size.numpy().item()
return None
@property
def _samples(self):
return self._adapter.get_samples()
@property
def _steps(self):
return self._adapter.get_size()
def _make_class_weight_map_fn(class_weight):
"""Applies class weighting to a `Dataset`.

View File

@ -786,6 +786,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
# User can choose to only partially consume `Dataset`.
data_handler = data_adapter.DataHandler(
data, initial_epoch=0, epochs=2, steps_per_epoch=2)
self.assertEqual(data_handler.inferred_steps, 2)
self.assertFalse(data_handler._adapter.should_recreate_iterator())
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
@ -798,6 +799,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
def test_finite_dataset_without_steps_per_epoch(self):
data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1)
data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2)
self.assertEqual(data_handler.inferred_steps, 3)
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
epoch_data = []
@ -851,6 +853,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
returned_data.append(epoch_data)
returned_data = self.evaluate(returned_data)
self.assertEqual(returned_data, [[0, 1], [2, 3]])
self.assertEqual(data_handler.inferred_steps, 2)
def test_unknown_cardinality_dataset_without_steps_per_epoch(self):
ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
@ -860,6 +863,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
data_handler = data_adapter.DataHandler(
filtered_ds, initial_epoch=0, epochs=2)
self.assertEqual(data_handler.inferred_steps, None)
self.assertTrue(data_handler._adapter.should_recreate_iterator())
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
@ -870,7 +874,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
returned_data.append(epoch_data)
returned_data = self.evaluate(returned_data)
self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
self.assertEqual(data_handler._steps_per_epoch, 4)
self.assertEqual(data_handler.inferred_steps, 4)
def test_insufficient_data(self):
ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1])

View File

@ -586,7 +586,7 @@ class Network(base_layer.Layer):
"""
return
@base_layer_utils.default
@generic_utils.default
def build(self, input_shape):
"""Builds the model based on input shapes received.

View File

@ -23,7 +23,6 @@ import copy
from tensorflow.python.keras import layers as layer_module
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine import training_utils
@ -255,7 +254,7 @@ class Sequential(training.Model):
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self.built = True
@base_layer_utils.default
@generic_utils.default
def build(self, input_shape=None):
if self._is_graph_network:
self._init_graph_network(self.inputs, self.outputs, name=self.name)

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values as ds_values
@ -743,14 +745,15 @@ class Model(network.Network, version_utils.ModelVersionSelector):
model=self)
# Container that configures and calls `tf.keras.Callback`s.
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=True,
model=self,
verbose=verbose,
epochs=epochs,
steps=data_handler._steps) # pylint: disable=protected-access
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
model=self,
verbose=verbose,
epochs=epochs,
steps=data_handler.inferred_steps)
self.stop_training = False
train_function = self.make_train_function()
@ -773,12 +776,14 @@ class Model(network.Network, version_utils.ModelVersionSelector):
batch_size=batch_size):
callbacks.on_train_batch_begin(step)
tmp_logs = train_function(iterator)
# Catch possible OutOfRangeError here.
# TODO(b/150292341): Allow multiple async steps.
context.async_wait()
logs = tmp_logs
# Catch OutOfRangeError for Datasets of unknown size.
# This blocks until the batch has finished executing.
# TODO(b/150292341): Allow multiple async steps here.
if not data_handler.inferred_steps:
context.async_wait()
logs = tmp_logs # No error, now safe to assign to logs.
callbacks.on_train_batch_end(step, logs)
epoch_logs = {m.name: m.result() for m in self.metrics}
epoch_logs = copy.copy(logs)
# Run validation.
if validation_data and self._should_eval(epoch, validation_freq):
@ -986,11 +991,11 @@ class Model(network.Network, version_utils.ModelVersionSelector):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=True,
add_progbar=verbose != 0,
model=self,
verbose=verbose,
epochs=1,
steps=data_handler._steps) # pylint: disable=protected-access
steps=data_handler.inferred_steps)
test_function = self.make_test_function()
callbacks.on_test_begin()
@ -1004,8 +1009,12 @@ class Model(network.Network, version_utils.ModelVersionSelector):
step_num=step):
callbacks.on_test_batch_begin(step)
tmp_logs = test_function(iterator)
context.async_wait() # Possible OutOfRangeError here.
logs = tmp_logs
# Catch OutOfRangeError for Datasets of unknown size.
# This blocks until the batch has finished executing.
# TODO(b/150292341): Allow multiple async steps here.
if not data_handler.inferred_steps:
context.async_wait()
logs = tmp_logs # No error, now safe to assign to logs.
callbacks.on_test_batch_end(step, logs)
callbacks.on_test_end()
@ -1170,14 +1179,15 @@ class Model(network.Network, version_utils.ModelVersionSelector):
model=self)
# Container that configures and calls `tf.keras.Callback`s.
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=True,
model=self,
verbose=verbose,
epochs=1,
steps=data_handler._steps) # pylint: disable=protected-access
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
model=self,
verbose=verbose,
epochs=1,
steps=data_handler.inferred_steps)
predict_function = self.make_predict_function()
callbacks.on_predict_begin()
@ -1186,8 +1196,12 @@ class Model(network.Network, version_utils.ModelVersionSelector):
for step in data_handler.steps():
callbacks.on_predict_batch_begin(step)
tmp_batch_outputs = predict_function(iterator)
context.async_wait() # Possible OutOfRangeError here.
batch_outputs = tmp_batch_outputs
# Catch OutOfRangeError for Datasets of unknown size.
# This blocks until the batch has finished executing.
# TODO(b/150292341): Allow multiple async steps here.
if not data_handler.inferred_steps:
context.async_wait()
batch_outputs = tmp_batch_outputs # No error, now safe to assign.
if outputs is None:
outputs = nest.map_structure(lambda batch_output: [batch_output],
batch_outputs)

View File

@ -795,3 +795,14 @@ def validate_kwargs(kwargs,
def validate_config(config):
"""Determines whether config appears to be a valid layer config."""
return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
def default(method):
"""Decorates a method to detect overrides in subclasses."""
method._is_default = True # pylint: disable=protected-access
return method
def is_default(method):
"""Check if a method is decorated with the `default` wrapper."""
return getattr(method, '_is_default', False)