Experimental work-in-progress support for TPUStrategy in keras.
PiperOrigin-RevId: 211705274
This commit is contained in:
parent
d6c6a759b6
commit
40e262c0dc
@ -58,13 +58,13 @@ def get_input_datasets():
|
|||||||
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||||
train_ds = train_ds.repeat()
|
train_ds = train_ds.repeat()
|
||||||
train_ds = train_ds.shuffle(100)
|
train_ds = train_ds.shuffle(100)
|
||||||
train_ds = train_ds.batch(64)
|
train_ds = train_ds.batch(64, drop_remainder=True)
|
||||||
|
|
||||||
# eval dataset
|
# eval dataset
|
||||||
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
||||||
eval_ds = eval_ds.repeat()
|
eval_ds = eval_ds.repeat()
|
||||||
eval_ds = eval_ds.shuffle(100)
|
eval_ds = eval_ds.shuffle(100)
|
||||||
eval_ds = eval_ds.batch(64)
|
eval_ds = eval_ds.batch(64, drop_remainder=True)
|
||||||
|
|
||||||
return train_ds, eval_ds, input_shape
|
return train_ds, eval_ds, input_shape
|
||||||
|
|
||||||
|
@ -19,13 +19,16 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import callbacks as cbks
|
from tensorflow.python.keras import callbacks as cbks
|
||||||
from tensorflow.python.keras import optimizers
|
from tensorflow.python.keras import optimizers
|
||||||
from tensorflow.python.keras.engine import distributed_training_utils
|
from tensorflow.python.keras.engine import distributed_training_utils
|
||||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
|
|
||||||
|
|
||||||
def fit_loop(
|
def fit_loop(
|
||||||
@ -64,6 +67,11 @@ def fit_loop(
|
|||||||
"""
|
"""
|
||||||
current_strategy = model._distribution_strategy
|
current_strategy = model._distribution_strategy
|
||||||
|
|
||||||
|
# TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
|
||||||
|
if current_strategy.__class__.__name__ == 'TPUStrategy':
|
||||||
|
return _experimental_fit_loop(
|
||||||
|
model, iterator, epochs, initial_epoch, steps_per_epoch)
|
||||||
|
|
||||||
clone_model_on_towers(
|
clone_model_on_towers(
|
||||||
model, current_strategy, make_callback_model=True)
|
model, current_strategy, make_callback_model=True)
|
||||||
|
|
||||||
@ -116,11 +124,6 @@ def fit_loop(
|
|||||||
do_validation = False
|
do_validation = False
|
||||||
if validation_steps:
|
if validation_steps:
|
||||||
do_validation = True
|
do_validation = True
|
||||||
if steps_per_epoch is None:
|
|
||||||
raise ValueError('Can only use `validation_steps` '
|
|
||||||
'when doing step-wise '
|
|
||||||
'training, i.e. `steps_per_epoch` '
|
|
||||||
'must be set.')
|
|
||||||
|
|
||||||
# Copy the weights from the original model to each of the replicated models.
|
# Copy the weights from the original model to each of the replicated models.
|
||||||
orig_model_weights = model.get_weights()
|
orig_model_weights = model.get_weights()
|
||||||
@ -140,9 +143,11 @@ def fit_loop(
|
|||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
out_labels = model.metrics_names or []
|
out_labels = model.metrics_names or []
|
||||||
callbacks.on_train_begin()
|
callbacks.on_train_begin()
|
||||||
|
|
||||||
|
assert steps_per_epoch is not None
|
||||||
|
|
||||||
for epoch in range(initial_epoch, epochs):
|
for epoch in range(initial_epoch, epochs):
|
||||||
callbacks.on_epoch_begin(epoch)
|
callbacks.on_epoch_begin(epoch)
|
||||||
if steps_per_epoch is not None:
|
|
||||||
epoch_logs = {}
|
epoch_logs = {}
|
||||||
for step_index in range(steps_per_epoch):
|
for step_index in range(steps_per_epoch):
|
||||||
batch_logs = {'batch': step_index, 'size': 1}
|
batch_logs = {'batch': step_index, 'size': 1}
|
||||||
@ -192,6 +197,139 @@ def fit_loop(
|
|||||||
return model.history
|
return model.history
|
||||||
|
|
||||||
|
|
||||||
|
def _experimental_fit_loop(
|
||||||
|
model,
|
||||||
|
iterator,
|
||||||
|
epochs=100,
|
||||||
|
initial_epoch=0,
|
||||||
|
steps_per_epoch=None):
|
||||||
|
"""fit function when using TPU DistributionStrategy for training.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model: Keras Model instance.
|
||||||
|
iterator: Iterator that returns inputs and targets
|
||||||
|
epochs: Number of times to iterate over the data
|
||||||
|
initial_epoch: Epoch at which to start training
|
||||||
|
(useful for resuming a previous training run)
|
||||||
|
steps_per_epoch: Total number of steps (batches of samples)
|
||||||
|
before declaring one epoch finished and starting the
|
||||||
|
next epoch. Ignored with the default value of `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns `None`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: in case of invalid arguments.
|
||||||
|
"""
|
||||||
|
current_strategy = model._distribution_strategy
|
||||||
|
|
||||||
|
# TODO(priyag): Add validation that shapes are fully defined for TPU case.
|
||||||
|
|
||||||
|
# TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
|
||||||
|
K.get_session().run(current_strategy.initialize())
|
||||||
|
|
||||||
|
def _per_device_train_function(model):
|
||||||
|
model._make_train_function()
|
||||||
|
return (model.train_function.inputs,
|
||||||
|
model.train_function.outputs,
|
||||||
|
model.train_function.updates_op,
|
||||||
|
model.train_function.session_kwargs)
|
||||||
|
|
||||||
|
# TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
|
||||||
|
K.set_learning_phase(1)
|
||||||
|
|
||||||
|
def step_fn(ctx, inputs, targets):
|
||||||
|
"""Clones the model and calls make_train_function."""
|
||||||
|
# TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
|
||||||
|
clone_model_on_towers(
|
||||||
|
model,
|
||||||
|
current_strategy,
|
||||||
|
make_callback_model=True,
|
||||||
|
inputs=inputs,
|
||||||
|
targets=targets)
|
||||||
|
|
||||||
|
(grouped_inputs, grouped_outputs, grouped_updates,
|
||||||
|
grouped_session_args) = current_strategy.call_for_each_tower(
|
||||||
|
_per_device_train_function, model._grouped_model)
|
||||||
|
(all_inputs, all_outputs, all_updates,
|
||||||
|
all_session_args) = distributed_training_utils.unwrap_values(
|
||||||
|
current_strategy, grouped_inputs, grouped_outputs,
|
||||||
|
grouped_updates, grouped_session_args, with_loss_tensor=True)
|
||||||
|
combined_fn = K.Function(
|
||||||
|
all_inputs, all_outputs,
|
||||||
|
updates=all_updates,
|
||||||
|
name='distributed_train_function',
|
||||||
|
**all_session_args)
|
||||||
|
|
||||||
|
# TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
|
||||||
|
# something else for different outputs.
|
||||||
|
out_labels = model.metrics_names or []
|
||||||
|
for label, output in zip(out_labels, combined_fn.outputs):
|
||||||
|
ctx.set_last_step_output(label, output,
|
||||||
|
aggregation=distribute_lib.get_loss_reduction())
|
||||||
|
|
||||||
|
# TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
|
||||||
|
# feed_dict, session kwargs, run options, run_metadata for now. These should
|
||||||
|
# be handled appropriately
|
||||||
|
return combined_fn.updates_op
|
||||||
|
|
||||||
|
# Add initial dummy values for loss and other metric tensors.
|
||||||
|
initial_loop_values = {}
|
||||||
|
initial_loop_values['loss'] = constant_op.constant(1e7)
|
||||||
|
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
|
||||||
|
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
|
||||||
|
|
||||||
|
with current_strategy.scope():
|
||||||
|
# TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
|
||||||
|
# steps_per_epoch and number of epochs.
|
||||||
|
ctx = current_strategy.run_steps_on_dataset(
|
||||||
|
step_fn, iterator, iterations=current_strategy.steps_per_run,
|
||||||
|
initial_loop_values=initial_loop_values)
|
||||||
|
|
||||||
|
train_op = ctx.run_op
|
||||||
|
output_tensors = ctx.last_step_outputs
|
||||||
|
|
||||||
|
# Copy the weights from the original model to each of the replicated models.
|
||||||
|
orig_model_weights = model.get_weights()
|
||||||
|
with current_strategy.scope():
|
||||||
|
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
|
||||||
|
distributed_training_utils.set_weights(
|
||||||
|
current_strategy, distributed_model, orig_model_weights)
|
||||||
|
|
||||||
|
assert steps_per_epoch is not None
|
||||||
|
|
||||||
|
# TODO(priyag, sourabhbajaj): Add callbacks support.
|
||||||
|
# TODO(priyag, sourabhbajaj): Add validation.
|
||||||
|
for epoch in range(initial_epoch, epochs):
|
||||||
|
for step_index in range(
|
||||||
|
0, steps_per_epoch, current_strategy.steps_per_run):
|
||||||
|
try:
|
||||||
|
_, outs = K.get_session().run([train_op, output_tensors])
|
||||||
|
# TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
|
||||||
|
# summaries through callbacks.
|
||||||
|
print('Epoch: {}, step_index: {}, loss: {}'.format(
|
||||||
|
epoch, step_index, outs['loss']))
|
||||||
|
for label, out in outs.items():
|
||||||
|
print(label, ': ', out)
|
||||||
|
except errors.OutOfRangeError:
|
||||||
|
logging.warning('Your dataset iterator ran out of data; '
|
||||||
|
'interrupting training. Make sure that your dataset '
|
||||||
|
'can generate at least `steps_per_epoch * epochs` '
|
||||||
|
'batches (in this case, %d batches).' %
|
||||||
|
steps_per_epoch * epochs)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Copy the weights back from the replicated model to the original model.
|
||||||
|
with current_strategy.scope():
|
||||||
|
updated_weights = current_strategy.unwrap(
|
||||||
|
model._grouped_model)[0].get_weights()
|
||||||
|
model.set_weights(updated_weights)
|
||||||
|
|
||||||
|
K.get_session().run(current_strategy.finalize())
|
||||||
|
|
||||||
|
# TODO(priyag, sourabhbajaj): Return history.
|
||||||
|
|
||||||
|
|
||||||
def test_loop(model, iterator, verbose=0, steps=None):
|
def test_loop(model, iterator, verbose=0, steps=None):
|
||||||
"""evaluate method to validate a model that uses DistributionStrategy.
|
"""evaluate method to validate a model that uses DistributionStrategy.
|
||||||
|
|
||||||
@ -373,12 +511,12 @@ def predict_loop(model, iterator, verbose=0, steps=None):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _clone_and_build_model(model):
|
def _clone_and_build_model(model, inputs=None, targets=None):
|
||||||
"""Clone and build the given keras_model."""
|
"""Clone and build the given keras_model."""
|
||||||
# We need to set the import here since we run into a circular dependency
|
# We need to set the import here since we run into a circular dependency
|
||||||
# error.
|
# error.
|
||||||
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
||||||
cloned_model = models.clone_model(model, input_tensors=None)
|
cloned_model = models.clone_model(model, input_tensors=inputs)
|
||||||
|
|
||||||
# Compile and build model.
|
# Compile and build model.
|
||||||
if isinstance(model.optimizer, optimizers.TFOptimizer):
|
if isinstance(model.optimizer, optimizers.TFOptimizer):
|
||||||
@ -387,22 +525,29 @@ def _clone_and_build_model(model):
|
|||||||
optimizer_config = model.optimizer.get_config()
|
optimizer_config = model.optimizer.get_config()
|
||||||
optimizer = model.optimizer.__class__.from_config(optimizer_config)
|
optimizer = model.optimizer.__class__.from_config(optimizer_config)
|
||||||
|
|
||||||
|
# TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
|
||||||
|
# single tensor should be OK but it throws an error in that case.
|
||||||
|
if (targets is not None and not isinstance(targets, list) and
|
||||||
|
not isinstance(targets, dict)):
|
||||||
|
targets = [targets]
|
||||||
cloned_model.compile(
|
cloned_model.compile(
|
||||||
optimizer,
|
optimizer,
|
||||||
model.loss,
|
model.loss,
|
||||||
metrics=model.metrics,
|
metrics=model.metrics,
|
||||||
loss_weights=model.loss_weights,
|
loss_weights=model.loss_weights,
|
||||||
sample_weight_mode=model.sample_weight_mode,
|
sample_weight_mode=model.sample_weight_mode,
|
||||||
weighted_metrics=model.weighted_metrics)
|
weighted_metrics=model.weighted_metrics,
|
||||||
|
target_tensors=targets)
|
||||||
return cloned_model
|
return cloned_model
|
||||||
|
|
||||||
|
|
||||||
def clone_model_on_towers(model, strategy, make_callback_model=False):
|
def clone_model_on_towers(
|
||||||
|
model, strategy, make_callback_model=False, inputs=None, targets=None):
|
||||||
"""Create a cloned model on each tower, unless already created."""
|
"""Create a cloned model on each tower, unless already created."""
|
||||||
if not model._grouped_model:
|
if not model._grouped_model:
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
model._grouped_model = strategy.call_for_each_tower(
|
model._grouped_model = strategy.call_for_each_tower(
|
||||||
_clone_and_build_model, model)
|
_clone_and_build_model, model, inputs, targets)
|
||||||
if make_callback_model:
|
if make_callback_model:
|
||||||
model._make_callback_model()
|
model._make_callback_model()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user