When doing validation within fit using distributed strategy, do not copy the weight over.
PiperOrigin-RevId: 224393616
This commit is contained in:
parent
a6e4af0a5a
commit
e33bca07de
@ -1273,12 +1273,6 @@ class TestDistributionStrategyCorrectness(test.TestCase,
|
||||
# TODO(b/119257215): use the default one once the flakyness is fixed.
|
||||
tolerance = 1e-4
|
||||
|
||||
if (use_validation_data and
|
||||
not isinstance(distribution, tpu_strategy.TPUStrategy)):
|
||||
# TODO(b/120435565): Enable tests with use_validation_data once the
|
||||
# the underlying bug is fixed.
|
||||
return
|
||||
|
||||
keras.backend.set_image_data_format('channels_last')
|
||||
np.random.seed(_RANDOM_SEED)
|
||||
random_seed.set_random_seed(_RANDOM_SEED)
|
||||
|
@ -138,6 +138,7 @@ def model_iteration(model,
|
||||
steps_per_epoch=None,
|
||||
validation_steps=None,
|
||||
mode='train',
|
||||
validation_in_fit=False,
|
||||
**kwargs):
|
||||
"""Loop function for arrays of data with modes 'train'/'test'/'predict'.
|
||||
|
||||
@ -164,6 +165,9 @@ def model_iteration(model,
|
||||
validation_steps: Number of steps to run validation for (only if doing
|
||||
validation from data tensors). Ignored with the default value of `None`.
|
||||
mode: One of 'train'/'test'/'predict'.
|
||||
validation_in_fit: if true, then this method is invoked from within
|
||||
training iteration (for validation). In this case, do not copy weights
|
||||
when using a tf.distribute.Strategy.
|
||||
**kwargs: Additional arguments for backwards compatibility.
|
||||
|
||||
Returns:
|
||||
@ -230,8 +234,9 @@ def model_iteration(model,
|
||||
aggregator = training_utils.MetricsAggregator(use_steps,
|
||||
num_samples_or_steps)
|
||||
|
||||
if model._distribution_strategy:
|
||||
training_distributed._copy_weights_to_distributed_model(model)
|
||||
if model._distribution_strategy and not validation_in_fit:
|
||||
training_distributed._copy_weights_to_distributed_model(
|
||||
model, model._grouped_model)
|
||||
|
||||
callbacks.model.stop_training = False
|
||||
callbacks._call_begin_hook(mode)
|
||||
@ -356,7 +361,8 @@ def model_iteration(model,
|
||||
steps_per_epoch=validation_steps,
|
||||
callbacks=callbacks,
|
||||
verbose=0,
|
||||
mode='test')
|
||||
mode='test',
|
||||
validation_in_fit=True)
|
||||
if not isinstance(val_results, list):
|
||||
val_results = [val_results]
|
||||
epoch_logs.update(
|
||||
@ -367,7 +373,10 @@ def model_iteration(model,
|
||||
callbacks._call_end_hook(mode)
|
||||
|
||||
if model._distribution_strategy:
|
||||
training_distributed._copy_weights_to_original_model(model, mode)
|
||||
if not validation_in_fit:
|
||||
training_distributed._copy_weights_to_original_model(
|
||||
model, model._grouped_model, mode)
|
||||
|
||||
scope.__exit__(None, None, None)
|
||||
|
||||
if mode == 'train':
|
||||
|
@ -163,11 +163,9 @@ def experimental_fit_loop(model,
|
||||
do_validation = bool(validation_steps)
|
||||
|
||||
# Copy the weights from the original model to each of the replicated models.
|
||||
orig_model_weights = model.get_weights()
|
||||
with current_strategy.scope():
|
||||
distributed_model = current_strategy.unwrap(model._grouped_model_train)[0]
|
||||
distributed_training_utils.set_weights(
|
||||
current_strategy, distributed_model, orig_model_weights)
|
||||
_copy_weights_to_distributed_model(model, model._grouped_model_train)
|
||||
|
||||
callbacks = cbks.configure_callbacks(
|
||||
callbacks,
|
||||
model,
|
||||
@ -217,9 +215,8 @@ def experimental_fit_loop(model,
|
||||
# Since we create a new clone from the original model we need to copy
|
||||
# the weights back to the original model before we can run validation.
|
||||
with current_strategy.scope():
|
||||
updated_weights = current_strategy.unwrap(
|
||||
model._grouped_model_train)[0].get_weights()
|
||||
model.set_weights(updated_weights)
|
||||
_copy_weights_to_original_model(model, model._grouped_model_train,
|
||||
'train')
|
||||
|
||||
val_outs = experimental_test_loop( # pylint: disable=undefined-variable
|
||||
model,
|
||||
@ -240,9 +237,7 @@ def experimental_fit_loop(model,
|
||||
|
||||
# Copy the weights back from the replicated model to the original model.
|
||||
with current_strategy.scope():
|
||||
updated_weights = current_strategy.unwrap(
|
||||
model._grouped_model_train)[0].get_weights()
|
||||
model.set_weights(updated_weights)
|
||||
_copy_weights_to_original_model(model, model._grouped_model_train, 'train')
|
||||
|
||||
K.get_session().run(current_strategy.finalize())
|
||||
return model.history
|
||||
@ -345,11 +340,8 @@ def experimental_test_loop(model,
|
||||
progbar = Progbar(target=steps)
|
||||
|
||||
# Copy the weights from the original model to each of the replicated models.
|
||||
orig_model_weights = model.get_weights()
|
||||
with current_strategy.scope():
|
||||
distributed_model = current_strategy.unwrap(model._grouped_model_test)[0]
|
||||
distributed_training_utils.set_weights(
|
||||
current_strategy, distributed_model, orig_model_weights)
|
||||
_copy_weights_to_distributed_model(model, model._grouped_model_test)
|
||||
|
||||
assert steps is not None
|
||||
outs = [0.] * len(model.metrics_names)
|
||||
@ -455,11 +447,8 @@ def experimental_predict_loop(model, iterator, verbose=0, steps=None):
|
||||
progbar = Progbar(target=steps)
|
||||
|
||||
# Copy the weights from the original model to each of the replicated models.
|
||||
orig_model_weights = model.get_weights()
|
||||
with current_strategy.scope():
|
||||
distributed_model = current_strategy.unwrap(model._grouped_model_predict)[0]
|
||||
distributed_training_utils.set_weights(
|
||||
current_strategy, distributed_model, orig_model_weights)
|
||||
_copy_weights_to_distributed_model(model, model._grouped_model_predict)
|
||||
|
||||
assert steps is not None
|
||||
# Since we do not know how many samples we will see, we cannot pre-allocate
|
||||
@ -695,22 +684,23 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
||||
return ins
|
||||
|
||||
|
||||
def _copy_weights_to_distributed_model(model):
|
||||
def _copy_weights_to_distributed_model(original_model, grouped_model):
|
||||
"""Copies weights from original model to distributed models."""
|
||||
if model._distribution_strategy:
|
||||
# Copy the weights from the original model to each of the replicated models.
|
||||
orig_model_weights = model.get_weights()
|
||||
distributed_model = model._distribution_strategy.unwrap(
|
||||
model._grouped_model)[0]
|
||||
distributed_training_utils.set_weights(
|
||||
model._distribution_strategy, distributed_model, orig_model_weights)
|
||||
strategy = original_model._distribution_strategy
|
||||
if strategy:
|
||||
# Copy the weights from the original model to each of the replicated
|
||||
# models.
|
||||
orig_model_weights = original_model.get_weights()
|
||||
distributed_model = strategy.unwrap(grouped_model)[0]
|
||||
distributed_training_utils.set_weights(strategy, distributed_model,
|
||||
orig_model_weights)
|
||||
|
||||
|
||||
def _copy_weights_to_original_model(model, mode):
|
||||
def _copy_weights_to_original_model(model, grouped_model, mode):
|
||||
"""Copies weights from first distributed model back to original model."""
|
||||
if model._distribution_strategy and mode == 'train':
|
||||
updated_weights = model._distribution_strategy.unwrap(
|
||||
model._grouped_model)[0].get_weights()
|
||||
grouped_model)[0].get_weights()
|
||||
model.set_weights(updated_weights)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user