diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index f072384d09f..8aca40f80aa 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -869,7 +869,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase): num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM) model.compile( loss='categorical_crossentropy', - optimizer=keras.optimizers.SGD(lr=0.1)) + optimizer=gradient_descent.SGD(lr=0.1)) return model # TODO(psv): Make sure the callback works correctly when min_delta is @@ -975,7 +975,7 @@ class KerasCallbacksTest(keras_parameterized.TestCase): num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM) model.compile( loss='categorical_crossentropy', - optimizer=keras.optimizers.SGD(lr=0.1), + optimizer=gradient_descent.SGD(lr=0.1), metrics=['accuracy']) return model diff --git a/tensorflow/python/keras/engine/training_v2.py b/tensorflow/python/keras/engine/training_v2.py index 7e89312d891..5d098476800 100644 --- a/tensorflow/python/keras/engine/training_v2.py +++ b/tensorflow/python/keras/engine/training_v2.py @@ -59,10 +59,10 @@ def run_one_epoch(model, batch_size=None, strategy=None, steps_per_epoch=None, + num_samples=None, mode=ModeKeys.TRAIN, training_context=None, - total_epochs=None, - partical_batch_size=None): + total_epochs=None): """Run the execution function with the data from iterator. Given the dataset iterator and execution function, get the data from iterator @@ -77,21 +77,18 @@ def run_one_epoch(model, batch_size: The size of the current batch. strategy: the distribution strategy instance from the model. steps_per_epoch: the number of steps to run for the epoch. + num_samples: the number of samples for the whole epoch if known. This can be + used to calculate the final partial batch, and scale the loss. mode: the mode for the current epoch. training_context: the context that contains callbacks and progress bar. total_epochs: the total number of epochs that will be run. Used when throw error when the iterator unexpectedly reaches its end. - partical_batch_size: the size of the final batch if it is already known. It - will be used to scale the loss value for the final batch. Returns: The loss and metric value from the model. """ # Only use the sample to count if there is a partial batch at the end. - use_steps = not (partical_batch_size and batch_size and steps_per_epoch and - steps_per_epoch == dataset_size) - num_samples = None if use_steps else batch_size * (steps_per_epoch - - 1) + partical_batch_size + use_steps = num_samples is None if mode == ModeKeys.PREDICT: aggregator = training_utils.OutputsAggregator( @@ -112,10 +109,17 @@ def run_one_epoch(model, step = 0 while step < target_steps: + if use_steps: + current_batch_size = 1 + elif step < target_steps - 1: + current_batch_size = batch_size + else: + current_batch_size = num_samples - step * batch_size + # TODO(scottzhu): Maybe update the training context to take into account # whether a batch of training happens. Then it could still use a # context manager - batch_logs = {'batch': step, 'size': 1} + batch_logs = {'batch': step, 'size': current_batch_size} training_context.callbacks._call_batch_hook( mode, 'begin', step, batch_logs) training_context.progbar.on_batch_begin(step, batch_logs) @@ -162,7 +166,7 @@ def run_one_epoch(model, aggregator.aggregate( batch_outs, batch_start=step * batch_size, - batch_end=min((step + 1) * batch_size, num_samples)) + batch_end=step * batch_size + current_batch_size) cbks.make_logs(model, batch_logs, batch_outs, mode) training_context.callbacks._call_batch_hook( @@ -216,6 +220,8 @@ class Loop(training_utils.TrainingLoop): validation_steps=validation_steps, distribution_strategy=strategy) + total_samples = _get_total_number_of_samples(training_data_adapter) + use_sample = total_samples is not None do_validation = (validation_adapter is not None) if not steps_per_epoch: @@ -273,11 +279,13 @@ class Loop(training_utils.TrainingLoop): batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, - samples=None, + samples=total_samples, + count_mode='samples' if use_sample else 'steps', verbose=0, # Handle ProgBarLogger separately in this loop. mode=ModeKeys.TRAIN) - with training_context.on_start(model, callbacks, verbose, ModeKeys.TRAIN): + with training_context.on_start( + model, callbacks, use_sample, verbose, ModeKeys.TRAIN): # TODO(scottzhu): Handle TPUStrategy training loop for epoch in range(initial_epoch, epochs): if training_context.callbacks.model.stop_training: @@ -303,10 +311,10 @@ class Loop(training_utils.TrainingLoop): batch_size=training_data_adapter.batch_size(), strategy=strategy, steps_per_epoch=steps_per_epoch, + num_samples=total_samples, mode=ModeKeys.TRAIN, training_context=training_context, - total_epochs=epochs, - partical_batch_size=training_data_adapter.partial_batch_size()) + total_epochs=epochs) cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN) # Evaluation @@ -321,9 +329,11 @@ class Loop(training_utils.TrainingLoop): else: eval_data_iter = iter(validation_dataset) + val_total_samples = _get_total_number_of_samples( + validation_adapter) eval_context = TrainingContext() with eval_context.on_start( - model, callbacks, verbose=0, mode=ModeKeys.TEST): + model, callbacks, use_sample, verbose=0, mode=ModeKeys.TEST): with eval_context.on_epoch(epoch, ModeKeys.TEST): model.reset_metrics() eval_result = run_one_epoch( @@ -334,11 +344,10 @@ class Loop(training_utils.TrainingLoop): batch_size=validation_adapter.batch_size(), strategy=strategy, steps_per_epoch=validation_steps, + num_samples=val_total_samples, mode=ModeKeys.TEST, training_context=eval_context, - total_epochs=1, - partical_batch_size=validation_adapter.partial_batch_size( - )) + total_epochs=1) cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST, prefix='val_') @@ -365,6 +374,8 @@ class Loop(training_utils.TrainingLoop): sample_weights=sample_weight, steps=steps, distribution_strategy=strategy) + total_samples = _get_total_number_of_samples(adapter) + use_sample = total_samples is not None if not steps: steps = adapter.get_size() @@ -393,11 +404,13 @@ class Loop(training_utils.TrainingLoop): batch_size=batch_size, epochs=1, steps_per_epoch=steps, - samples=None, + samples=use_sample, + count_mode='samples' if use_sample else 'steps', verbose=0, # Handle ProgBarLogger separately in this loop. mode=mode) - with training_context.on_start(model, callbacks, verbose, mode): + with training_context.on_start( + model, callbacks, use_sample, verbose, mode): # TODO(scottzhu): Handle TPUStrategy training loop with training_context.on_epoch(0, mode) as epoch_logs: model.reset_metrics() @@ -409,10 +422,10 @@ class Loop(training_utils.TrainingLoop): batch_size=adapter.batch_size(), strategy=strategy, steps_per_epoch=steps, + num_samples=total_samples, mode=mode, training_context=training_context, - total_epochs=1, - partical_batch_size=adapter.partial_batch_size()) + total_epochs=1) cbks.make_logs(model, epoch_logs, result, mode) if len(result) == 1: @@ -571,14 +584,25 @@ def _update_sample_weight_mode(model, mode, dataset): del iterator +def _get_total_number_of_samples(adapter): + if not adapter.get_size() or not adapter.batch_size(): + return None + total_sample = adapter.get_size() * adapter.batch_size() + if adapter.has_partial_batch(): + total_sample -= (adapter.batch_size() - adapter.partial_batch_size()) + return total_sample + + class TrainingContext(object): """Utility object that wrap around callbacks and progress bars.""" @tf_contextlib.contextmanager - def on_start(self, model, callbacks=None, verbose=0, mode=ModeKeys.TRAIN): + def on_start(self, model, callbacks=None, use_samples=False, verbose=0, + mode=ModeKeys.TRAIN): """Provide a scope for the whole training process.""" # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready. - progbar = training_utils.get_progbar(model, 'steps') + progbar = training_utils.get_progbar( + model, 'samples' if use_samples else 'steps') progbar.params = callbacks.params progbar.params['verbose'] = verbose callbacks.model.stop_training = False