move v1-specific distributed_training_utils to distributed_training_utils_v1.py
PiperOrigin-RevId: 333120339 Change-Id: I52ee3d95d95165a7f56fcc43f791011ef7fae154
This commit is contained in:
parent
ef2c14d030
commit
a151f928f3
@ -29,6 +29,7 @@ py_library(
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"distributed_training_utils.py",
|
||||
"distributed_training_utils_v1.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||
from tensorflow.python.keras.distribute import optimizer_combinations
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
@ -363,13 +364,13 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
|
||||
with self.cached_session():
|
||||
# Default global batch size 32 for input with 64 samples run in 2 steps
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||
distribution, 64, steps=None, batch_size=None)
|
||||
self.assertEqual(batch_size, 32 // replica_scale_factor)
|
||||
self.assertEqual(steps, 2)
|
||||
|
||||
# Computed global batch size 20 is lower than 32 if we pass less samples.
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||
distribution, 20, steps=None, batch_size=None)
|
||||
self.assertEqual(batch_size, 20 // replica_scale_factor)
|
||||
self.assertEqual(steps, 1)
|
||||
@ -385,27 +386,27 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
|
||||
with self.cached_session():
|
||||
# Computed global batch size is correct for number of specified 1 step
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||
distribution, 64, steps=1, batch_size=None)
|
||||
self.assertEqual(batch_size, 64 // replica_scale_factor)
|
||||
self.assertEqual(steps, 1)
|
||||
|
||||
# Computed global batch size is correct for number of specified 2 steps
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||
distribution, 64, steps=2, batch_size=None)
|
||||
self.assertEqual(batch_size, 32 // replica_scale_factor)
|
||||
self.assertEqual(steps, 2)
|
||||
|
||||
# All samples can not be consumed in specified number of steps
|
||||
with self.assertRaisesRegex(ValueError, 'not divisible by steps'):
|
||||
distributed_training_utils.get_input_params(
|
||||
distributed_training_utils_v1.get_input_params(
|
||||
distribution, 63, steps=2, batch_size=None)
|
||||
|
||||
# This cases is different for different strategies due to the
|
||||
# difference in supported batch size being global or per-replica.
|
||||
if replica_scale_factor == 1:
|
||||
# Computed global batch size is correct even if not sharadable
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||
distribution, 63, steps=3, batch_size=None)
|
||||
self.assertEqual(batch_size, 21)
|
||||
self.assertEqual(steps, 3)
|
||||
@ -414,7 +415,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'could not be sharded evenly '
|
||||
'across the sync replicas'):
|
||||
distributed_training_utils.get_input_params(
|
||||
distributed_training_utils_v1.get_input_params(
|
||||
distribution, 63, steps=1, batch_size=None)
|
||||
|
||||
@ds_combinations.generate(all_strategy_combinations())
|
||||
@ -428,13 +429,13 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
|
||||
with self.cached_session():
|
||||
# Computed steps is correct for specified batch size
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||
distribution, 64, steps=None, batch_size=16)
|
||||
self.assertEqual(batch_size, 16)
|
||||
self.assertEqual(steps, 4 // replica_scale_factor)
|
||||
|
||||
# Computed steps is correct for specified batch size
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||
distribution, 64, steps=None, batch_size=32)
|
||||
self.assertEqual(batch_size, 32)
|
||||
self.assertEqual(steps, 2 // replica_scale_factor)
|
||||
@ -444,14 +445,14 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
self, distribution):
|
||||
with self.cached_session():
|
||||
# No change to steps and batch size if both specified and feasible
|
||||
steps, batch_size = distributed_training_utils.get_input_params(
|
||||
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||
distribution, 64, steps=5, batch_size=3)
|
||||
self.assertEqual(batch_size, 3)
|
||||
self.assertEqual(steps, 5)
|
||||
|
||||
# Number of samples is less than global batch size * steps
|
||||
with self.assertRaisesRegex(ValueError, 'less than samples required'):
|
||||
distributed_training_utils.get_input_params(
|
||||
distributed_training_utils_v1.get_input_params(
|
||||
distribution, 64, steps=10, batch_size=13)
|
||||
|
||||
@ds_combinations.generate(all_strategy_combinations())
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam as v1_adam
|
||||
@ -39,7 +39,7 @@ class DistributedTrainingUtilsTest(test.TestCase):
|
||||
callbacks.RemoteMonitor()
|
||||
]
|
||||
|
||||
distributed_training_utils.validate_callbacks(
|
||||
distributed_training_utils_v1.validate_callbacks(
|
||||
supported_predefined_callbacks, adam.Adam())
|
||||
|
||||
unsupported_predefined_callbacks = [
|
||||
@ -50,8 +50,8 @@ class DistributedTrainingUtilsTest(test.TestCase):
|
||||
for callback in unsupported_predefined_callbacks:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'You must specify a Keras Optimizer V2'):
|
||||
distributed_training_utils.validate_callbacks([callback],
|
||||
v1_adam.AdamOptimizer())
|
||||
distributed_training_utils_v1.validate_callbacks(
|
||||
[callback], v1_adam.AdamOptimizer())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
1158
tensorflow/python/keras/distribute/distributed_training_utils_v1.py
Normal file
1158
tensorflow/python/keras/distribute/distributed_training_utils_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -35,7 +35,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras.distribute import distribute_strategy_test as keras_test_lib
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||
from tensorflow.python.keras.distribute import optimizer_combinations
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
@ -203,7 +203,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
|
||||
'distributed tensor inputs '
|
||||
'DistributedValues:.+'):
|
||||
with distribution.scope():
|
||||
distributed_training_utils.validate_distributed_dataset_inputs(
|
||||
distributed_training_utils_v1.validate_distributed_dataset_inputs(
|
||||
distribution, x, y)
|
||||
|
||||
@ds_combinations.generate(
|
||||
@ -227,7 +227,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
|
||||
'distributed tensor inputs '
|
||||
'DistributedValues:.+'):
|
||||
with distribution.scope():
|
||||
distributed_training_utils.validate_distributed_dataset_inputs(
|
||||
distributed_training_utils_v1.validate_distributed_dataset_inputs(
|
||||
distribution, x, y)
|
||||
|
||||
@ds_combinations.generate(
|
||||
|
@ -29,7 +29,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import make_batches
|
||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||
@ -145,7 +145,7 @@ def model_iteration(model,
|
||||
|
||||
# Enter tf.distribute.Strategy scope.
|
||||
if model._distribution_strategy:
|
||||
scope = distributed_training_utils.distributed_scope(
|
||||
scope = distributed_training_utils_v1.distributed_scope(
|
||||
strategy=model._distribution_strategy,
|
||||
learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
|
||||
scope.__enter__()
|
||||
@ -251,7 +251,8 @@ def model_iteration(model,
|
||||
steps=steps_per_epoch)
|
||||
|
||||
if model._compile_distribution:
|
||||
distributed_training_utils._copy_weights_to_distributed_model(model, mode)
|
||||
distributed_training_utils_v1._copy_weights_to_distributed_model(
|
||||
model, mode)
|
||||
|
||||
callbacks.model.stop_training = False
|
||||
callbacks._call_begin_hook(mode)
|
||||
@ -288,9 +289,9 @@ def model_iteration(model,
|
||||
# Get outputs.
|
||||
try:
|
||||
# `ins` can be callable in tf.distribute.Strategy + eager case.
|
||||
if not callable(ins) or (
|
||||
model._distribution_strategy and
|
||||
not distributed_training_utils.is_distributing_by_cloning(model)):
|
||||
if not callable(ins) or (model._distribution_strategy and
|
||||
not distributed_training_utils_v1
|
||||
.is_distributing_by_cloning(model)):
|
||||
actual_inputs = ins
|
||||
else:
|
||||
actual_inputs = ins()
|
||||
@ -329,8 +330,9 @@ def model_iteration(model,
|
||||
batch_outs = [batch_outs]
|
||||
|
||||
if model._distribution_strategy:
|
||||
batch_outs = distributed_training_utils._per_replica_aggregate_batch(
|
||||
model._distribution_strategy, batch_outs, model, mode)
|
||||
batch_outs = (
|
||||
distributed_training_utils_v1._per_replica_aggregate_batch(
|
||||
model._distribution_strategy, batch_outs, model, mode))
|
||||
|
||||
# Aggregate results.
|
||||
if step == 0:
|
||||
@ -413,7 +415,7 @@ def model_iteration(model,
|
||||
if model._compile_distribution:
|
||||
# Since we create a new clone from the original model we need to copy
|
||||
# the weights back to the original model before we can run validation.
|
||||
distributed_training_utils._copy_weights_to_original_model(
|
||||
distributed_training_utils_v1._copy_weights_to_original_model(
|
||||
model, ModeKeys.TRAIN)
|
||||
|
||||
val_results = model_iteration(
|
||||
@ -450,7 +452,7 @@ def model_iteration(model,
|
||||
if model._distribution_strategy:
|
||||
if model._compile_distribution:
|
||||
# TODO(priyag, psv): Copy back metrics to the original model as well?
|
||||
distributed_training_utils._copy_weights_to_original_model(model, mode)
|
||||
distributed_training_utils_v1._copy_weights_to_original_model(model, mode)
|
||||
scope.__exit__(None, None, None)
|
||||
|
||||
if mode == ModeKeys.TRAIN:
|
||||
@ -500,11 +502,11 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
||||
"""
|
||||
if model._distribution_strategy:
|
||||
if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
|
||||
inputs = distributed_training_utils.get_iterator(
|
||||
inputs = distributed_training_utils_v1.get_iterator(
|
||||
inputs, model._distribution_strategy)
|
||||
|
||||
def get_distributed_inputs():
|
||||
return distributed_training_utils._prepare_feed_values(
|
||||
return distributed_training_utils_v1._prepare_feed_values(
|
||||
model, inputs, targets, sample_weights, mode)
|
||||
|
||||
# In the eager case, we want to call the input method per step, so return
|
||||
@ -537,14 +539,14 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
||||
|
||||
def _get_iterator(inputs, distribution_strategy=None):
|
||||
if distribution_strategy:
|
||||
return distributed_training_utils.get_iterator(
|
||||
return distributed_training_utils_v1.get_iterator(
|
||||
inputs, distribution_strategy)
|
||||
return training_utils.get_iterator(inputs)
|
||||
|
||||
|
||||
def _reinitialize_iterator(iterator, distribution_strategy=None):
|
||||
if distribution_strategy:
|
||||
distributed_training_utils.initialize_iterator(
|
||||
distributed_training_utils_v1.initialize_iterator(
|
||||
iterator, distribution_strategy)
|
||||
else:
|
||||
training_utils.initialize_iterator(iterator)
|
||||
@ -553,7 +555,7 @@ def _reinitialize_iterator(iterator, distribution_strategy=None):
|
||||
def _make_execution_function(model, mode):
|
||||
"""Makes function to run one step of model execution."""
|
||||
if model._distribution_strategy:
|
||||
return distributed_training_utils._make_execution_function(model, mode)
|
||||
return distributed_training_utils_v1._make_execution_function(model, mode)
|
||||
return model._make_execution_function(mode)
|
||||
|
||||
|
||||
@ -580,8 +582,8 @@ def _update_sample_weight_mode(model, mode, inputs):
|
||||
# Call the DistributionStrategy specific function to update the
|
||||
# sample_weight_mode on the model.
|
||||
if model._distribution_strategy:
|
||||
distributed_training_utils._update_sample_weight_modes(model, mode,
|
||||
sample_weights)
|
||||
distributed_training_utils_v1._update_sample_weight_modes(model, mode,
|
||||
sample_weights)
|
||||
|
||||
# For backwards compatibility for internal users of these loops.
|
||||
fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
|
||||
|
@ -31,7 +31,8 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils_v2
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
|
||||
from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
|
||||
from tensorflow.python.keras.engine import training_arrays_v1
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
@ -648,7 +649,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
||||
raise ValueError('validation_split argument is not supported with '
|
||||
'distribution strategies.')
|
||||
|
||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||
model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
||||
if steps_per_epoch is None:
|
||||
@ -705,7 +706,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
||||
batch_size=batch_size,
|
||||
allow_partial_batch=True)
|
||||
|
||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||
steps = training_utils.infer_steps_for_dataset(
|
||||
model, dataset, steps, steps_name='steps')
|
||||
if steps is None:
|
||||
@ -742,7 +743,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
||||
x,
|
||||
batch_size=batch_size,
|
||||
allow_partial_batch=True)
|
||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||
steps = training_utils.infer_steps_for_dataset(
|
||||
model, dataset, steps, steps_name='steps')
|
||||
if steps is None:
|
||||
|
@ -42,6 +42,7 @@ from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import training as training_lib
|
||||
from tensorflow.python.keras.engine import training_arrays_v1
|
||||
@ -993,7 +994,7 @@ class Model(training_lib.Model):
|
||||
|
||||
# Reset metrics on all the distributed (cloned) models.
|
||||
if self._distribution_strategy:
|
||||
distributed_training_utils._reset_metrics(self) # pylint: disable=protected-access
|
||||
distributed_training_utils_v1._reset_metrics(self) # pylint: disable=protected-access
|
||||
|
||||
def train_on_batch(self,
|
||||
x,
|
||||
@ -1398,7 +1399,7 @@ class Model(training_lib.Model):
|
||||
'We currently do not support enabling `run_eagerly` with '
|
||||
'distribution strategy.')
|
||||
|
||||
if (distributed_training_utils.is_distributing_by_cloning(self) and
|
||||
if (distributed_training_utils_v1.is_distributing_by_cloning(self) and
|
||||
(not self.built or not self.inputs or not self.outputs)):
|
||||
raise ValueError(
|
||||
'We currently do not support distribution strategy with a '
|
||||
@ -2856,7 +2857,7 @@ class DistributedCallbackModel(Model):
|
||||
self._original_model.load_weights(filepath, by_name=False)
|
||||
# Copy the weights from the original model to each of the replicated models.
|
||||
orig_model_weights = self._original_model.get_weights()
|
||||
distributed_training_utils.set_weights(
|
||||
distributed_training_utils_v1.set_weights(
|
||||
self._original_model._distribution_strategy, self, # pylint: disable=protected-access
|
||||
orig_model_weights)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user