Frees up Model.fit to take advantage of async eager when no batch-level Callbacks are
passed. For Datasets of unknown sizes, the first epoch will still block on each batch. However, after the first epoch, the number of steps to run will be known, and therefore these Datasets will be able to take advantage of async eager as well. Also fixes issue where the "steps" argument was not always passed correctly to Callbacks (tested now in data_adapter_test) PiperOrigin-RevId: 298728965 Change-Id: I430a6935ffc19718dab1be992f03a637d3ff4584
This commit is contained in:
parent
c738e60e3d
commit
3d5ed682b4
@ -39,6 +39,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils.data_utils import Sequence
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
@ -221,6 +222,18 @@ class CallbackList(object):
|
||||
self._queue_length = 10
|
||||
self._reset_batch_timing()
|
||||
|
||||
# Determines if batch-level hooks need to be called.
|
||||
# This is important for performance, because processing batch-level logs
|
||||
# will cause async eager to block on each batch.
|
||||
# pylint: disable=protected-access
|
||||
self._should_call_train_batch_hooks = any(
|
||||
cb._implements_train_batch_hooks() for cb in self.callbacks)
|
||||
self._should_call_test_batch_hooks = any(
|
||||
cb._implements_test_batch_hooks() for cb in self.callbacks)
|
||||
self._should_call_predict_batch_hooks = any(
|
||||
cb._implements_predict_batch_hooks() for cb in self.callbacks)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _add_default_callbacks(self, add_history, add_progbar):
|
||||
"""Adds `Callback`s that are always present."""
|
||||
self._progbar = None
|
||||
@ -311,12 +324,14 @@ class CallbackList(object):
|
||||
self.on_predict_end()
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
|
||||
if self._should_call_train_batch_hooks:
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
|
||||
if self._should_call_train_batch_hooks:
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
"""Calls the `on_epoch_begin` methods of its callbacks.
|
||||
@ -356,8 +371,11 @@ class CallbackList(object):
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
|
||||
# TODO(b/150629188): Make ProgBarLogger callback not use batch hooks
|
||||
# when verbose != 1
|
||||
if self._should_call_train_batch_hooks:
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
"""Calls the `on_train_batch_end` methods of its callbacks.
|
||||
@ -366,8 +384,9 @@ class CallbackList(object):
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Metric results for this batch.
|
||||
"""
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
|
||||
if self._should_call_train_batch_hooks:
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
|
||||
|
||||
def on_test_batch_begin(self, batch, logs=None):
|
||||
"""Calls the `on_test_batch_begin` methods of its callbacks.
|
||||
@ -377,8 +396,9 @@ class CallbackList(object):
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
|
||||
if self._should_call_test_batch_hooks:
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
"""Calls the `on_test_batch_end` methods of its callbacks.
|
||||
@ -387,7 +407,9 @@ class CallbackList(object):
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Metric results for this batch.
|
||||
"""
|
||||
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
|
||||
if self._should_call_test_batch_hooks:
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
|
||||
|
||||
def on_predict_batch_begin(self, batch, logs=None):
|
||||
"""Calls the `on_predict_batch_begin` methods of its callbacks.
|
||||
@ -397,8 +419,9 @@ class CallbackList(object):
|
||||
logs: dict. Has keys `batch` and `size` representing the current batch
|
||||
number and the size of the batch.
|
||||
"""
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
|
||||
if self._should_call_predict_batch_hooks:
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
|
||||
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
"""Calls the `on_predict_batch_end` methods of its callbacks.
|
||||
@ -407,8 +430,9 @@ class CallbackList(object):
|
||||
batch: integer, index of batch within the current epoch.
|
||||
logs: dict. Metric results for this batch.
|
||||
"""
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
|
||||
if self._should_call_predict_batch_hooks:
|
||||
logs = self._process_logs(logs)
|
||||
self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
"""Calls the `on_train_begin` methods of its callbacks.
|
||||
@ -524,10 +548,12 @@ class Callback(object):
|
||||
self.model = model
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@generic_utils.default
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
"""A backwards compatibility alias for `on_train_batch_begin`."""
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@generic_utils.default
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
"""A backwards compatibility alias for `on_train_batch_end`."""
|
||||
|
||||
@ -559,6 +585,7 @@ class Callback(object):
|
||||
"""
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@generic_utils.default
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
"""Called at the beginning of a training batch in `fit` methods.
|
||||
|
||||
@ -573,6 +600,7 @@ class Callback(object):
|
||||
self.on_batch_begin(batch, logs=logs)
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@generic_utils.default
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
"""Called at the end of a training batch in `fit` methods.
|
||||
|
||||
@ -586,6 +614,7 @@ class Callback(object):
|
||||
self.on_batch_end(batch, logs=logs)
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@generic_utils.default
|
||||
def on_test_batch_begin(self, batch, logs=None):
|
||||
"""Called at the beginning of a batch in `evaluate` methods.
|
||||
|
||||
@ -601,6 +630,7 @@ class Callback(object):
|
||||
"""
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@generic_utils.default
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
"""Called at the end of a batch in `evaluate` methods.
|
||||
|
||||
@ -615,6 +645,7 @@ class Callback(object):
|
||||
"""
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@generic_utils.default
|
||||
def on_predict_batch_begin(self, batch, logs=None):
|
||||
"""Called at the beginning of a batch in `predict` methods.
|
||||
|
||||
@ -627,6 +658,7 @@ class Callback(object):
|
||||
"""
|
||||
|
||||
@doc_controls.for_subclass_implementers
|
||||
@generic_utils.default
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
"""Called at the end of a batch in `predict` methods.
|
||||
|
||||
@ -703,6 +735,23 @@ class Callback(object):
|
||||
but that may change in the future.
|
||||
"""
|
||||
|
||||
def _implements_train_batch_hooks(self):
|
||||
"""Determines if this Callback should be called for each train batch."""
|
||||
return (not generic_utils.is_default(self.on_batch_begin) or
|
||||
not generic_utils.is_default(self.on_batch_end) or
|
||||
not generic_utils.is_default(self.on_train_batch_begin) or
|
||||
not generic_utils.is_default(self.on_train_batch_end))
|
||||
|
||||
def _implements_test_batch_hooks(self):
|
||||
"""Determines if this Callback should be called for each test batch."""
|
||||
return (not generic_utils.is_default(self.on_test_batch_begin) or
|
||||
not generic_utils.is_default(self.on_test_batch_end))
|
||||
|
||||
def _implements_predict_batch_hooks(self):
|
||||
"""Determines if this Callback should be called for each predict batch."""
|
||||
return (not generic_utils.is_default(self.on_predict_batch_begin) or
|
||||
not generic_utils.is_default(self.on_predict_batch_end))
|
||||
|
||||
|
||||
@keras_export('keras.callbacks.BaseLogger')
|
||||
class BaseLogger(Callback):
|
||||
@ -1105,8 +1154,8 @@ class ModelCheckpoint(Callback):
|
||||
self.model._training_state = None
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
logs = logs or {}
|
||||
if isinstance(self.save_freq, int):
|
||||
if self._implements_train_batch_hooks():
|
||||
logs = logs or {}
|
||||
self._batches_seen_since_last_saving += 1
|
||||
if self._batches_seen_since_last_saving >= self.save_freq:
|
||||
self._save_model(epoch=self._current_epoch, logs=logs)
|
||||
@ -1307,6 +1356,10 @@ class ModelCheckpoint(Callback):
|
||||
# the file path with the largest file name.
|
||||
return file_path_with_largest_file_name
|
||||
|
||||
def _implements_train_batch_hooks(self):
|
||||
# If save_freq="epoch", batch-level hooks don't need to be run.
|
||||
return isinstance(self.save_freq, int)
|
||||
|
||||
|
||||
@keras_export('keras.callbacks.EarlyStopping')
|
||||
class EarlyStopping(Callback):
|
||||
@ -1911,6 +1964,8 @@ class TensorBoard(Callback):
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Metric results for this batch.
|
||||
"""
|
||||
# TODO(b/150629188): Make TensorBoard callback not use batch hooks
|
||||
# by default.
|
||||
if self.update_freq == 'epoch' and self._start_batch is None:
|
||||
return
|
||||
|
||||
|
||||
@ -1442,6 +1442,123 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
self.assertTrue(callback.on_batch_end_called)
|
||||
self.assertTrue(callback.on_batch_end_called)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_implements_batch_hooks(self):
|
||||
|
||||
class MyCallbackWithBatchHooks(keras.callbacks.Callback):
|
||||
|
||||
def __init__(self):
|
||||
self.train_batches = 0
|
||||
self.test_batches = 0
|
||||
self.predict_batches = 0
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
self.train_batches += 1
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
self.test_batches += 1
|
||||
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
self.predict_batches += 1
|
||||
|
||||
class MyCallbackWithoutBatchHooks(keras.callbacks.Callback):
|
||||
|
||||
def __init__(self):
|
||||
self.epochs = 0
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
self.epochs += 1
|
||||
|
||||
x, y = np.ones((10, 1)), np.ones((10, 1))
|
||||
model = keras.Sequential([keras.layers.Dense(1)])
|
||||
model.compile('sgd', 'mse')
|
||||
|
||||
my_cb = MyCallbackWithBatchHooks()
|
||||
cb_list = keras.callbacks.CallbackList([my_cb], verbose=0)
|
||||
self.assertTrue(cb_list._should_call_train_batch_hooks)
|
||||
self.assertTrue(cb_list._should_call_test_batch_hooks)
|
||||
self.assertTrue(cb_list._should_call_predict_batch_hooks)
|
||||
|
||||
model.fit(x, y, epochs=2, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
|
||||
self.assertEqual(my_cb.train_batches, 2)
|
||||
self.assertEqual(my_cb.test_batches, 1)
|
||||
self.assertEqual(my_cb.predict_batches, 1)
|
||||
|
||||
my_cb = MyCallbackWithoutBatchHooks()
|
||||
cb_list = keras.callbacks.CallbackList([my_cb], verbose=0)
|
||||
self.assertLen(cb_list.callbacks, 1)
|
||||
self.assertFalse(cb_list._should_call_train_batch_hooks)
|
||||
self.assertFalse(cb_list._should_call_test_batch_hooks)
|
||||
self.assertFalse(cb_list._should_call_predict_batch_hooks)
|
||||
|
||||
model.fit(x, y, epochs=2, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_implements_batch_hooks_override(self):
|
||||
|
||||
class MyCallback(keras.callbacks.Callback):
|
||||
|
||||
def __init__(self, should_run=True):
|
||||
self.should_run = should_run
|
||||
self.train_batches = 0
|
||||
self.test_batches = 0
|
||||
self.predict_batches = 0
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
self.train_batches += 1
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
self.test_batches += 1
|
||||
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
self.predict_batches += 1
|
||||
|
||||
def _implements_train_batch_hooks(self):
|
||||
return self.should_run
|
||||
|
||||
def _implements_test_batch_hooks(self):
|
||||
return self.should_run
|
||||
|
||||
def _implements_predict_batch_hooks(self):
|
||||
return self.should_run
|
||||
|
||||
x, y = np.ones((10, 1)), np.ones((10, 1))
|
||||
model = keras.Sequential([keras.layers.Dense(1)])
|
||||
model.compile('sgd', 'mse')
|
||||
|
||||
my_cb = MyCallback(should_run=True)
|
||||
cb_list = keras.callbacks.CallbackList([my_cb], verbose=0)
|
||||
self.assertTrue(cb_list._should_call_train_batch_hooks)
|
||||
self.assertTrue(cb_list._should_call_test_batch_hooks)
|
||||
self.assertTrue(cb_list._should_call_predict_batch_hooks)
|
||||
|
||||
model.fit(x, y, epochs=2, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
|
||||
self.assertEqual(my_cb.train_batches, 2)
|
||||
self.assertEqual(my_cb.test_batches, 1)
|
||||
self.assertEqual(my_cb.predict_batches, 1)
|
||||
|
||||
my_cb = MyCallback(should_run=False)
|
||||
cb_list = keras.callbacks.CallbackList([my_cb], verbose=0)
|
||||
self.assertFalse(cb_list._should_call_train_batch_hooks)
|
||||
self.assertFalse(cb_list._should_call_test_batch_hooks)
|
||||
self.assertFalse(cb_list._should_call_predict_batch_hooks)
|
||||
|
||||
model.fit(x, y, epochs=2, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.evaluate(x, y, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
model.predict(x, batch_size=10, callbacks=[my_cb], verbose=0)
|
||||
|
||||
self.assertEqual(my_cb.train_batches, 0)
|
||||
self.assertEqual(my_cb.test_batches, 0)
|
||||
self.assertEqual(my_cb.predict_batches, 0)
|
||||
|
||||
|
||||
# A summary that was emitted during a test. Fields:
|
||||
# logdir: str. The logdir of the FileWriter to which the summary was
|
||||
|
||||
@ -383,7 +383,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
self._auto_track_sub_layers = True
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
@base_layer_utils.default
|
||||
@generic_utils.default
|
||||
def build(self, input_shape):
|
||||
"""Creates the variables of the layer (optional, for subclass implementers).
|
||||
|
||||
@ -597,7 +597,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
self._non_trainable_weights.append(variable)
|
||||
return variable
|
||||
|
||||
@base_layer_utils.default
|
||||
@generic_utils.default
|
||||
def get_config(self):
|
||||
"""Returns the config of the layer.
|
||||
|
||||
@ -737,7 +737,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
|
||||
output_shape)
|
||||
|
||||
@base_layer_utils.default
|
||||
@generic_utils.default
|
||||
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
|
||||
"""Computes an output mask tensor.
|
||||
|
||||
|
||||
@ -651,12 +651,6 @@ def mark_as_return(outputs, acd):
|
||||
return nest.map_structure(_mark_as_return, outputs)
|
||||
|
||||
|
||||
def default(method):
|
||||
"""Decorates a method to detect overrides in subclasses."""
|
||||
method._is_default = True # pylint: disable=protected-access
|
||||
return method
|
||||
|
||||
|
||||
V2_DTYPE_BEHAVIOR = None
|
||||
|
||||
|
||||
|
||||
@ -254,7 +254,7 @@ class Layer(base_layer.Layer):
|
||||
self._auto_track_sub_layers = True
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
@base_layer_utils.default
|
||||
@generic_utils.default
|
||||
def build(self, input_shape):
|
||||
"""Creates the variables of the layer (optional, for subclass implementers).
|
||||
|
||||
@ -467,7 +467,7 @@ class Layer(base_layer.Layer):
|
||||
self._non_trainable_weights.append(variable)
|
||||
return variable
|
||||
|
||||
@base_layer_utils.default
|
||||
@generic_utils.default
|
||||
def get_config(self):
|
||||
"""Returns the config of the layer.
|
||||
|
||||
@ -605,7 +605,7 @@ class Layer(base_layer.Layer):
|
||||
lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
|
||||
output_shape)
|
||||
|
||||
@base_layer_utils.default
|
||||
@generic_utils.default
|
||||
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
|
||||
"""Computes an output mask tensor.
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -1111,7 +1112,7 @@ class DataHandler(object):
|
||||
dataset = self._adapter.get_dataset()
|
||||
if class_weight:
|
||||
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
|
||||
self._steps_per_epoch = self._infer_steps(steps_per_epoch, dataset)
|
||||
self._inferred_steps = self._infer_steps(steps_per_epoch, dataset)
|
||||
self._dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
|
||||
def enumerate_epochs(self):
|
||||
@ -1135,12 +1136,13 @@ class DataHandler(object):
|
||||
"""Catches errors when an iterator runs out of data."""
|
||||
try:
|
||||
yield
|
||||
context.async_wait()
|
||||
except (StopIteration, errors.OutOfRangeError):
|
||||
if (self._adapter.get_size() is None and self._steps_per_epoch is None and
|
||||
if (self._adapter.get_size() is None and self._inferred_steps is None and
|
||||
self._current_step > 0):
|
||||
# The input passed by the user ran out of batches.
|
||||
# Now we know the cardinality of the input(dataset or generator).
|
||||
self._steps_per_epoch = self._current_step
|
||||
self._inferred_steps = self._current_step
|
||||
else:
|
||||
self._insufficient_data = True
|
||||
total_epochs = self._epochs - self._initial_epoch
|
||||
@ -1150,19 +1152,34 @@ class DataHandler(object):
|
||||
"least `steps_per_epoch * epochs` batches (in this case, "
|
||||
"{} batches). You may need to use the repeat() function "
|
||||
"when building your dataset.".format(total_epochs *
|
||||
self._steps_per_epoch))
|
||||
self._inferred_steps))
|
||||
|
||||
def steps(self):
|
||||
"""Yields steps for the current epoch."""
|
||||
self._current_step = 0
|
||||
# `self._steps_per_epoch` can be changed by `catch_stop_iteration`.
|
||||
while (self._steps_per_epoch is None or
|
||||
self._current_step < self._steps_per_epoch):
|
||||
# `self._inferred_steps` can be changed by `catch_stop_iteration`.
|
||||
while (self._inferred_steps is None or
|
||||
self._current_step < self._inferred_steps):
|
||||
if self._insufficient_data: # Set by `catch_stop_iteration`.
|
||||
break
|
||||
yield self._current_step
|
||||
self._current_step += 1
|
||||
|
||||
@property
|
||||
def inferred_steps(self):
|
||||
"""The inferred steps per epoch of the created `Dataset`.
|
||||
|
||||
This will be `None` in the case where:
|
||||
|
||||
(1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, and
|
||||
(2) `steps_per_epoch` was not provided, and
|
||||
(3) The first epoch of iteration has not yet completed.
|
||||
|
||||
Returns:
|
||||
The inferred steps per epoch of the created `Dataset`.
|
||||
"""
|
||||
return self._inferred_steps
|
||||
|
||||
def _infer_steps(self, steps, dataset):
|
||||
"""Infers steps_per_epoch needed to loop through a dataset."""
|
||||
if steps is not None:
|
||||
@ -1189,17 +1206,13 @@ class DataHandler(object):
|
||||
raise ValueError("When passing an infinitely repeating dataset, you "
|
||||
"must specify how many steps to draw.")
|
||||
if size >= 0:
|
||||
return size
|
||||
return size.numpy().item()
|
||||
return None
|
||||
|
||||
@property
|
||||
def _samples(self):
|
||||
return self._adapter.get_samples()
|
||||
|
||||
@property
|
||||
def _steps(self):
|
||||
return self._adapter.get_size()
|
||||
|
||||
|
||||
def _make_class_weight_map_fn(class_weight):
|
||||
"""Applies class weighting to a `Dataset`.
|
||||
|
||||
@ -786,6 +786,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
|
||||
# User can choose to only partially consume `Dataset`.
|
||||
data_handler = data_adapter.DataHandler(
|
||||
data, initial_epoch=0, epochs=2, steps_per_epoch=2)
|
||||
self.assertEqual(data_handler.inferred_steps, 2)
|
||||
self.assertFalse(data_handler._adapter.should_recreate_iterator())
|
||||
returned_data = []
|
||||
for _, iterator in data_handler.enumerate_epochs():
|
||||
@ -798,6 +799,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
|
||||
def test_finite_dataset_without_steps_per_epoch(self):
|
||||
data = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]).batch(1)
|
||||
data_handler = data_adapter.DataHandler(data, initial_epoch=0, epochs=2)
|
||||
self.assertEqual(data_handler.inferred_steps, 3)
|
||||
returned_data = []
|
||||
for _, iterator in data_handler.enumerate_epochs():
|
||||
epoch_data = []
|
||||
@ -851,6 +853,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
|
||||
returned_data.append(epoch_data)
|
||||
returned_data = self.evaluate(returned_data)
|
||||
self.assertEqual(returned_data, [[0, 1], [2, 3]])
|
||||
self.assertEqual(data_handler.inferred_steps, 2)
|
||||
|
||||
def test_unknown_cardinality_dataset_without_steps_per_epoch(self):
|
||||
ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
|
||||
@ -860,6 +863,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
|
||||
|
||||
data_handler = data_adapter.DataHandler(
|
||||
filtered_ds, initial_epoch=0, epochs=2)
|
||||
self.assertEqual(data_handler.inferred_steps, None)
|
||||
self.assertTrue(data_handler._adapter.should_recreate_iterator())
|
||||
returned_data = []
|
||||
for _, iterator in data_handler.enumerate_epochs():
|
||||
@ -870,7 +874,7 @@ class DataHandlerTest(keras_parameterized.TestCase):
|
||||
returned_data.append(epoch_data)
|
||||
returned_data = self.evaluate(returned_data)
|
||||
self.assertEqual(returned_data, [[0, 1, 2, 3], [0, 1, 2, 3]])
|
||||
self.assertEqual(data_handler._steps_per_epoch, 4)
|
||||
self.assertEqual(data_handler.inferred_steps, 4)
|
||||
|
||||
def test_insufficient_data(self):
|
||||
ds = dataset_ops.DatasetV2.from_tensor_slices([0, 1])
|
||||
|
||||
@ -586,7 +586,7 @@ class Network(base_layer.Layer):
|
||||
"""
|
||||
return
|
||||
|
||||
@base_layer_utils.default
|
||||
@generic_utils.default
|
||||
def build(self, input_shape):
|
||||
"""Builds the model based on input shapes received.
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ import copy
|
||||
|
||||
from tensorflow.python.keras import layers as layer_module
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
@ -255,7 +254,7 @@ class Sequential(training.Model):
|
||||
self._init_graph_network(self.inputs, self.outputs, name=self.name)
|
||||
self.built = True
|
||||
|
||||
@base_layer_utils.default
|
||||
@generic_utils.default
|
||||
def build(self, input_shape=None):
|
||||
if self._is_graph_network:
|
||||
self._init_graph_network(self.inputs, self.outputs, name=self.name)
|
||||
|
||||
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
@ -743,14 +745,15 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
model=self)
|
||||
|
||||
# Container that configures and calls `tf.keras.Callback`s.
|
||||
callbacks = callbacks_module.CallbackList(
|
||||
callbacks,
|
||||
add_history=True,
|
||||
add_progbar=True,
|
||||
model=self,
|
||||
verbose=verbose,
|
||||
epochs=epochs,
|
||||
steps=data_handler._steps) # pylint: disable=protected-access
|
||||
if not isinstance(callbacks, callbacks_module.CallbackList):
|
||||
callbacks = callbacks_module.CallbackList(
|
||||
callbacks,
|
||||
add_history=True,
|
||||
add_progbar=verbose != 0,
|
||||
model=self,
|
||||
verbose=verbose,
|
||||
epochs=epochs,
|
||||
steps=data_handler.inferred_steps)
|
||||
|
||||
self.stop_training = False
|
||||
train_function = self.make_train_function()
|
||||
@ -773,12 +776,14 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
batch_size=batch_size):
|
||||
callbacks.on_train_batch_begin(step)
|
||||
tmp_logs = train_function(iterator)
|
||||
# Catch possible OutOfRangeError here.
|
||||
# TODO(b/150292341): Allow multiple async steps.
|
||||
context.async_wait()
|
||||
logs = tmp_logs
|
||||
# Catch OutOfRangeError for Datasets of unknown size.
|
||||
# This blocks until the batch has finished executing.
|
||||
# TODO(b/150292341): Allow multiple async steps here.
|
||||
if not data_handler.inferred_steps:
|
||||
context.async_wait()
|
||||
logs = tmp_logs # No error, now safe to assign to logs.
|
||||
callbacks.on_train_batch_end(step, logs)
|
||||
epoch_logs = {m.name: m.result() for m in self.metrics}
|
||||
epoch_logs = copy.copy(logs)
|
||||
|
||||
# Run validation.
|
||||
if validation_data and self._should_eval(epoch, validation_freq):
|
||||
@ -986,11 +991,11 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
callbacks = callbacks_module.CallbackList(
|
||||
callbacks,
|
||||
add_history=True,
|
||||
add_progbar=True,
|
||||
add_progbar=verbose != 0,
|
||||
model=self,
|
||||
verbose=verbose,
|
||||
epochs=1,
|
||||
steps=data_handler._steps) # pylint: disable=protected-access
|
||||
steps=data_handler.inferred_steps)
|
||||
|
||||
test_function = self.make_test_function()
|
||||
callbacks.on_test_begin()
|
||||
@ -1004,8 +1009,12 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
step_num=step):
|
||||
callbacks.on_test_batch_begin(step)
|
||||
tmp_logs = test_function(iterator)
|
||||
context.async_wait() # Possible OutOfRangeError here.
|
||||
logs = tmp_logs
|
||||
# Catch OutOfRangeError for Datasets of unknown size.
|
||||
# This blocks until the batch has finished executing.
|
||||
# TODO(b/150292341): Allow multiple async steps here.
|
||||
if not data_handler.inferred_steps:
|
||||
context.async_wait()
|
||||
logs = tmp_logs # No error, now safe to assign to logs.
|
||||
callbacks.on_test_batch_end(step, logs)
|
||||
callbacks.on_test_end()
|
||||
|
||||
@ -1170,14 +1179,15 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
model=self)
|
||||
|
||||
# Container that configures and calls `tf.keras.Callback`s.
|
||||
callbacks = callbacks_module.CallbackList(
|
||||
callbacks,
|
||||
add_history=True,
|
||||
add_progbar=True,
|
||||
model=self,
|
||||
verbose=verbose,
|
||||
epochs=1,
|
||||
steps=data_handler._steps) # pylint: disable=protected-access
|
||||
if not isinstance(callbacks, callbacks_module.CallbackList):
|
||||
callbacks = callbacks_module.CallbackList(
|
||||
callbacks,
|
||||
add_history=True,
|
||||
add_progbar=verbose != 0,
|
||||
model=self,
|
||||
verbose=verbose,
|
||||
epochs=1,
|
||||
steps=data_handler.inferred_steps)
|
||||
|
||||
predict_function = self.make_predict_function()
|
||||
callbacks.on_predict_begin()
|
||||
@ -1186,8 +1196,12 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
for step in data_handler.steps():
|
||||
callbacks.on_predict_batch_begin(step)
|
||||
tmp_batch_outputs = predict_function(iterator)
|
||||
context.async_wait() # Possible OutOfRangeError here.
|
||||
batch_outputs = tmp_batch_outputs
|
||||
# Catch OutOfRangeError for Datasets of unknown size.
|
||||
# This blocks until the batch has finished executing.
|
||||
# TODO(b/150292341): Allow multiple async steps here.
|
||||
if not data_handler.inferred_steps:
|
||||
context.async_wait()
|
||||
batch_outputs = tmp_batch_outputs # No error, now safe to assign.
|
||||
if outputs is None:
|
||||
outputs = nest.map_structure(lambda batch_output: [batch_output],
|
||||
batch_outputs)
|
||||
|
||||
@ -795,3 +795,14 @@ def validate_kwargs(kwargs,
|
||||
def validate_config(config):
|
||||
"""Determines whether config appears to be a valid layer config."""
|
||||
return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
|
||||
|
||||
|
||||
def default(method):
|
||||
"""Decorates a method to detect overrides in subclasses."""
|
||||
method._is_default = True # pylint: disable=protected-access
|
||||
return method
|
||||
|
||||
|
||||
def is_default(method):
|
||||
"""Check if a method is decorated with the `default` wrapper."""
|
||||
return getattr(method, '_is_default', False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user