When doing validation within fit using distributed strategy, do not copy the weight over.
PiperOrigin-RevId: 224393616
This commit is contained in:
parent
a6e4af0a5a
commit
e33bca07de
@ -1273,12 +1273,6 @@ class TestDistributionStrategyCorrectness(test.TestCase,
|
|||||||
# TODO(b/119257215): use the default one once the flakyness is fixed.
|
# TODO(b/119257215): use the default one once the flakyness is fixed.
|
||||||
tolerance = 1e-4
|
tolerance = 1e-4
|
||||||
|
|
||||||
if (use_validation_data and
|
|
||||||
not isinstance(distribution, tpu_strategy.TPUStrategy)):
|
|
||||||
# TODO(b/120435565): Enable tests with use_validation_data once the
|
|
||||||
# the underlying bug is fixed.
|
|
||||||
return
|
|
||||||
|
|
||||||
keras.backend.set_image_data_format('channels_last')
|
keras.backend.set_image_data_format('channels_last')
|
||||||
np.random.seed(_RANDOM_SEED)
|
np.random.seed(_RANDOM_SEED)
|
||||||
random_seed.set_random_seed(_RANDOM_SEED)
|
random_seed.set_random_seed(_RANDOM_SEED)
|
||||||
|
@ -138,6 +138,7 @@ def model_iteration(model,
|
|||||||
steps_per_epoch=None,
|
steps_per_epoch=None,
|
||||||
validation_steps=None,
|
validation_steps=None,
|
||||||
mode='train',
|
mode='train',
|
||||||
|
validation_in_fit=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Loop function for arrays of data with modes 'train'/'test'/'predict'.
|
"""Loop function for arrays of data with modes 'train'/'test'/'predict'.
|
||||||
|
|
||||||
@ -164,6 +165,9 @@ def model_iteration(model,
|
|||||||
validation_steps: Number of steps to run validation for (only if doing
|
validation_steps: Number of steps to run validation for (only if doing
|
||||||
validation from data tensors). Ignored with the default value of `None`.
|
validation from data tensors). Ignored with the default value of `None`.
|
||||||
mode: One of 'train'/'test'/'predict'.
|
mode: One of 'train'/'test'/'predict'.
|
||||||
|
validation_in_fit: if true, then this method is invoked from within
|
||||||
|
training iteration (for validation). In this case, do not copy weights
|
||||||
|
when using a tf.distribute.Strategy.
|
||||||
**kwargs: Additional arguments for backwards compatibility.
|
**kwargs: Additional arguments for backwards compatibility.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -230,8 +234,9 @@ def model_iteration(model,
|
|||||||
aggregator = training_utils.MetricsAggregator(use_steps,
|
aggregator = training_utils.MetricsAggregator(use_steps,
|
||||||
num_samples_or_steps)
|
num_samples_or_steps)
|
||||||
|
|
||||||
if model._distribution_strategy:
|
if model._distribution_strategy and not validation_in_fit:
|
||||||
training_distributed._copy_weights_to_distributed_model(model)
|
training_distributed._copy_weights_to_distributed_model(
|
||||||
|
model, model._grouped_model)
|
||||||
|
|
||||||
callbacks.model.stop_training = False
|
callbacks.model.stop_training = False
|
||||||
callbacks._call_begin_hook(mode)
|
callbacks._call_begin_hook(mode)
|
||||||
@ -356,7 +361,8 @@ def model_iteration(model,
|
|||||||
steps_per_epoch=validation_steps,
|
steps_per_epoch=validation_steps,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
verbose=0,
|
verbose=0,
|
||||||
mode='test')
|
mode='test',
|
||||||
|
validation_in_fit=True)
|
||||||
if not isinstance(val_results, list):
|
if not isinstance(val_results, list):
|
||||||
val_results = [val_results]
|
val_results = [val_results]
|
||||||
epoch_logs.update(
|
epoch_logs.update(
|
||||||
@ -367,7 +373,10 @@ def model_iteration(model,
|
|||||||
callbacks._call_end_hook(mode)
|
callbacks._call_end_hook(mode)
|
||||||
|
|
||||||
if model._distribution_strategy:
|
if model._distribution_strategy:
|
||||||
training_distributed._copy_weights_to_original_model(model, mode)
|
if not validation_in_fit:
|
||||||
|
training_distributed._copy_weights_to_original_model(
|
||||||
|
model, model._grouped_model, mode)
|
||||||
|
|
||||||
scope.__exit__(None, None, None)
|
scope.__exit__(None, None, None)
|
||||||
|
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
|
@ -163,11 +163,9 @@ def experimental_fit_loop(model,
|
|||||||
do_validation = bool(validation_steps)
|
do_validation = bool(validation_steps)
|
||||||
|
|
||||||
# Copy the weights from the original model to each of the replicated models.
|
# Copy the weights from the original model to each of the replicated models.
|
||||||
orig_model_weights = model.get_weights()
|
|
||||||
with current_strategy.scope():
|
with current_strategy.scope():
|
||||||
distributed_model = current_strategy.unwrap(model._grouped_model_train)[0]
|
_copy_weights_to_distributed_model(model, model._grouped_model_train)
|
||||||
distributed_training_utils.set_weights(
|
|
||||||
current_strategy, distributed_model, orig_model_weights)
|
|
||||||
callbacks = cbks.configure_callbacks(
|
callbacks = cbks.configure_callbacks(
|
||||||
callbacks,
|
callbacks,
|
||||||
model,
|
model,
|
||||||
@ -217,9 +215,8 @@ def experimental_fit_loop(model,
|
|||||||
# Since we create a new clone from the original model we need to copy
|
# Since we create a new clone from the original model we need to copy
|
||||||
# the weights back to the original model before we can run validation.
|
# the weights back to the original model before we can run validation.
|
||||||
with current_strategy.scope():
|
with current_strategy.scope():
|
||||||
updated_weights = current_strategy.unwrap(
|
_copy_weights_to_original_model(model, model._grouped_model_train,
|
||||||
model._grouped_model_train)[0].get_weights()
|
'train')
|
||||||
model.set_weights(updated_weights)
|
|
||||||
|
|
||||||
val_outs = experimental_test_loop( # pylint: disable=undefined-variable
|
val_outs = experimental_test_loop( # pylint: disable=undefined-variable
|
||||||
model,
|
model,
|
||||||
@ -240,9 +237,7 @@ def experimental_fit_loop(model,
|
|||||||
|
|
||||||
# Copy the weights back from the replicated model to the original model.
|
# Copy the weights back from the replicated model to the original model.
|
||||||
with current_strategy.scope():
|
with current_strategy.scope():
|
||||||
updated_weights = current_strategy.unwrap(
|
_copy_weights_to_original_model(model, model._grouped_model_train, 'train')
|
||||||
model._grouped_model_train)[0].get_weights()
|
|
||||||
model.set_weights(updated_weights)
|
|
||||||
|
|
||||||
K.get_session().run(current_strategy.finalize())
|
K.get_session().run(current_strategy.finalize())
|
||||||
return model.history
|
return model.history
|
||||||
@ -345,11 +340,8 @@ def experimental_test_loop(model,
|
|||||||
progbar = Progbar(target=steps)
|
progbar = Progbar(target=steps)
|
||||||
|
|
||||||
# Copy the weights from the original model to each of the replicated models.
|
# Copy the weights from the original model to each of the replicated models.
|
||||||
orig_model_weights = model.get_weights()
|
|
||||||
with current_strategy.scope():
|
with current_strategy.scope():
|
||||||
distributed_model = current_strategy.unwrap(model._grouped_model_test)[0]
|
_copy_weights_to_distributed_model(model, model._grouped_model_test)
|
||||||
distributed_training_utils.set_weights(
|
|
||||||
current_strategy, distributed_model, orig_model_weights)
|
|
||||||
|
|
||||||
assert steps is not None
|
assert steps is not None
|
||||||
outs = [0.] * len(model.metrics_names)
|
outs = [0.] * len(model.metrics_names)
|
||||||
@ -455,11 +447,8 @@ def experimental_predict_loop(model, iterator, verbose=0, steps=None):
|
|||||||
progbar = Progbar(target=steps)
|
progbar = Progbar(target=steps)
|
||||||
|
|
||||||
# Copy the weights from the original model to each of the replicated models.
|
# Copy the weights from the original model to each of the replicated models.
|
||||||
orig_model_weights = model.get_weights()
|
|
||||||
with current_strategy.scope():
|
with current_strategy.scope():
|
||||||
distributed_model = current_strategy.unwrap(model._grouped_model_predict)[0]
|
_copy_weights_to_distributed_model(model, model._grouped_model_predict)
|
||||||
distributed_training_utils.set_weights(
|
|
||||||
current_strategy, distributed_model, orig_model_weights)
|
|
||||||
|
|
||||||
assert steps is not None
|
assert steps is not None
|
||||||
# Since we do not know how many samples we will see, we cannot pre-allocate
|
# Since we do not know how many samples we will see, we cannot pre-allocate
|
||||||
@ -695,22 +684,23 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
|||||||
return ins
|
return ins
|
||||||
|
|
||||||
|
|
||||||
def _copy_weights_to_distributed_model(model):
|
def _copy_weights_to_distributed_model(original_model, grouped_model):
|
||||||
"""Copies weights from original model to distributed models."""
|
"""Copies weights from original model to distributed models."""
|
||||||
if model._distribution_strategy:
|
strategy = original_model._distribution_strategy
|
||||||
# Copy the weights from the original model to each of the replicated models.
|
if strategy:
|
||||||
orig_model_weights = model.get_weights()
|
# Copy the weights from the original model to each of the replicated
|
||||||
distributed_model = model._distribution_strategy.unwrap(
|
# models.
|
||||||
model._grouped_model)[0]
|
orig_model_weights = original_model.get_weights()
|
||||||
distributed_training_utils.set_weights(
|
distributed_model = strategy.unwrap(grouped_model)[0]
|
||||||
model._distribution_strategy, distributed_model, orig_model_weights)
|
distributed_training_utils.set_weights(strategy, distributed_model,
|
||||||
|
orig_model_weights)
|
||||||
|
|
||||||
|
|
||||||
def _copy_weights_to_original_model(model, mode):
|
def _copy_weights_to_original_model(model, grouped_model, mode):
|
||||||
"""Copies weights from first distributed model back to original model."""
|
"""Copies weights from first distributed model back to original model."""
|
||||||
if model._distribution_strategy and mode == 'train':
|
if model._distribution_strategy and mode == 'train':
|
||||||
updated_weights = model._distribution_strategy.unwrap(
|
updated_weights = model._distribution_strategy.unwrap(
|
||||||
model._grouped_model)[0].get_weights()
|
grouped_model)[0].get_weights()
|
||||||
model.set_weights(updated_weights)
|
model.set_weights(updated_weights)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user