Update training_v2 to count of samples if the total number is known.
This will bring back the existing behavior of progress bar and callbacks if they rely on the counting of number of example. Also update the callback test to use v2 optimizer, since v1 will fail with run_distributed = True. PiperOrigin-RevId: 259854861
This commit is contained in:
parent
6b337b315e
commit
272d69f23c
tensorflow/python/keras
@ -869,7 +869,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer=keras.optimizers.SGD(lr=0.1))
|
||||
optimizer=gradient_descent.SGD(lr=0.1))
|
||||
return model
|
||||
|
||||
# TODO(psv): Make sure the callback works correctly when min_delta is
|
||||
@ -975,7 +975,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer=keras.optimizers.SGD(lr=0.1),
|
||||
optimizer=gradient_descent.SGD(lr=0.1),
|
||||
metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
|
@ -59,10 +59,10 @@ def run_one_epoch(model,
|
||||
batch_size=None,
|
||||
strategy=None,
|
||||
steps_per_epoch=None,
|
||||
num_samples=None,
|
||||
mode=ModeKeys.TRAIN,
|
||||
training_context=None,
|
||||
total_epochs=None,
|
||||
partical_batch_size=None):
|
||||
total_epochs=None):
|
||||
"""Run the execution function with the data from iterator.
|
||||
|
||||
Given the dataset iterator and execution function, get the data from iterator
|
||||
@ -77,21 +77,18 @@ def run_one_epoch(model,
|
||||
batch_size: The size of the current batch.
|
||||
strategy: the distribution strategy instance from the model.
|
||||
steps_per_epoch: the number of steps to run for the epoch.
|
||||
num_samples: the number of samples for the whole epoch if known. This can be
|
||||
used to calculate the final partial batch, and scale the loss.
|
||||
mode: the mode for the current epoch.
|
||||
training_context: the context that contains callbacks and progress bar.
|
||||
total_epochs: the total number of epochs that will be run.
|
||||
Used when throw error when the iterator unexpectedly
|
||||
reaches its end.
|
||||
partical_batch_size: the size of the final batch if it is already known. It
|
||||
will be used to scale the loss value for the final batch.
|
||||
Returns:
|
||||
The loss and metric value from the model.
|
||||
"""
|
||||
# Only use the sample to count if there is a partial batch at the end.
|
||||
use_steps = not (partical_batch_size and batch_size and steps_per_epoch and
|
||||
steps_per_epoch == dataset_size)
|
||||
num_samples = None if use_steps else batch_size * (steps_per_epoch -
|
||||
1) + partical_batch_size
|
||||
use_steps = num_samples is None
|
||||
|
||||
if mode == ModeKeys.PREDICT:
|
||||
aggregator = training_utils.OutputsAggregator(
|
||||
@ -112,10 +109,17 @@ def run_one_epoch(model,
|
||||
step = 0
|
||||
|
||||
while step < target_steps:
|
||||
if use_steps:
|
||||
current_batch_size = 1
|
||||
elif step < target_steps - 1:
|
||||
current_batch_size = batch_size
|
||||
else:
|
||||
current_batch_size = num_samples - step * batch_size
|
||||
|
||||
# TODO(scottzhu): Maybe update the training context to take into account
|
||||
# whether a batch of training happens. Then it could still use a
|
||||
# context manager
|
||||
batch_logs = {'batch': step, 'size': 1}
|
||||
batch_logs = {'batch': step, 'size': current_batch_size}
|
||||
training_context.callbacks._call_batch_hook(
|
||||
mode, 'begin', step, batch_logs)
|
||||
training_context.progbar.on_batch_begin(step, batch_logs)
|
||||
@ -162,7 +166,7 @@ def run_one_epoch(model,
|
||||
aggregator.aggregate(
|
||||
batch_outs,
|
||||
batch_start=step * batch_size,
|
||||
batch_end=min((step + 1) * batch_size, num_samples))
|
||||
batch_end=step * batch_size + current_batch_size)
|
||||
cbks.make_logs(model, batch_logs, batch_outs, mode)
|
||||
|
||||
training_context.callbacks._call_batch_hook(
|
||||
@ -216,6 +220,8 @@ class Loop(training_utils.TrainingLoop):
|
||||
validation_steps=validation_steps,
|
||||
distribution_strategy=strategy)
|
||||
|
||||
total_samples = _get_total_number_of_samples(training_data_adapter)
|
||||
use_sample = total_samples is not None
|
||||
do_validation = (validation_adapter is not None)
|
||||
|
||||
if not steps_per_epoch:
|
||||
@ -273,11 +279,13 @@ class Loop(training_utils.TrainingLoop):
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
samples=None,
|
||||
samples=total_samples,
|
||||
count_mode='samples' if use_sample else 'steps',
|
||||
verbose=0, # Handle ProgBarLogger separately in this loop.
|
||||
mode=ModeKeys.TRAIN)
|
||||
|
||||
with training_context.on_start(model, callbacks, verbose, ModeKeys.TRAIN):
|
||||
with training_context.on_start(
|
||||
model, callbacks, use_sample, verbose, ModeKeys.TRAIN):
|
||||
# TODO(scottzhu): Handle TPUStrategy training loop
|
||||
for epoch in range(initial_epoch, epochs):
|
||||
if training_context.callbacks.model.stop_training:
|
||||
@ -303,10 +311,10 @@ class Loop(training_utils.TrainingLoop):
|
||||
batch_size=training_data_adapter.batch_size(),
|
||||
strategy=strategy,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
num_samples=total_samples,
|
||||
mode=ModeKeys.TRAIN,
|
||||
training_context=training_context,
|
||||
total_epochs=epochs,
|
||||
partical_batch_size=training_data_adapter.partial_batch_size())
|
||||
total_epochs=epochs)
|
||||
cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
|
||||
|
||||
# Evaluation
|
||||
@ -321,9 +329,11 @@ class Loop(training_utils.TrainingLoop):
|
||||
else:
|
||||
eval_data_iter = iter(validation_dataset)
|
||||
|
||||
val_total_samples = _get_total_number_of_samples(
|
||||
validation_adapter)
|
||||
eval_context = TrainingContext()
|
||||
with eval_context.on_start(
|
||||
model, callbacks, verbose=0, mode=ModeKeys.TEST):
|
||||
model, callbacks, use_sample, verbose=0, mode=ModeKeys.TEST):
|
||||
with eval_context.on_epoch(epoch, ModeKeys.TEST):
|
||||
model.reset_metrics()
|
||||
eval_result = run_one_epoch(
|
||||
@ -334,11 +344,10 @@ class Loop(training_utils.TrainingLoop):
|
||||
batch_size=validation_adapter.batch_size(),
|
||||
strategy=strategy,
|
||||
steps_per_epoch=validation_steps,
|
||||
num_samples=val_total_samples,
|
||||
mode=ModeKeys.TEST,
|
||||
training_context=eval_context,
|
||||
total_epochs=1,
|
||||
partical_batch_size=validation_adapter.partial_batch_size(
|
||||
))
|
||||
total_epochs=1)
|
||||
cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
|
||||
prefix='val_')
|
||||
|
||||
@ -365,6 +374,8 @@ class Loop(training_utils.TrainingLoop):
|
||||
sample_weights=sample_weight,
|
||||
steps=steps,
|
||||
distribution_strategy=strategy)
|
||||
total_samples = _get_total_number_of_samples(adapter)
|
||||
use_sample = total_samples is not None
|
||||
|
||||
if not steps:
|
||||
steps = adapter.get_size()
|
||||
@ -393,11 +404,13 @@ class Loop(training_utils.TrainingLoop):
|
||||
batch_size=batch_size,
|
||||
epochs=1,
|
||||
steps_per_epoch=steps,
|
||||
samples=None,
|
||||
samples=use_sample,
|
||||
count_mode='samples' if use_sample else 'steps',
|
||||
verbose=0, # Handle ProgBarLogger separately in this loop.
|
||||
mode=mode)
|
||||
|
||||
with training_context.on_start(model, callbacks, verbose, mode):
|
||||
with training_context.on_start(
|
||||
model, callbacks, use_sample, verbose, mode):
|
||||
# TODO(scottzhu): Handle TPUStrategy training loop
|
||||
with training_context.on_epoch(0, mode) as epoch_logs:
|
||||
model.reset_metrics()
|
||||
@ -409,10 +422,10 @@ class Loop(training_utils.TrainingLoop):
|
||||
batch_size=adapter.batch_size(),
|
||||
strategy=strategy,
|
||||
steps_per_epoch=steps,
|
||||
num_samples=total_samples,
|
||||
mode=mode,
|
||||
training_context=training_context,
|
||||
total_epochs=1,
|
||||
partical_batch_size=adapter.partial_batch_size())
|
||||
total_epochs=1)
|
||||
cbks.make_logs(model, epoch_logs, result, mode)
|
||||
|
||||
if len(result) == 1:
|
||||
@ -571,14 +584,25 @@ def _update_sample_weight_mode(model, mode, dataset):
|
||||
del iterator
|
||||
|
||||
|
||||
def _get_total_number_of_samples(adapter):
|
||||
if not adapter.get_size() or not adapter.batch_size():
|
||||
return None
|
||||
total_sample = adapter.get_size() * adapter.batch_size()
|
||||
if adapter.has_partial_batch():
|
||||
total_sample -= (adapter.batch_size() - adapter.partial_batch_size())
|
||||
return total_sample
|
||||
|
||||
|
||||
class TrainingContext(object):
|
||||
"""Utility object that wrap around callbacks and progress bars."""
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def on_start(self, model, callbacks=None, verbose=0, mode=ModeKeys.TRAIN):
|
||||
def on_start(self, model, callbacks=None, use_samples=False, verbose=0,
|
||||
mode=ModeKeys.TRAIN):
|
||||
"""Provide a scope for the whole training process."""
|
||||
# TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
|
||||
progbar = training_utils.get_progbar(model, 'steps')
|
||||
progbar = training_utils.get_progbar(
|
||||
model, 'samples' if use_samples else 'steps')
|
||||
progbar.params = callbacks.params
|
||||
progbar.params['verbose'] = verbose
|
||||
callbacks.model.stop_training = False
|
||||
|
Loading…
Reference in New Issue
Block a user