Experimental work-in-progress support for TPUStrategy in keras.

PiperOrigin-RevId: 211705274
This commit is contained in:
Priya Gupta 2018-09-05 15:00:09 -07:00 committed by TensorFlower Gardener
parent d6c6a759b6
commit 40e262c0dc
2 changed files with 191 additions and 46 deletions

View File

@ -58,13 +58,13 @@ def get_input_datasets():
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(100)
train_ds = train_ds.batch(64)
train_ds = train_ds.batch(64, drop_remainder=True)
# eval dataset
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
eval_ds = eval_ds.repeat()
eval_ds = eval_ds.shuffle(100)
eval_ds = eval_ds.batch(64)
eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds, input_shape

View File

@ -19,13 +19,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
def fit_loop(
@ -64,6 +67,11 @@ def fit_loop(
"""
current_strategy = model._distribution_strategy
# TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
if current_strategy.__class__.__name__ == 'TPUStrategy':
return _experimental_fit_loop(
model, iterator, epochs, initial_epoch, steps_per_epoch)
clone_model_on_towers(
model, current_strategy, make_callback_model=True)
@ -116,11 +124,6 @@ def fit_loop(
do_validation = False
if validation_steps:
do_validation = True
if steps_per_epoch is None:
raise ValueError('Can only use `validation_steps` '
'when doing step-wise '
'training, i.e. `steps_per_epoch` '
'must be set.')
# Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights()
@ -140,44 +143,46 @@ def fit_loop(
verbose=verbose)
out_labels = model.metrics_names or []
callbacks.on_train_begin()
assert steps_per_epoch is not None
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
if steps_per_epoch is not None:
epoch_logs = {}
for step_index in range(steps_per_epoch):
batch_logs = {'batch': step_index, 'size': 1}
callbacks.on_batch_begin(step_index, batch_logs)
try:
outs = distributed_train_function(ins)
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your dataset '
'can generate at least `steps_per_epoch * epochs` '
'batches (in this case, %d batches).' %
steps_per_epoch * epochs)
break
epoch_logs = {}
for step_index in range(steps_per_epoch):
batch_logs = {'batch': step_index, 'size': 1}
callbacks.on_batch_begin(step_index, batch_logs)
try:
outs = distributed_train_function(ins)
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your dataset '
'can generate at least `steps_per_epoch * epochs` '
'batches (in this case, %d batches).' %
steps_per_epoch * epochs)
break
if not isinstance(outs, list):
outs = [outs]
if not isinstance(outs, list):
outs = [outs]
outs = _aggregate_metrics_across_towers(
current_strategy.num_towers, out_labels, outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(step_index, batch_logs)
if callbacks.model.stop_training:
break
if do_validation:
val_outs = test_loop(
model,
val_iterator,
steps=validation_steps,
verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
outs = _aggregate_metrics_across_towers(
current_strategy.num_towers, out_labels, outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(step_index, batch_logs)
if callbacks.model.stop_training:
break
if do_validation:
val_outs = test_loop(
model,
val_iterator,
steps=validation_steps,
verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
@ -192,6 +197,139 @@ def fit_loop(
return model.history
def _experimental_fit_loop(
model,
iterator,
epochs=100,
initial_epoch=0,
steps_per_epoch=None):
"""fit function when using TPU DistributionStrategy for training.
Arguments:
model: Keras Model instance.
iterator: Iterator that returns inputs and targets
epochs: Number of times to iterate over the data
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
steps_per_epoch: Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. Ignored with the default value of `None`.
Returns:
Returns `None`.
Raises:
ValueError: in case of invalid arguments.
"""
current_strategy = model._distribution_strategy
# TODO(priyag): Add validation that shapes are fully defined for TPU case.
# TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
K.get_session().run(current_strategy.initialize())
def _per_device_train_function(model):
model._make_train_function()
return (model.train_function.inputs,
model.train_function.outputs,
model.train_function.updates_op,
model.train_function.session_kwargs)
# TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
K.set_learning_phase(1)
def step_fn(ctx, inputs, targets):
"""Clones the model and calls make_train_function."""
# TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
clone_model_on_towers(
model,
current_strategy,
make_callback_model=True,
inputs=inputs,
targets=targets)
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.call_for_each_tower(
_per_device_train_function, model._grouped_model)
(all_inputs, all_outputs, all_updates,
all_session_args) = distributed_training_utils.unwrap_values(
current_strategy, grouped_inputs, grouped_outputs,
grouped_updates, grouped_session_args, with_loss_tensor=True)
combined_fn = K.Function(
all_inputs, all_outputs,
updates=all_updates,
name='distributed_train_function',
**all_session_args)
# TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
# something else for different outputs.
out_labels = model.metrics_names or []
for label, output in zip(out_labels, combined_fn.outputs):
ctx.set_last_step_output(label, output,
aggregation=distribute_lib.get_loss_reduction())
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
# feed_dict, session kwargs, run options, run_metadata for now. These should
# be handled appropriately
return combined_fn.updates_op
# Add initial dummy values for loss and other metric tensors.
initial_loop_values = {}
initial_loop_values['loss'] = constant_op.constant(1e7)
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
with current_strategy.scope():
# TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
# steps_per_epoch and number of epochs.
ctx = current_strategy.run_steps_on_dataset(
step_fn, iterator, iterations=current_strategy.steps_per_run,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
output_tensors = ctx.last_step_outputs
# Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights()
with current_strategy.scope():
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
assert steps_per_epoch is not None
# TODO(priyag, sourabhbajaj): Add callbacks support.
# TODO(priyag, sourabhbajaj): Add validation.
for epoch in range(initial_epoch, epochs):
for step_index in range(
0, steps_per_epoch, current_strategy.steps_per_run):
try:
_, outs = K.get_session().run([train_op, output_tensors])
# TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
# summaries through callbacks.
print('Epoch: {}, step_index: {}, loss: {}'.format(
epoch, step_index, outs['loss']))
for label, out in outs.items():
print(label, ': ', out)
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your dataset '
'can generate at least `steps_per_epoch * epochs` '
'batches (in this case, %d batches).' %
steps_per_epoch * epochs)
break
# Copy the weights back from the replicated model to the original model.
with current_strategy.scope():
updated_weights = current_strategy.unwrap(
model._grouped_model)[0].get_weights()
model.set_weights(updated_weights)
K.get_session().run(current_strategy.finalize())
# TODO(priyag, sourabhbajaj): Return history.
def test_loop(model, iterator, verbose=0, steps=None):
"""evaluate method to validate a model that uses DistributionStrategy.
@ -373,12 +511,12 @@ def predict_loop(model, iterator, verbose=0, steps=None):
]
def _clone_and_build_model(model):
def _clone_and_build_model(model, inputs=None, targets=None):
"""Clone and build the given keras_model."""
# We need to set the import here since we run into a circular dependency
# error.
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
cloned_model = models.clone_model(model, input_tensors=None)
cloned_model = models.clone_model(model, input_tensors=inputs)
# Compile and build model.
if isinstance(model.optimizer, optimizers.TFOptimizer):
@ -387,22 +525,29 @@ def _clone_and_build_model(model):
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
# TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
# single tensor should be OK but it throws an error in that case.
if (targets is not None and not isinstance(targets, list) and
not isinstance(targets, dict)):
targets = [targets]
cloned_model.compile(
optimizer,
model.loss,
metrics=model.metrics,
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
weighted_metrics=model.weighted_metrics)
weighted_metrics=model.weighted_metrics,
target_tensors=targets)
return cloned_model
def clone_model_on_towers(model, strategy, make_callback_model=False):
def clone_model_on_towers(
model, strategy, make_callback_model=False, inputs=None, targets=None):
"""Create a cloned model on each tower, unless already created."""
if not model._grouped_model:
with strategy.scope():
model._grouped_model = strategy.call_for_each_tower(
_clone_and_build_model, model)
_clone_and_build_model, model, inputs, targets)
if make_callback_model:
model._make_callback_model()