Experimental work-in-progress support for TPUStrategy in keras.

PiperOrigin-RevId: 211705274
This commit is contained in:
Priya Gupta 2018-09-05 15:00:09 -07:00 committed by TensorFlower Gardener
parent d6c6a759b6
commit 40e262c0dc
2 changed files with 191 additions and 46 deletions

View File

@ -58,13 +58,13 @@ def get_input_datasets():
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.repeat() train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(100) train_ds = train_ds.shuffle(100)
train_ds = train_ds.batch(64) train_ds = train_ds.batch(64, drop_remainder=True)
# eval dataset # eval dataset
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
eval_ds = eval_ds.repeat() eval_ds = eval_ds.repeat()
eval_ds = eval_ds.shuffle(100) eval_ds = eval_ds.shuffle(100)
eval_ds = eval_ds.batch(64) eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds, input_shape return train_ds, eval_ds, input_shape

View File

@ -19,13 +19,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import optimizers from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
def fit_loop( def fit_loop(
@ -64,6 +67,11 @@ def fit_loop(
""" """
current_strategy = model._distribution_strategy current_strategy = model._distribution_strategy
# TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
if current_strategy.__class__.__name__ == 'TPUStrategy':
return _experimental_fit_loop(
model, iterator, epochs, initial_epoch, steps_per_epoch)
clone_model_on_towers( clone_model_on_towers(
model, current_strategy, make_callback_model=True) model, current_strategy, make_callback_model=True)
@ -116,11 +124,6 @@ def fit_loop(
do_validation = False do_validation = False
if validation_steps: if validation_steps:
do_validation = True do_validation = True
if steps_per_epoch is None:
raise ValueError('Can only use `validation_steps` '
'when doing step-wise '
'training, i.e. `steps_per_epoch` '
'must be set.')
# Copy the weights from the original model to each of the replicated models. # Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights() orig_model_weights = model.get_weights()
@ -140,9 +143,11 @@ def fit_loop(
verbose=verbose) verbose=verbose)
out_labels = model.metrics_names or [] out_labels = model.metrics_names or []
callbacks.on_train_begin() callbacks.on_train_begin()
assert steps_per_epoch is not None
for epoch in range(initial_epoch, epochs): for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch) callbacks.on_epoch_begin(epoch)
if steps_per_epoch is not None:
epoch_logs = {} epoch_logs = {}
for step_index in range(steps_per_epoch): for step_index in range(steps_per_epoch):
batch_logs = {'batch': step_index, 'size': 1} batch_logs = {'batch': step_index, 'size': 1}
@ -192,6 +197,139 @@ def fit_loop(
return model.history return model.history
def _experimental_fit_loop(
model,
iterator,
epochs=100,
initial_epoch=0,
steps_per_epoch=None):
"""fit function when using TPU DistributionStrategy for training.
Arguments:
model: Keras Model instance.
iterator: Iterator that returns inputs and targets
epochs: Number of times to iterate over the data
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
steps_per_epoch: Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. Ignored with the default value of `None`.
Returns:
Returns `None`.
Raises:
ValueError: in case of invalid arguments.
"""
current_strategy = model._distribution_strategy
# TODO(priyag): Add validation that shapes are fully defined for TPU case.
# TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
K.get_session().run(current_strategy.initialize())
def _per_device_train_function(model):
model._make_train_function()
return (model.train_function.inputs,
model.train_function.outputs,
model.train_function.updates_op,
model.train_function.session_kwargs)
# TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
K.set_learning_phase(1)
def step_fn(ctx, inputs, targets):
"""Clones the model and calls make_train_function."""
# TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
clone_model_on_towers(
model,
current_strategy,
make_callback_model=True,
inputs=inputs,
targets=targets)
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.call_for_each_tower(
_per_device_train_function, model._grouped_model)
(all_inputs, all_outputs, all_updates,
all_session_args) = distributed_training_utils.unwrap_values(
current_strategy, grouped_inputs, grouped_outputs,
grouped_updates, grouped_session_args, with_loss_tensor=True)
combined_fn = K.Function(
all_inputs, all_outputs,
updates=all_updates,
name='distributed_train_function',
**all_session_args)
# TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
# something else for different outputs.
out_labels = model.metrics_names or []
for label, output in zip(out_labels, combined_fn.outputs):
ctx.set_last_step_output(label, output,
aggregation=distribute_lib.get_loss_reduction())
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
# feed_dict, session kwargs, run options, run_metadata for now. These should
# be handled appropriately
return combined_fn.updates_op
# Add initial dummy values for loss and other metric tensors.
initial_loop_values = {}
initial_loop_values['loss'] = constant_op.constant(1e7)
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
with current_strategy.scope():
# TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
# steps_per_epoch and number of epochs.
ctx = current_strategy.run_steps_on_dataset(
step_fn, iterator, iterations=current_strategy.steps_per_run,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
output_tensors = ctx.last_step_outputs
# Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights()
with current_strategy.scope():
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
assert steps_per_epoch is not None
# TODO(priyag, sourabhbajaj): Add callbacks support.
# TODO(priyag, sourabhbajaj): Add validation.
for epoch in range(initial_epoch, epochs):
for step_index in range(
0, steps_per_epoch, current_strategy.steps_per_run):
try:
_, outs = K.get_session().run([train_op, output_tensors])
# TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
# summaries through callbacks.
print('Epoch: {}, step_index: {}, loss: {}'.format(
epoch, step_index, outs['loss']))
for label, out in outs.items():
print(label, ': ', out)
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your dataset '
'can generate at least `steps_per_epoch * epochs` '
'batches (in this case, %d batches).' %
steps_per_epoch * epochs)
break
# Copy the weights back from the replicated model to the original model.
with current_strategy.scope():
updated_weights = current_strategy.unwrap(
model._grouped_model)[0].get_weights()
model.set_weights(updated_weights)
K.get_session().run(current_strategy.finalize())
# TODO(priyag, sourabhbajaj): Return history.
def test_loop(model, iterator, verbose=0, steps=None): def test_loop(model, iterator, verbose=0, steps=None):
"""evaluate method to validate a model that uses DistributionStrategy. """evaluate method to validate a model that uses DistributionStrategy.
@ -373,12 +511,12 @@ def predict_loop(model, iterator, verbose=0, steps=None):
] ]
def _clone_and_build_model(model): def _clone_and_build_model(model, inputs=None, targets=None):
"""Clone and build the given keras_model.""" """Clone and build the given keras_model."""
# We need to set the import here since we run into a circular dependency # We need to set the import here since we run into a circular dependency
# error. # error.
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
cloned_model = models.clone_model(model, input_tensors=None) cloned_model = models.clone_model(model, input_tensors=inputs)
# Compile and build model. # Compile and build model.
if isinstance(model.optimizer, optimizers.TFOptimizer): if isinstance(model.optimizer, optimizers.TFOptimizer):
@ -387,22 +525,29 @@ def _clone_and_build_model(model):
optimizer_config = model.optimizer.get_config() optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config) optimizer = model.optimizer.__class__.from_config(optimizer_config)
# TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
# single tensor should be OK but it throws an error in that case.
if (targets is not None and not isinstance(targets, list) and
not isinstance(targets, dict)):
targets = [targets]
cloned_model.compile( cloned_model.compile(
optimizer, optimizer,
model.loss, model.loss,
metrics=model.metrics, metrics=model.metrics,
loss_weights=model.loss_weights, loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode, sample_weight_mode=model.sample_weight_mode,
weighted_metrics=model.weighted_metrics) weighted_metrics=model.weighted_metrics,
target_tensors=targets)
return cloned_model return cloned_model
def clone_model_on_towers(model, strategy, make_callback_model=False): def clone_model_on_towers(
model, strategy, make_callback_model=False, inputs=None, targets=None):
"""Create a cloned model on each tower, unless already created.""" """Create a cloned model on each tower, unless already created."""
if not model._grouped_model: if not model._grouped_model:
with strategy.scope(): with strategy.scope():
model._grouped_model = strategy.call_for_each_tower( model._grouped_model = strategy.call_for_each_tower(
_clone_and_build_model, model) _clone_and_build_model, model, inputs, targets)
if make_callback_model: if make_callback_model:
model._make_callback_model() model._make_callback_model()