move v1-specific distributed_training_utils to distributed_training_utils_v1.py

PiperOrigin-RevId: 333120339
Change-Id: I52ee3d95d95165a7f56fcc43f791011ef7fae154
This commit is contained in:
Tomer Kaftan 2020-09-22 11:16:57 -07:00 committed by TensorFlower Gardener
parent ef2c14d030
commit a151f928f3
9 changed files with 1211 additions and 1182 deletions

View File

@ -29,6 +29,7 @@ py_library(
srcs = [
"__init__.py",
"distributed_training_utils.py",
"distributed_training_utils_v1.py",
],
srcs_version = "PY2AND3",
deps = [

View File

@ -40,6 +40,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import distributed_training_utils_v1
from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision.experimental import policy
@ -363,13 +364,13 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
with self.cached_session():
# Default global batch size 32 for input with 64 samples run in 2 steps
steps, batch_size = distributed_training_utils.get_input_params(
steps, batch_size = distributed_training_utils_v1.get_input_params(
distribution, 64, steps=None, batch_size=None)
self.assertEqual(batch_size, 32 // replica_scale_factor)
self.assertEqual(steps, 2)
# Computed global batch size 20 is lower than 32 if we pass less samples.
steps, batch_size = distributed_training_utils.get_input_params(
steps, batch_size = distributed_training_utils_v1.get_input_params(
distribution, 20, steps=None, batch_size=None)
self.assertEqual(batch_size, 20 // replica_scale_factor)
self.assertEqual(steps, 1)
@ -385,27 +386,27 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
with self.cached_session():
# Computed global batch size is correct for number of specified 1 step
steps, batch_size = distributed_training_utils.get_input_params(
steps, batch_size = distributed_training_utils_v1.get_input_params(
distribution, 64, steps=1, batch_size=None)
self.assertEqual(batch_size, 64 // replica_scale_factor)
self.assertEqual(steps, 1)
# Computed global batch size is correct for number of specified 2 steps
steps, batch_size = distributed_training_utils.get_input_params(
steps, batch_size = distributed_training_utils_v1.get_input_params(
distribution, 64, steps=2, batch_size=None)
self.assertEqual(batch_size, 32 // replica_scale_factor)
self.assertEqual(steps, 2)
# All samples can not be consumed in specified number of steps
with self.assertRaisesRegex(ValueError, 'not divisible by steps'):
distributed_training_utils.get_input_params(
distributed_training_utils_v1.get_input_params(
distribution, 63, steps=2, batch_size=None)
# This cases is different for different strategies due to the
# difference in supported batch size being global or per-replica.
if replica_scale_factor == 1:
# Computed global batch size is correct even if not sharadable
steps, batch_size = distributed_training_utils.get_input_params(
steps, batch_size = distributed_training_utils_v1.get_input_params(
distribution, 63, steps=3, batch_size=None)
self.assertEqual(batch_size, 21)
self.assertEqual(steps, 3)
@ -414,7 +415,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
with self.assertRaisesRegex(
ValueError, 'could not be sharded evenly '
'across the sync replicas'):
distributed_training_utils.get_input_params(
distributed_training_utils_v1.get_input_params(
distribution, 63, steps=1, batch_size=None)
@ds_combinations.generate(all_strategy_combinations())
@ -428,13 +429,13 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
with self.cached_session():
# Computed steps is correct for specified batch size
steps, batch_size = distributed_training_utils.get_input_params(
steps, batch_size = distributed_training_utils_v1.get_input_params(
distribution, 64, steps=None, batch_size=16)
self.assertEqual(batch_size, 16)
self.assertEqual(steps, 4 // replica_scale_factor)
# Computed steps is correct for specified batch size
steps, batch_size = distributed_training_utils.get_input_params(
steps, batch_size = distributed_training_utils_v1.get_input_params(
distribution, 64, steps=None, batch_size=32)
self.assertEqual(batch_size, 32)
self.assertEqual(steps, 2 // replica_scale_factor)
@ -444,14 +445,14 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
self, distribution):
with self.cached_session():
# No change to steps and batch size if both specified and feasible
steps, batch_size = distributed_training_utils.get_input_params(
steps, batch_size = distributed_training_utils_v1.get_input_params(
distribution, 64, steps=5, batch_size=3)
self.assertEqual(batch_size, 3)
self.assertEqual(steps, 5)
# Number of samples is less than global batch size * steps
with self.assertRaisesRegex(ValueError, 'less than samples required'):
distributed_training_utils.get_input_params(
distributed_training_utils_v1.get_input_params(
distribution, 64, steps=10, batch_size=13)
@ds_combinations.generate(all_strategy_combinations())

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import callbacks
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import distributed_training_utils_v1
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.platform import test
from tensorflow.python.training import adam as v1_adam
@ -39,7 +39,7 @@ class DistributedTrainingUtilsTest(test.TestCase):
callbacks.RemoteMonitor()
]
distributed_training_utils.validate_callbacks(
distributed_training_utils_v1.validate_callbacks(
supported_predefined_callbacks, adam.Adam())
unsupported_predefined_callbacks = [
@ -50,8 +50,8 @@ class DistributedTrainingUtilsTest(test.TestCase):
for callback in unsupported_predefined_callbacks:
with self.assertRaisesRegex(ValueError,
'You must specify a Keras Optimizer V2'):
distributed_training_utils.validate_callbacks([callback],
v1_adam.AdamOptimizer())
distributed_training_utils_v1.validate_callbacks(
[callback], v1_adam.AdamOptimizer())
if __name__ == '__main__':

File diff suppressed because it is too large Load Diff

View File

@ -35,7 +35,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import losses
from tensorflow.python.keras.distribute import distribute_strategy_test as keras_test_lib
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import distributed_training_utils_v1
from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
@ -203,7 +203,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
'distributed tensor inputs '
'DistributedValues:.+'):
with distribution.scope():
distributed_training_utils.validate_distributed_dataset_inputs(
distributed_training_utils_v1.validate_distributed_dataset_inputs(
distribution, x, y)
@ds_combinations.generate(
@ -227,7 +227,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
'distributed tensor inputs '
'DistributedValues:.+'):
with distribution.scope():
distributed_training_utils.validate_distributed_dataset_inputs(
distributed_training_utils_v1.validate_distributed_dataset_inputs(
distribution, x, y)
@ds_combinations.generate(

View File

@ -29,7 +29,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import distributed_training_utils_v1
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils.generic_utils import make_batches
from tensorflow.python.keras.utils.generic_utils import slice_arrays
@ -145,7 +145,7 @@ def model_iteration(model,
# Enter tf.distribute.Strategy scope.
if model._distribution_strategy:
scope = distributed_training_utils.distributed_scope(
scope = distributed_training_utils_v1.distributed_scope(
strategy=model._distribution_strategy,
learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
scope.__enter__()
@ -251,7 +251,8 @@ def model_iteration(model,
steps=steps_per_epoch)
if model._compile_distribution:
distributed_training_utils._copy_weights_to_distributed_model(model, mode)
distributed_training_utils_v1._copy_weights_to_distributed_model(
model, mode)
callbacks.model.stop_training = False
callbacks._call_begin_hook(mode)
@ -288,9 +289,9 @@ def model_iteration(model,
# Get outputs.
try:
# `ins` can be callable in tf.distribute.Strategy + eager case.
if not callable(ins) or (
model._distribution_strategy and
not distributed_training_utils.is_distributing_by_cloning(model)):
if not callable(ins) or (model._distribution_strategy and
not distributed_training_utils_v1
.is_distributing_by_cloning(model)):
actual_inputs = ins
else:
actual_inputs = ins()
@ -329,8 +330,9 @@ def model_iteration(model,
batch_outs = [batch_outs]
if model._distribution_strategy:
batch_outs = distributed_training_utils._per_replica_aggregate_batch(
model._distribution_strategy, batch_outs, model, mode)
batch_outs = (
distributed_training_utils_v1._per_replica_aggregate_batch(
model._distribution_strategy, batch_outs, model, mode))
# Aggregate results.
if step == 0:
@ -413,7 +415,7 @@ def model_iteration(model,
if model._compile_distribution:
# Since we create a new clone from the original model we need to copy
# the weights back to the original model before we can run validation.
distributed_training_utils._copy_weights_to_original_model(
distributed_training_utils_v1._copy_weights_to_original_model(
model, ModeKeys.TRAIN)
val_results = model_iteration(
@ -450,7 +452,7 @@ def model_iteration(model,
if model._distribution_strategy:
if model._compile_distribution:
# TODO(priyag, psv): Copy back metrics to the original model as well?
distributed_training_utils._copy_weights_to_original_model(model, mode)
distributed_training_utils_v1._copy_weights_to_original_model(model, mode)
scope.__exit__(None, None, None)
if mode == ModeKeys.TRAIN:
@ -500,11 +502,11 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
"""
if model._distribution_strategy:
if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
inputs = distributed_training_utils.get_iterator(
inputs = distributed_training_utils_v1.get_iterator(
inputs, model._distribution_strategy)
def get_distributed_inputs():
return distributed_training_utils._prepare_feed_values(
return distributed_training_utils_v1._prepare_feed_values(
model, inputs, targets, sample_weights, mode)
# In the eager case, we want to call the input method per step, so return
@ -537,14 +539,14 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
def _get_iterator(inputs, distribution_strategy=None):
if distribution_strategy:
return distributed_training_utils.get_iterator(
return distributed_training_utils_v1.get_iterator(
inputs, distribution_strategy)
return training_utils.get_iterator(inputs)
def _reinitialize_iterator(iterator, distribution_strategy=None):
if distribution_strategy:
distributed_training_utils.initialize_iterator(
distributed_training_utils_v1.initialize_iterator(
iterator, distribution_strategy)
else:
training_utils.initialize_iterator(iterator)
@ -553,7 +555,7 @@ def _reinitialize_iterator(iterator, distribution_strategy=None):
def _make_execution_function(model, mode):
"""Makes function to run one step of model execution."""
if model._distribution_strategy:
return distributed_training_utils._make_execution_function(model, mode)
return distributed_training_utils_v1._make_execution_function(model, mode)
return model._make_execution_function(mode)
@ -580,8 +582,8 @@ def _update_sample_weight_mode(model, mode, inputs):
# Call the DistributionStrategy specific function to update the
# sample_weight_mode on the model.
if model._distribution_strategy:
distributed_training_utils._update_sample_weight_modes(model, mode,
sample_weights)
distributed_training_utils_v1._update_sample_weight_modes(model, mode,
sample_weights)
# For backwards compatibility for internal users of these loops.
fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN)

View File

@ -31,7 +31,8 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils_v2
from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
from tensorflow.python.keras.engine import training_arrays_v1
from tensorflow.python.keras.engine import training_utils
@ -648,7 +649,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
raise ValueError('validation_split argument is not supported with '
'distribution strategies.')
if dist_utils.is_tpu_strategy(model._distribution_strategy):
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
steps_per_epoch = training_utils.infer_steps_for_dataset(
model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
if steps_per_epoch is None:
@ -705,7 +706,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
batch_size=batch_size,
allow_partial_batch=True)
if dist_utils.is_tpu_strategy(model._distribution_strategy):
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
steps = training_utils.infer_steps_for_dataset(
model, dataset, steps, steps_name='steps')
if steps is None:
@ -742,7 +743,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
x,
batch_size=batch_size,
allow_partial_batch=True)
if dist_utils.is_tpu_strategy(model._distribution_strategy):
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
steps = training_utils.infer_steps_for_dataset(
model, dataset, steps, steps_name='steps')
if steps is None:

View File

@ -42,6 +42,7 @@ from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import distributed_training_utils_v1
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training as training_lib
from tensorflow.python.keras.engine import training_arrays_v1
@ -993,7 +994,7 @@ class Model(training_lib.Model):
# Reset metrics on all the distributed (cloned) models.
if self._distribution_strategy:
distributed_training_utils._reset_metrics(self) # pylint: disable=protected-access
distributed_training_utils_v1._reset_metrics(self) # pylint: disable=protected-access
def train_on_batch(self,
x,
@ -1398,7 +1399,7 @@ class Model(training_lib.Model):
'We currently do not support enabling `run_eagerly` with '
'distribution strategy.')
if (distributed_training_utils.is_distributing_by_cloning(self) and
if (distributed_training_utils_v1.is_distributing_by_cloning(self) and
(not self.built or not self.inputs or not self.outputs)):
raise ValueError(
'We currently do not support distribution strategy with a '
@ -2856,7 +2857,7 @@ class DistributedCallbackModel(Model):
self._original_model.load_weights(filepath, by_name=False)
# Copy the weights from the original model to each of the replicated models.
orig_model_weights = self._original_model.get_weights()
distributed_training_utils.set_weights(
distributed_training_utils_v1.set_weights(
self._original_model._distribution_strategy, self, # pylint: disable=protected-access
orig_model_weights)