Update training_v2 to count of samples if the total number is known.

This will bring back the existing behavior of progress bar and callbacks if they rely on the counting of number of example.

Also update the callback test to use v2 optimizer, since v1 will fail
with run_distributed = True.

PiperOrigin-RevId: 259854861
This commit is contained in:
Scott Zhu 2019-07-24 17:47:14 -07:00 committed by TensorFlower Gardener
parent 6b337b315e
commit 272d69f23c
2 changed files with 50 additions and 26 deletions
tensorflow/python/keras

View File

@ -869,7 +869,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizers.SGD(lr=0.1))
optimizer=gradient_descent.SGD(lr=0.1))
return model
# TODO(psv): Make sure the callback works correctly when min_delta is
@ -975,7 +975,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizers.SGD(lr=0.1),
optimizer=gradient_descent.SGD(lr=0.1),
metrics=['accuracy'])
return model

View File

@ -59,10 +59,10 @@ def run_one_epoch(model,
batch_size=None,
strategy=None,
steps_per_epoch=None,
num_samples=None,
mode=ModeKeys.TRAIN,
training_context=None,
total_epochs=None,
partical_batch_size=None):
total_epochs=None):
"""Run the execution function with the data from iterator.
Given the dataset iterator and execution function, get the data from iterator
@ -77,21 +77,18 @@ def run_one_epoch(model,
batch_size: The size of the current batch.
strategy: the distribution strategy instance from the model.
steps_per_epoch: the number of steps to run for the epoch.
num_samples: the number of samples for the whole epoch if known. This can be
used to calculate the final partial batch, and scale the loss.
mode: the mode for the current epoch.
training_context: the context that contains callbacks and progress bar.
total_epochs: the total number of epochs that will be run.
Used when throw error when the iterator unexpectedly
reaches its end.
partical_batch_size: the size of the final batch if it is already known. It
will be used to scale the loss value for the final batch.
Returns:
The loss and metric value from the model.
"""
# Only use the sample to count if there is a partial batch at the end.
use_steps = not (partical_batch_size and batch_size and steps_per_epoch and
steps_per_epoch == dataset_size)
num_samples = None if use_steps else batch_size * (steps_per_epoch -
1) + partical_batch_size
use_steps = num_samples is None
if mode == ModeKeys.PREDICT:
aggregator = training_utils.OutputsAggregator(
@ -112,10 +109,17 @@ def run_one_epoch(model,
step = 0
while step < target_steps:
if use_steps:
current_batch_size = 1
elif step < target_steps - 1:
current_batch_size = batch_size
else:
current_batch_size = num_samples - step * batch_size
# TODO(scottzhu): Maybe update the training context to take into account
# whether a batch of training happens. Then it could still use a
# context manager
batch_logs = {'batch': step, 'size': 1}
batch_logs = {'batch': step, 'size': current_batch_size}
training_context.callbacks._call_batch_hook(
mode, 'begin', step, batch_logs)
training_context.progbar.on_batch_begin(step, batch_logs)
@ -162,7 +166,7 @@ def run_one_epoch(model,
aggregator.aggregate(
batch_outs,
batch_start=step * batch_size,
batch_end=min((step + 1) * batch_size, num_samples))
batch_end=step * batch_size + current_batch_size)
cbks.make_logs(model, batch_logs, batch_outs, mode)
training_context.callbacks._call_batch_hook(
@ -216,6 +220,8 @@ class Loop(training_utils.TrainingLoop):
validation_steps=validation_steps,
distribution_strategy=strategy)
total_samples = _get_total_number_of_samples(training_data_adapter)
use_sample = total_samples is not None
do_validation = (validation_adapter is not None)
if not steps_per_epoch:
@ -273,11 +279,13 @@ class Loop(training_utils.TrainingLoop):
batch_size=batch_size,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
samples=None,
samples=total_samples,
count_mode='samples' if use_sample else 'steps',
verbose=0, # Handle ProgBarLogger separately in this loop.
mode=ModeKeys.TRAIN)
with training_context.on_start(model, callbacks, verbose, ModeKeys.TRAIN):
with training_context.on_start(
model, callbacks, use_sample, verbose, ModeKeys.TRAIN):
# TODO(scottzhu): Handle TPUStrategy training loop
for epoch in range(initial_epoch, epochs):
if training_context.callbacks.model.stop_training:
@ -303,10 +311,10 @@ class Loop(training_utils.TrainingLoop):
batch_size=training_data_adapter.batch_size(),
strategy=strategy,
steps_per_epoch=steps_per_epoch,
num_samples=total_samples,
mode=ModeKeys.TRAIN,
training_context=training_context,
total_epochs=epochs,
partical_batch_size=training_data_adapter.partial_batch_size())
total_epochs=epochs)
cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
# Evaluation
@ -321,9 +329,11 @@ class Loop(training_utils.TrainingLoop):
else:
eval_data_iter = iter(validation_dataset)
val_total_samples = _get_total_number_of_samples(
validation_adapter)
eval_context = TrainingContext()
with eval_context.on_start(
model, callbacks, verbose=0, mode=ModeKeys.TEST):
model, callbacks, use_sample, verbose=0, mode=ModeKeys.TEST):
with eval_context.on_epoch(epoch, ModeKeys.TEST):
model.reset_metrics()
eval_result = run_one_epoch(
@ -334,11 +344,10 @@ class Loop(training_utils.TrainingLoop):
batch_size=validation_adapter.batch_size(),
strategy=strategy,
steps_per_epoch=validation_steps,
num_samples=val_total_samples,
mode=ModeKeys.TEST,
training_context=eval_context,
total_epochs=1,
partical_batch_size=validation_adapter.partial_batch_size(
))
total_epochs=1)
cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
prefix='val_')
@ -365,6 +374,8 @@ class Loop(training_utils.TrainingLoop):
sample_weights=sample_weight,
steps=steps,
distribution_strategy=strategy)
total_samples = _get_total_number_of_samples(adapter)
use_sample = total_samples is not None
if not steps:
steps = adapter.get_size()
@ -393,11 +404,13 @@ class Loop(training_utils.TrainingLoop):
batch_size=batch_size,
epochs=1,
steps_per_epoch=steps,
samples=None,
samples=use_sample,
count_mode='samples' if use_sample else 'steps',
verbose=0, # Handle ProgBarLogger separately in this loop.
mode=mode)
with training_context.on_start(model, callbacks, verbose, mode):
with training_context.on_start(
model, callbacks, use_sample, verbose, mode):
# TODO(scottzhu): Handle TPUStrategy training loop
with training_context.on_epoch(0, mode) as epoch_logs:
model.reset_metrics()
@ -409,10 +422,10 @@ class Loop(training_utils.TrainingLoop):
batch_size=adapter.batch_size(),
strategy=strategy,
steps_per_epoch=steps,
num_samples=total_samples,
mode=mode,
training_context=training_context,
total_epochs=1,
partical_batch_size=adapter.partial_batch_size())
total_epochs=1)
cbks.make_logs(model, epoch_logs, result, mode)
if len(result) == 1:
@ -571,14 +584,25 @@ def _update_sample_weight_mode(model, mode, dataset):
del iterator
def _get_total_number_of_samples(adapter):
if not adapter.get_size() or not adapter.batch_size():
return None
total_sample = adapter.get_size() * adapter.batch_size()
if adapter.has_partial_batch():
total_sample -= (adapter.batch_size() - adapter.partial_batch_size())
return total_sample
class TrainingContext(object):
"""Utility object that wrap around callbacks and progress bars."""
@tf_contextlib.contextmanager
def on_start(self, model, callbacks=None, verbose=0, mode=ModeKeys.TRAIN):
def on_start(self, model, callbacks=None, use_samples=False, verbose=0,
mode=ModeKeys.TRAIN):
"""Provide a scope for the whole training process."""
# TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
progbar = training_utils.get_progbar(model, 'steps')
progbar = training_utils.get_progbar(
model, 'samples' if use_samples else 'steps')
progbar.params = callbacks.params
progbar.params['verbose'] = verbose
callbacks.model.stop_training = False