diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index a20069c4fe4..04951346366 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -58,13 +58,13 @@ def get_input_datasets(): train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() train_ds = train_ds.shuffle(100) - train_ds = train_ds.batch(64) + train_ds = train_ds.batch(64, drop_remainder=True) # eval dataset eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = eval_ds.repeat() eval_ds = eval_ds.shuffle(100) - eval_ds = eval_ds.batch(64) + eval_ds = eval_ds.batch(64, drop_remainder=True) return train_ds, eval_ds, input_shape diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index a7bb1f81776..e440e02bfb0 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -19,13 +19,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribute as distribute_lib def fit_loop( @@ -64,6 +67,11 @@ def fit_loop( """ current_strategy = model._distribution_strategy + # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. + if current_strategy.__class__.__name__ == 'TPUStrategy': + return _experimental_fit_loop( + model, iterator, epochs, initial_epoch, steps_per_epoch) + clone_model_on_towers( model, current_strategy, make_callback_model=True) @@ -116,11 +124,6 @@ def fit_loop( do_validation = False if validation_steps: do_validation = True - if steps_per_epoch is None: - raise ValueError('Can only use `validation_steps` ' - 'when doing step-wise ' - 'training, i.e. `steps_per_epoch` ' - 'must be set.') # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() @@ -140,44 +143,46 @@ def fit_loop( verbose=verbose) out_labels = model.metrics_names or [] callbacks.on_train_begin() + + assert steps_per_epoch is not None + for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) - if steps_per_epoch is not None: - epoch_logs = {} - for step_index in range(steps_per_epoch): - batch_logs = {'batch': step_index, 'size': 1} - callbacks.on_batch_begin(step_index, batch_logs) - try: - outs = distributed_train_function(ins) - except errors.OutOfRangeError: - logging.warning('Your dataset iterator ran out of data; ' - 'interrupting training. Make sure that your dataset ' - 'can generate at least `steps_per_epoch * epochs` ' - 'batches (in this case, %d batches).' % - steps_per_epoch * epochs) - break + epoch_logs = {} + for step_index in range(steps_per_epoch): + batch_logs = {'batch': step_index, 'size': 1} + callbacks.on_batch_begin(step_index, batch_logs) + try: + outs = distributed_train_function(ins) + except errors.OutOfRangeError: + logging.warning('Your dataset iterator ran out of data; ' + 'interrupting training. Make sure that your dataset ' + 'can generate at least `steps_per_epoch * epochs` ' + 'batches (in this case, %d batches).' % + steps_per_epoch * epochs) + break - if not isinstance(outs, list): - outs = [outs] + if not isinstance(outs, list): + outs = [outs] - outs = _aggregate_metrics_across_towers( - current_strategy.num_towers, out_labels, outs) - for l, o in zip(out_labels, outs): - batch_logs[l] = o - callbacks.on_batch_end(step_index, batch_logs) - if callbacks.model.stop_training: - break - if do_validation: - val_outs = test_loop( - model, - val_iterator, - steps=validation_steps, - verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o + outs = _aggregate_metrics_across_towers( + current_strategy.num_towers, out_labels, outs) + for l, o in zip(out_labels, outs): + batch_logs[l] = o + callbacks.on_batch_end(step_index, batch_logs) + if callbacks.model.stop_training: + break + if do_validation: + val_outs = test_loop( + model, + val_iterator, + steps=validation_steps, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: @@ -192,6 +197,139 @@ def fit_loop( return model.history +def _experimental_fit_loop( + model, + iterator, + epochs=100, + initial_epoch=0, + steps_per_epoch=None): + """fit function when using TPU DistributionStrategy for training. + + Arguments: + model: Keras Model instance. + iterator: Iterator that returns inputs and targets + epochs: Number of times to iterate over the data + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run) + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. Ignored with the default value of `None`. + + Returns: + Returns `None`. + + Raises: + ValueError: in case of invalid arguments. + """ + current_strategy = model._distribution_strategy + + # TODO(priyag): Add validation that shapes are fully defined for TPU case. + + # TODO(priyag, sourabhbajaj): This should be moved into a callback instead. + K.get_session().run(current_strategy.initialize()) + + def _per_device_train_function(model): + model._make_train_function() + return (model.train_function.inputs, + model.train_function.outputs, + model.train_function.updates_op, + model.train_function.session_kwargs) + + # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. + K.set_learning_phase(1) + + def step_fn(ctx, inputs, targets): + """Clones the model and calls make_train_function.""" + # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes. + clone_model_on_towers( + model, + current_strategy, + make_callback_model=True, + inputs=inputs, + targets=targets) + + (grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) = current_strategy.call_for_each_tower( + _per_device_train_function, model._grouped_model) + (all_inputs, all_outputs, all_updates, + all_session_args) = distributed_training_utils.unwrap_values( + current_strategy, grouped_inputs, grouped_outputs, + grouped_updates, grouped_session_args, with_loss_tensor=True) + combined_fn = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_train_function', + **all_session_args) + + # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be + # something else for different outputs. + out_labels = model.metrics_names or [] + for label, output in zip(out_labels, combined_fn.outputs): + ctx.set_last_step_output(label, output, + aggregation=distribute_lib.get_loss_reduction()) + + # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: + # feed_dict, session kwargs, run options, run_metadata for now. These should + # be handled appropriately + return combined_fn.updates_op + + # Add initial dummy values for loss and other metric tensors. + initial_loop_values = {} + initial_loop_values['loss'] = constant_op.constant(1e7) + for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors): + initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) + + with current_strategy.scope(): + # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on + # steps_per_epoch and number of epochs. + ctx = current_strategy.run_steps_on_dataset( + step_fn, iterator, iterations=current_strategy.steps_per_run, + initial_loop_values=initial_loop_values) + + train_op = ctx.run_op + output_tensors = ctx.last_step_outputs + + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() + with current_strategy.scope(): + distributed_model = current_strategy.unwrap(model._grouped_model)[0] + distributed_training_utils.set_weights( + current_strategy, distributed_model, orig_model_weights) + + assert steps_per_epoch is not None + + # TODO(priyag, sourabhbajaj): Add callbacks support. + # TODO(priyag, sourabhbajaj): Add validation. + for epoch in range(initial_epoch, epochs): + for step_index in range( + 0, steps_per_epoch, current_strategy.steps_per_run): + try: + _, outs = K.get_session().run([train_op, output_tensors]) + # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper + # summaries through callbacks. + print('Epoch: {}, step_index: {}, loss: {}'.format( + epoch, step_index, outs['loss'])) + for label, out in outs.items(): + print(label, ': ', out) + except errors.OutOfRangeError: + logging.warning('Your dataset iterator ran out of data; ' + 'interrupting training. Make sure that your dataset ' + 'can generate at least `steps_per_epoch * epochs` ' + 'batches (in this case, %d batches).' % + steps_per_epoch * epochs) + break + + # Copy the weights back from the replicated model to the original model. + with current_strategy.scope(): + updated_weights = current_strategy.unwrap( + model._grouped_model)[0].get_weights() + model.set_weights(updated_weights) + + K.get_session().run(current_strategy.finalize()) + + # TODO(priyag, sourabhbajaj): Return history. + + def test_loop(model, iterator, verbose=0, steps=None): """evaluate method to validate a model that uses DistributionStrategy. @@ -373,12 +511,12 @@ def predict_loop(model, iterator, verbose=0, steps=None): ] -def _clone_and_build_model(model): +def _clone_and_build_model(model, inputs=None, targets=None): """Clone and build the given keras_model.""" # We need to set the import here since we run into a circular dependency # error. from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top - cloned_model = models.clone_model(model, input_tensors=None) + cloned_model = models.clone_model(model, input_tensors=inputs) # Compile and build model. if isinstance(model.optimizer, optimizers.TFOptimizer): @@ -387,22 +525,29 @@ def _clone_and_build_model(model): optimizer_config = model.optimizer.get_config() optimizer = model.optimizer.__class__.from_config(optimizer_config) + # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a + # single tensor should be OK but it throws an error in that case. + if (targets is not None and not isinstance(targets, list) and + not isinstance(targets, dict)): + targets = [targets] cloned_model.compile( optimizer, model.loss, metrics=model.metrics, loss_weights=model.loss_weights, sample_weight_mode=model.sample_weight_mode, - weighted_metrics=model.weighted_metrics) + weighted_metrics=model.weighted_metrics, + target_tensors=targets) return cloned_model -def clone_model_on_towers(model, strategy, make_callback_model=False): +def clone_model_on_towers( + model, strategy, make_callback_model=False, inputs=None, targets=None): """Create a cloned model on each tower, unless already created.""" if not model._grouped_model: with strategy.scope(): model._grouped_model = strategy.call_for_each_tower( - _clone_and_build_model, model) + _clone_and_build_model, model, inputs, targets) if make_callback_model: model._make_callback_model()