Experimental work-in-progress support for TPUStrategy in keras.
PiperOrigin-RevId: 211705274
This commit is contained in:
parent
d6c6a759b6
commit
40e262c0dc
@ -58,13 +58,13 @@ def get_input_datasets():
|
||||
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||
train_ds = train_ds.repeat()
|
||||
train_ds = train_ds.shuffle(100)
|
||||
train_ds = train_ds.batch(64)
|
||||
train_ds = train_ds.batch(64, drop_remainder=True)
|
||||
|
||||
# eval dataset
|
||||
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
||||
eval_ds = eval_ds.repeat()
|
||||
eval_ds = eval_ds.shuffle(100)
|
||||
eval_ds = eval_ds.batch(64)
|
||||
eval_ds = eval_ds.batch(64, drop_remainder=True)
|
||||
|
||||
return train_ds, eval_ds, input_shape
|
||||
|
||||
|
@ -19,13 +19,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.engine import distributed_training_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
|
||||
|
||||
def fit_loop(
|
||||
@ -64,6 +67,11 @@ def fit_loop(
|
||||
"""
|
||||
current_strategy = model._distribution_strategy
|
||||
|
||||
# TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
|
||||
if current_strategy.__class__.__name__ == 'TPUStrategy':
|
||||
return _experimental_fit_loop(
|
||||
model, iterator, epochs, initial_epoch, steps_per_epoch)
|
||||
|
||||
clone_model_on_towers(
|
||||
model, current_strategy, make_callback_model=True)
|
||||
|
||||
@ -116,11 +124,6 @@ def fit_loop(
|
||||
do_validation = False
|
||||
if validation_steps:
|
||||
do_validation = True
|
||||
if steps_per_epoch is None:
|
||||
raise ValueError('Can only use `validation_steps` '
|
||||
'when doing step-wise '
|
||||
'training, i.e. `steps_per_epoch` '
|
||||
'must be set.')
|
||||
|
||||
# Copy the weights from the original model to each of the replicated models.
|
||||
orig_model_weights = model.get_weights()
|
||||
@ -140,44 +143,46 @@ def fit_loop(
|
||||
verbose=verbose)
|
||||
out_labels = model.metrics_names or []
|
||||
callbacks.on_train_begin()
|
||||
|
||||
assert steps_per_epoch is not None
|
||||
|
||||
for epoch in range(initial_epoch, epochs):
|
||||
callbacks.on_epoch_begin(epoch)
|
||||
if steps_per_epoch is not None:
|
||||
epoch_logs = {}
|
||||
for step_index in range(steps_per_epoch):
|
||||
batch_logs = {'batch': step_index, 'size': 1}
|
||||
callbacks.on_batch_begin(step_index, batch_logs)
|
||||
try:
|
||||
outs = distributed_train_function(ins)
|
||||
except errors.OutOfRangeError:
|
||||
logging.warning('Your dataset iterator ran out of data; '
|
||||
'interrupting training. Make sure that your dataset '
|
||||
'can generate at least `steps_per_epoch * epochs` '
|
||||
'batches (in this case, %d batches).' %
|
||||
steps_per_epoch * epochs)
|
||||
break
|
||||
epoch_logs = {}
|
||||
for step_index in range(steps_per_epoch):
|
||||
batch_logs = {'batch': step_index, 'size': 1}
|
||||
callbacks.on_batch_begin(step_index, batch_logs)
|
||||
try:
|
||||
outs = distributed_train_function(ins)
|
||||
except errors.OutOfRangeError:
|
||||
logging.warning('Your dataset iterator ran out of data; '
|
||||
'interrupting training. Make sure that your dataset '
|
||||
'can generate at least `steps_per_epoch * epochs` '
|
||||
'batches (in this case, %d batches).' %
|
||||
steps_per_epoch * epochs)
|
||||
break
|
||||
|
||||
if not isinstance(outs, list):
|
||||
outs = [outs]
|
||||
if not isinstance(outs, list):
|
||||
outs = [outs]
|
||||
|
||||
outs = _aggregate_metrics_across_towers(
|
||||
current_strategy.num_towers, out_labels, outs)
|
||||
for l, o in zip(out_labels, outs):
|
||||
batch_logs[l] = o
|
||||
callbacks.on_batch_end(step_index, batch_logs)
|
||||
if callbacks.model.stop_training:
|
||||
break
|
||||
if do_validation:
|
||||
val_outs = test_loop(
|
||||
model,
|
||||
val_iterator,
|
||||
steps=validation_steps,
|
||||
verbose=0)
|
||||
if not isinstance(val_outs, list):
|
||||
val_outs = [val_outs]
|
||||
# Same labels assumed.
|
||||
for l, o in zip(out_labels, val_outs):
|
||||
epoch_logs['val_' + l] = o
|
||||
outs = _aggregate_metrics_across_towers(
|
||||
current_strategy.num_towers, out_labels, outs)
|
||||
for l, o in zip(out_labels, outs):
|
||||
batch_logs[l] = o
|
||||
callbacks.on_batch_end(step_index, batch_logs)
|
||||
if callbacks.model.stop_training:
|
||||
break
|
||||
if do_validation:
|
||||
val_outs = test_loop(
|
||||
model,
|
||||
val_iterator,
|
||||
steps=validation_steps,
|
||||
verbose=0)
|
||||
if not isinstance(val_outs, list):
|
||||
val_outs = [val_outs]
|
||||
# Same labels assumed.
|
||||
for l, o in zip(out_labels, val_outs):
|
||||
epoch_logs['val_' + l] = o
|
||||
|
||||
callbacks.on_epoch_end(epoch, epoch_logs)
|
||||
if callbacks.model.stop_training:
|
||||
@ -192,6 +197,139 @@ def fit_loop(
|
||||
return model.history
|
||||
|
||||
|
||||
def _experimental_fit_loop(
|
||||
model,
|
||||
iterator,
|
||||
epochs=100,
|
||||
initial_epoch=0,
|
||||
steps_per_epoch=None):
|
||||
"""fit function when using TPU DistributionStrategy for training.
|
||||
|
||||
Arguments:
|
||||
model: Keras Model instance.
|
||||
iterator: Iterator that returns inputs and targets
|
||||
epochs: Number of times to iterate over the data
|
||||
initial_epoch: Epoch at which to start training
|
||||
(useful for resuming a previous training run)
|
||||
steps_per_epoch: Total number of steps (batches of samples)
|
||||
before declaring one epoch finished and starting the
|
||||
next epoch. Ignored with the default value of `None`.
|
||||
|
||||
Returns:
|
||||
Returns `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: in case of invalid arguments.
|
||||
"""
|
||||
current_strategy = model._distribution_strategy
|
||||
|
||||
# TODO(priyag): Add validation that shapes are fully defined for TPU case.
|
||||
|
||||
# TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
|
||||
K.get_session().run(current_strategy.initialize())
|
||||
|
||||
def _per_device_train_function(model):
|
||||
model._make_train_function()
|
||||
return (model.train_function.inputs,
|
||||
model.train_function.outputs,
|
||||
model.train_function.updates_op,
|
||||
model.train_function.session_kwargs)
|
||||
|
||||
# TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
|
||||
K.set_learning_phase(1)
|
||||
|
||||
def step_fn(ctx, inputs, targets):
|
||||
"""Clones the model and calls make_train_function."""
|
||||
# TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
|
||||
clone_model_on_towers(
|
||||
model,
|
||||
current_strategy,
|
||||
make_callback_model=True,
|
||||
inputs=inputs,
|
||||
targets=targets)
|
||||
|
||||
(grouped_inputs, grouped_outputs, grouped_updates,
|
||||
grouped_session_args) = current_strategy.call_for_each_tower(
|
||||
_per_device_train_function, model._grouped_model)
|
||||
(all_inputs, all_outputs, all_updates,
|
||||
all_session_args) = distributed_training_utils.unwrap_values(
|
||||
current_strategy, grouped_inputs, grouped_outputs,
|
||||
grouped_updates, grouped_session_args, with_loss_tensor=True)
|
||||
combined_fn = K.Function(
|
||||
all_inputs, all_outputs,
|
||||
updates=all_updates,
|
||||
name='distributed_train_function',
|
||||
**all_session_args)
|
||||
|
||||
# TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
|
||||
# something else for different outputs.
|
||||
out_labels = model.metrics_names or []
|
||||
for label, output in zip(out_labels, combined_fn.outputs):
|
||||
ctx.set_last_step_output(label, output,
|
||||
aggregation=distribute_lib.get_loss_reduction())
|
||||
|
||||
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
|
||||
# feed_dict, session kwargs, run options, run_metadata for now. These should
|
||||
# be handled appropriately
|
||||
return combined_fn.updates_op
|
||||
|
||||
# Add initial dummy values for loss and other metric tensors.
|
||||
initial_loop_values = {}
|
||||
initial_loop_values['loss'] = constant_op.constant(1e7)
|
||||
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
|
||||
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
|
||||
|
||||
with current_strategy.scope():
|
||||
# TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
|
||||
# steps_per_epoch and number of epochs.
|
||||
ctx = current_strategy.run_steps_on_dataset(
|
||||
step_fn, iterator, iterations=current_strategy.steps_per_run,
|
||||
initial_loop_values=initial_loop_values)
|
||||
|
||||
train_op = ctx.run_op
|
||||
output_tensors = ctx.last_step_outputs
|
||||
|
||||
# Copy the weights from the original model to each of the replicated models.
|
||||
orig_model_weights = model.get_weights()
|
||||
with current_strategy.scope():
|
||||
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
|
||||
distributed_training_utils.set_weights(
|
||||
current_strategy, distributed_model, orig_model_weights)
|
||||
|
||||
assert steps_per_epoch is not None
|
||||
|
||||
# TODO(priyag, sourabhbajaj): Add callbacks support.
|
||||
# TODO(priyag, sourabhbajaj): Add validation.
|
||||
for epoch in range(initial_epoch, epochs):
|
||||
for step_index in range(
|
||||
0, steps_per_epoch, current_strategy.steps_per_run):
|
||||
try:
|
||||
_, outs = K.get_session().run([train_op, output_tensors])
|
||||
# TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
|
||||
# summaries through callbacks.
|
||||
print('Epoch: {}, step_index: {}, loss: {}'.format(
|
||||
epoch, step_index, outs['loss']))
|
||||
for label, out in outs.items():
|
||||
print(label, ': ', out)
|
||||
except errors.OutOfRangeError:
|
||||
logging.warning('Your dataset iterator ran out of data; '
|
||||
'interrupting training. Make sure that your dataset '
|
||||
'can generate at least `steps_per_epoch * epochs` '
|
||||
'batches (in this case, %d batches).' %
|
||||
steps_per_epoch * epochs)
|
||||
break
|
||||
|
||||
# Copy the weights back from the replicated model to the original model.
|
||||
with current_strategy.scope():
|
||||
updated_weights = current_strategy.unwrap(
|
||||
model._grouped_model)[0].get_weights()
|
||||
model.set_weights(updated_weights)
|
||||
|
||||
K.get_session().run(current_strategy.finalize())
|
||||
|
||||
# TODO(priyag, sourabhbajaj): Return history.
|
||||
|
||||
|
||||
def test_loop(model, iterator, verbose=0, steps=None):
|
||||
"""evaluate method to validate a model that uses DistributionStrategy.
|
||||
|
||||
@ -373,12 +511,12 @@ def predict_loop(model, iterator, verbose=0, steps=None):
|
||||
]
|
||||
|
||||
|
||||
def _clone_and_build_model(model):
|
||||
def _clone_and_build_model(model, inputs=None, targets=None):
|
||||
"""Clone and build the given keras_model."""
|
||||
# We need to set the import here since we run into a circular dependency
|
||||
# error.
|
||||
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
||||
cloned_model = models.clone_model(model, input_tensors=None)
|
||||
cloned_model = models.clone_model(model, input_tensors=inputs)
|
||||
|
||||
# Compile and build model.
|
||||
if isinstance(model.optimizer, optimizers.TFOptimizer):
|
||||
@ -387,22 +525,29 @@ def _clone_and_build_model(model):
|
||||
optimizer_config = model.optimizer.get_config()
|
||||
optimizer = model.optimizer.__class__.from_config(optimizer_config)
|
||||
|
||||
# TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
|
||||
# single tensor should be OK but it throws an error in that case.
|
||||
if (targets is not None and not isinstance(targets, list) and
|
||||
not isinstance(targets, dict)):
|
||||
targets = [targets]
|
||||
cloned_model.compile(
|
||||
optimizer,
|
||||
model.loss,
|
||||
metrics=model.metrics,
|
||||
loss_weights=model.loss_weights,
|
||||
sample_weight_mode=model.sample_weight_mode,
|
||||
weighted_metrics=model.weighted_metrics)
|
||||
weighted_metrics=model.weighted_metrics,
|
||||
target_tensors=targets)
|
||||
return cloned_model
|
||||
|
||||
|
||||
def clone_model_on_towers(model, strategy, make_callback_model=False):
|
||||
def clone_model_on_towers(
|
||||
model, strategy, make_callback_model=False, inputs=None, targets=None):
|
||||
"""Create a cloned model on each tower, unless already created."""
|
||||
if not model._grouped_model:
|
||||
with strategy.scope():
|
||||
model._grouped_model = strategy.call_for_each_tower(
|
||||
_clone_and_build_model, model)
|
||||
_clone_and_build_model, model, inputs, targets)
|
||||
if make_callback_model:
|
||||
model._make_callback_model()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user