When doing validation within fit using distributed strategy, do not copy the weight over.

PiperOrigin-RevId: 224393616
This commit is contained in:
Shining Sun 2018-12-06 13:11:25 -08:00 committed by TensorFlower Gardener
parent a6e4af0a5a
commit e33bca07de
3 changed files with 31 additions and 38 deletions

View File

@ -1273,12 +1273,6 @@ class TestDistributionStrategyCorrectness(test.TestCase,
# TODO(b/119257215): use the default one once the flakyness is fixed. # TODO(b/119257215): use the default one once the flakyness is fixed.
tolerance = 1e-4 tolerance = 1e-4
if (use_validation_data and
not isinstance(distribution, tpu_strategy.TPUStrategy)):
# TODO(b/120435565): Enable tests with use_validation_data once the
# the underlying bug is fixed.
return
keras.backend.set_image_data_format('channels_last') keras.backend.set_image_data_format('channels_last')
np.random.seed(_RANDOM_SEED) np.random.seed(_RANDOM_SEED)
random_seed.set_random_seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED)

View File

@ -138,6 +138,7 @@ def model_iteration(model,
steps_per_epoch=None, steps_per_epoch=None,
validation_steps=None, validation_steps=None,
mode='train', mode='train',
validation_in_fit=False,
**kwargs): **kwargs):
"""Loop function for arrays of data with modes 'train'/'test'/'predict'. """Loop function for arrays of data with modes 'train'/'test'/'predict'.
@ -164,6 +165,9 @@ def model_iteration(model,
validation_steps: Number of steps to run validation for (only if doing validation_steps: Number of steps to run validation for (only if doing
validation from data tensors). Ignored with the default value of `None`. validation from data tensors). Ignored with the default value of `None`.
mode: One of 'train'/'test'/'predict'. mode: One of 'train'/'test'/'predict'.
validation_in_fit: if true, then this method is invoked from within
training iteration (for validation). In this case, do not copy weights
when using a tf.distribute.Strategy.
**kwargs: Additional arguments for backwards compatibility. **kwargs: Additional arguments for backwards compatibility.
Returns: Returns:
@ -230,8 +234,9 @@ def model_iteration(model,
aggregator = training_utils.MetricsAggregator(use_steps, aggregator = training_utils.MetricsAggregator(use_steps,
num_samples_or_steps) num_samples_or_steps)
if model._distribution_strategy: if model._distribution_strategy and not validation_in_fit:
training_distributed._copy_weights_to_distributed_model(model) training_distributed._copy_weights_to_distributed_model(
model, model._grouped_model)
callbacks.model.stop_training = False callbacks.model.stop_training = False
callbacks._call_begin_hook(mode) callbacks._call_begin_hook(mode)
@ -356,7 +361,8 @@ def model_iteration(model,
steps_per_epoch=validation_steps, steps_per_epoch=validation_steps,
callbacks=callbacks, callbacks=callbacks,
verbose=0, verbose=0,
mode='test') mode='test',
validation_in_fit=True)
if not isinstance(val_results, list): if not isinstance(val_results, list):
val_results = [val_results] val_results = [val_results]
epoch_logs.update( epoch_logs.update(
@ -367,7 +373,10 @@ def model_iteration(model,
callbacks._call_end_hook(mode) callbacks._call_end_hook(mode)
if model._distribution_strategy: if model._distribution_strategy:
training_distributed._copy_weights_to_original_model(model, mode) if not validation_in_fit:
training_distributed._copy_weights_to_original_model(
model, model._grouped_model, mode)
scope.__exit__(None, None, None) scope.__exit__(None, None, None)
if mode == 'train': if mode == 'train':

View File

@ -163,11 +163,9 @@ def experimental_fit_loop(model,
do_validation = bool(validation_steps) do_validation = bool(validation_steps)
# Copy the weights from the original model to each of the replicated models. # Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights()
with current_strategy.scope(): with current_strategy.scope():
distributed_model = current_strategy.unwrap(model._grouped_model_train)[0] _copy_weights_to_distributed_model(model, model._grouped_model_train)
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
callbacks = cbks.configure_callbacks( callbacks = cbks.configure_callbacks(
callbacks, callbacks,
model, model,
@ -217,9 +215,8 @@ def experimental_fit_loop(model,
# Since we create a new clone from the original model we need to copy # Since we create a new clone from the original model we need to copy
# the weights back to the original model before we can run validation. # the weights back to the original model before we can run validation.
with current_strategy.scope(): with current_strategy.scope():
updated_weights = current_strategy.unwrap( _copy_weights_to_original_model(model, model._grouped_model_train,
model._grouped_model_train)[0].get_weights() 'train')
model.set_weights(updated_weights)
val_outs = experimental_test_loop( # pylint: disable=undefined-variable val_outs = experimental_test_loop( # pylint: disable=undefined-variable
model, model,
@ -240,9 +237,7 @@ def experimental_fit_loop(model,
# Copy the weights back from the replicated model to the original model. # Copy the weights back from the replicated model to the original model.
with current_strategy.scope(): with current_strategy.scope():
updated_weights = current_strategy.unwrap( _copy_weights_to_original_model(model, model._grouped_model_train, 'train')
model._grouped_model_train)[0].get_weights()
model.set_weights(updated_weights)
K.get_session().run(current_strategy.finalize()) K.get_session().run(current_strategy.finalize())
return model.history return model.history
@ -345,11 +340,8 @@ def experimental_test_loop(model,
progbar = Progbar(target=steps) progbar = Progbar(target=steps)
# Copy the weights from the original model to each of the replicated models. # Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights()
with current_strategy.scope(): with current_strategy.scope():
distributed_model = current_strategy.unwrap(model._grouped_model_test)[0] _copy_weights_to_distributed_model(model, model._grouped_model_test)
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
assert steps is not None assert steps is not None
outs = [0.] * len(model.metrics_names) outs = [0.] * len(model.metrics_names)
@ -455,11 +447,8 @@ def experimental_predict_loop(model, iterator, verbose=0, steps=None):
progbar = Progbar(target=steps) progbar = Progbar(target=steps)
# Copy the weights from the original model to each of the replicated models. # Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights()
with current_strategy.scope(): with current_strategy.scope():
distributed_model = current_strategy.unwrap(model._grouped_model_predict)[0] _copy_weights_to_distributed_model(model, model._grouped_model_predict)
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
assert steps is not None assert steps is not None
# Since we do not know how many samples we will see, we cannot pre-allocate # Since we do not know how many samples we will see, we cannot pre-allocate
@ -695,22 +684,23 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
return ins return ins
def _copy_weights_to_distributed_model(model): def _copy_weights_to_distributed_model(original_model, grouped_model):
"""Copies weights from original model to distributed models.""" """Copies weights from original model to distributed models."""
if model._distribution_strategy: strategy = original_model._distribution_strategy
# Copy the weights from the original model to each of the replicated models. if strategy:
orig_model_weights = model.get_weights() # Copy the weights from the original model to each of the replicated
distributed_model = model._distribution_strategy.unwrap( # models.
model._grouped_model)[0] orig_model_weights = original_model.get_weights()
distributed_training_utils.set_weights( distributed_model = strategy.unwrap(grouped_model)[0]
model._distribution_strategy, distributed_model, orig_model_weights) distributed_training_utils.set_weights(strategy, distributed_model,
orig_model_weights)
def _copy_weights_to_original_model(model, mode): def _copy_weights_to_original_model(model, grouped_model, mode):
"""Copies weights from first distributed model back to original model.""" """Copies weights from first distributed model back to original model."""
if model._distribution_strategy and mode == 'train': if model._distribution_strategy and mode == 'train':
updated_weights = model._distribution_strategy.unwrap( updated_weights = model._distribution_strategy.unwrap(
model._grouped_model)[0].get_weights() grouped_model)[0].get_weights()
model.set_weights(updated_weights) model.set_weights(updated_weights)