move v1-specific distributed_training_utils to distributed_training_utils_v1.py
PiperOrigin-RevId: 333120339 Change-Id: I52ee3d95d95165a7f56fcc43f791011ef7fae154
This commit is contained in:
parent
ef2c14d030
commit
a151f928f3
@ -29,6 +29,7 @@ py_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
"distributed_training_utils.py",
|
"distributed_training_utils.py",
|
||||||
|
"distributed_training_utils_v1.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
@ -40,6 +40,7 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.framework import test_combinations as combinations
|
from tensorflow.python.framework import test_combinations as combinations
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||||
|
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||||
from tensorflow.python.keras.distribute import optimizer_combinations
|
from tensorflow.python.keras.distribute import optimizer_combinations
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||||
@ -363,13 +364,13 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
# Default global batch size 32 for input with 64 samples run in 2 steps
|
# Default global batch size 32 for input with 64 samples run in 2 steps
|
||||||
steps, batch_size = distributed_training_utils.get_input_params(
|
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 64, steps=None, batch_size=None)
|
distribution, 64, steps=None, batch_size=None)
|
||||||
self.assertEqual(batch_size, 32 // replica_scale_factor)
|
self.assertEqual(batch_size, 32 // replica_scale_factor)
|
||||||
self.assertEqual(steps, 2)
|
self.assertEqual(steps, 2)
|
||||||
|
|
||||||
# Computed global batch size 20 is lower than 32 if we pass less samples.
|
# Computed global batch size 20 is lower than 32 if we pass less samples.
|
||||||
steps, batch_size = distributed_training_utils.get_input_params(
|
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 20, steps=None, batch_size=None)
|
distribution, 20, steps=None, batch_size=None)
|
||||||
self.assertEqual(batch_size, 20 // replica_scale_factor)
|
self.assertEqual(batch_size, 20 // replica_scale_factor)
|
||||||
self.assertEqual(steps, 1)
|
self.assertEqual(steps, 1)
|
||||||
@ -385,27 +386,27 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
# Computed global batch size is correct for number of specified 1 step
|
# Computed global batch size is correct for number of specified 1 step
|
||||||
steps, batch_size = distributed_training_utils.get_input_params(
|
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 64, steps=1, batch_size=None)
|
distribution, 64, steps=1, batch_size=None)
|
||||||
self.assertEqual(batch_size, 64 // replica_scale_factor)
|
self.assertEqual(batch_size, 64 // replica_scale_factor)
|
||||||
self.assertEqual(steps, 1)
|
self.assertEqual(steps, 1)
|
||||||
|
|
||||||
# Computed global batch size is correct for number of specified 2 steps
|
# Computed global batch size is correct for number of specified 2 steps
|
||||||
steps, batch_size = distributed_training_utils.get_input_params(
|
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 64, steps=2, batch_size=None)
|
distribution, 64, steps=2, batch_size=None)
|
||||||
self.assertEqual(batch_size, 32 // replica_scale_factor)
|
self.assertEqual(batch_size, 32 // replica_scale_factor)
|
||||||
self.assertEqual(steps, 2)
|
self.assertEqual(steps, 2)
|
||||||
|
|
||||||
# All samples can not be consumed in specified number of steps
|
# All samples can not be consumed in specified number of steps
|
||||||
with self.assertRaisesRegex(ValueError, 'not divisible by steps'):
|
with self.assertRaisesRegex(ValueError, 'not divisible by steps'):
|
||||||
distributed_training_utils.get_input_params(
|
distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 63, steps=2, batch_size=None)
|
distribution, 63, steps=2, batch_size=None)
|
||||||
|
|
||||||
# This cases is different for different strategies due to the
|
# This cases is different for different strategies due to the
|
||||||
# difference in supported batch size being global or per-replica.
|
# difference in supported batch size being global or per-replica.
|
||||||
if replica_scale_factor == 1:
|
if replica_scale_factor == 1:
|
||||||
# Computed global batch size is correct even if not sharadable
|
# Computed global batch size is correct even if not sharadable
|
||||||
steps, batch_size = distributed_training_utils.get_input_params(
|
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 63, steps=3, batch_size=None)
|
distribution, 63, steps=3, batch_size=None)
|
||||||
self.assertEqual(batch_size, 21)
|
self.assertEqual(batch_size, 21)
|
||||||
self.assertEqual(steps, 3)
|
self.assertEqual(steps, 3)
|
||||||
@ -414,7 +415,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, 'could not be sharded evenly '
|
ValueError, 'could not be sharded evenly '
|
||||||
'across the sync replicas'):
|
'across the sync replicas'):
|
||||||
distributed_training_utils.get_input_params(
|
distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 63, steps=1, batch_size=None)
|
distribution, 63, steps=1, batch_size=None)
|
||||||
|
|
||||||
@ds_combinations.generate(all_strategy_combinations())
|
@ds_combinations.generate(all_strategy_combinations())
|
||||||
@ -428,13 +429,13 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
# Computed steps is correct for specified batch size
|
# Computed steps is correct for specified batch size
|
||||||
steps, batch_size = distributed_training_utils.get_input_params(
|
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 64, steps=None, batch_size=16)
|
distribution, 64, steps=None, batch_size=16)
|
||||||
self.assertEqual(batch_size, 16)
|
self.assertEqual(batch_size, 16)
|
||||||
self.assertEqual(steps, 4 // replica_scale_factor)
|
self.assertEqual(steps, 4 // replica_scale_factor)
|
||||||
|
|
||||||
# Computed steps is correct for specified batch size
|
# Computed steps is correct for specified batch size
|
||||||
steps, batch_size = distributed_training_utils.get_input_params(
|
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 64, steps=None, batch_size=32)
|
distribution, 64, steps=None, batch_size=32)
|
||||||
self.assertEqual(batch_size, 32)
|
self.assertEqual(batch_size, 32)
|
||||||
self.assertEqual(steps, 2 // replica_scale_factor)
|
self.assertEqual(steps, 2 // replica_scale_factor)
|
||||||
@ -444,14 +445,14 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
self, distribution):
|
self, distribution):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
# No change to steps and batch size if both specified and feasible
|
# No change to steps and batch size if both specified and feasible
|
||||||
steps, batch_size = distributed_training_utils.get_input_params(
|
steps, batch_size = distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 64, steps=5, batch_size=3)
|
distribution, 64, steps=5, batch_size=3)
|
||||||
self.assertEqual(batch_size, 3)
|
self.assertEqual(batch_size, 3)
|
||||||
self.assertEqual(steps, 5)
|
self.assertEqual(steps, 5)
|
||||||
|
|
||||||
# Number of samples is less than global batch size * steps
|
# Number of samples is less than global batch size * steps
|
||||||
with self.assertRaisesRegex(ValueError, 'less than samples required'):
|
with self.assertRaisesRegex(ValueError, 'less than samples required'):
|
||||||
distributed_training_utils.get_input_params(
|
distributed_training_utils_v1.get_input_params(
|
||||||
distribution, 64, steps=10, batch_size=13)
|
distribution, 64, steps=10, batch_size=13)
|
||||||
|
|
||||||
@ds_combinations.generate(all_strategy_combinations())
|
@ds_combinations.generate(all_strategy_combinations())
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.keras import callbacks
|
from tensorflow.python.keras import callbacks
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||||
from tensorflow.python.keras.optimizer_v2 import adam
|
from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import adam as v1_adam
|
from tensorflow.python.training import adam as v1_adam
|
||||||
@ -39,7 +39,7 @@ class DistributedTrainingUtilsTest(test.TestCase):
|
|||||||
callbacks.RemoteMonitor()
|
callbacks.RemoteMonitor()
|
||||||
]
|
]
|
||||||
|
|
||||||
distributed_training_utils.validate_callbacks(
|
distributed_training_utils_v1.validate_callbacks(
|
||||||
supported_predefined_callbacks, adam.Adam())
|
supported_predefined_callbacks, adam.Adam())
|
||||||
|
|
||||||
unsupported_predefined_callbacks = [
|
unsupported_predefined_callbacks = [
|
||||||
@ -50,8 +50,8 @@ class DistributedTrainingUtilsTest(test.TestCase):
|
|||||||
for callback in unsupported_predefined_callbacks:
|
for callback in unsupported_predefined_callbacks:
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(ValueError,
|
||||||
'You must specify a Keras Optimizer V2'):
|
'You must specify a Keras Optimizer V2'):
|
||||||
distributed_training_utils.validate_callbacks([callback],
|
distributed_training_utils_v1.validate_callbacks(
|
||||||
v1_adam.AdamOptimizer())
|
[callback], v1_adam.AdamOptimizer())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
1158
tensorflow/python/keras/distribute/distributed_training_utils_v1.py
Normal file
1158
tensorflow/python/keras/distribute/distributed_training_utils_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -35,7 +35,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import test_combinations as combinations
|
from tensorflow.python.framework import test_combinations as combinations
|
||||||
from tensorflow.python.keras import losses
|
from tensorflow.python.keras import losses
|
||||||
from tensorflow.python.keras.distribute import distribute_strategy_test as keras_test_lib
|
from tensorflow.python.keras.distribute import distribute_strategy_test as keras_test_lib
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||||
from tensorflow.python.keras.distribute import optimizer_combinations
|
from tensorflow.python.keras.distribute import optimizer_combinations
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
@ -203,7 +203,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
|
|||||||
'distributed tensor inputs '
|
'distributed tensor inputs '
|
||||||
'DistributedValues:.+'):
|
'DistributedValues:.+'):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
distributed_training_utils.validate_distributed_dataset_inputs(
|
distributed_training_utils_v1.validate_distributed_dataset_inputs(
|
||||||
distribution, x, y)
|
distribution, x, y)
|
||||||
|
|
||||||
@ds_combinations.generate(
|
@ds_combinations.generate(
|
||||||
@ -227,7 +227,7 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
|
|||||||
'distributed tensor inputs '
|
'distributed tensor inputs '
|
||||||
'DistributedValues:.+'):
|
'DistributedValues:.+'):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
distributed_training_utils.validate_distributed_dataset_inputs(
|
distributed_training_utils_v1.validate_distributed_dataset_inputs(
|
||||||
distribution, x, y)
|
distribution, x, y)
|
||||||
|
|
||||||
@ds_combinations.generate(
|
@ds_combinations.generate(
|
||||||
|
@ -29,7 +29,7 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import callbacks as cbks
|
from tensorflow.python.keras import callbacks as cbks
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.utils.generic_utils import make_batches
|
from tensorflow.python.keras.utils.generic_utils import make_batches
|
||||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||||
@ -145,7 +145,7 @@ def model_iteration(model,
|
|||||||
|
|
||||||
# Enter tf.distribute.Strategy scope.
|
# Enter tf.distribute.Strategy scope.
|
||||||
if model._distribution_strategy:
|
if model._distribution_strategy:
|
||||||
scope = distributed_training_utils.distributed_scope(
|
scope = distributed_training_utils_v1.distributed_scope(
|
||||||
strategy=model._distribution_strategy,
|
strategy=model._distribution_strategy,
|
||||||
learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
|
learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
|
||||||
scope.__enter__()
|
scope.__enter__()
|
||||||
@ -251,7 +251,8 @@ def model_iteration(model,
|
|||||||
steps=steps_per_epoch)
|
steps=steps_per_epoch)
|
||||||
|
|
||||||
if model._compile_distribution:
|
if model._compile_distribution:
|
||||||
distributed_training_utils._copy_weights_to_distributed_model(model, mode)
|
distributed_training_utils_v1._copy_weights_to_distributed_model(
|
||||||
|
model, mode)
|
||||||
|
|
||||||
callbacks.model.stop_training = False
|
callbacks.model.stop_training = False
|
||||||
callbacks._call_begin_hook(mode)
|
callbacks._call_begin_hook(mode)
|
||||||
@ -288,9 +289,9 @@ def model_iteration(model,
|
|||||||
# Get outputs.
|
# Get outputs.
|
||||||
try:
|
try:
|
||||||
# `ins` can be callable in tf.distribute.Strategy + eager case.
|
# `ins` can be callable in tf.distribute.Strategy + eager case.
|
||||||
if not callable(ins) or (
|
if not callable(ins) or (model._distribution_strategy and
|
||||||
model._distribution_strategy and
|
not distributed_training_utils_v1
|
||||||
not distributed_training_utils.is_distributing_by_cloning(model)):
|
.is_distributing_by_cloning(model)):
|
||||||
actual_inputs = ins
|
actual_inputs = ins
|
||||||
else:
|
else:
|
||||||
actual_inputs = ins()
|
actual_inputs = ins()
|
||||||
@ -329,8 +330,9 @@ def model_iteration(model,
|
|||||||
batch_outs = [batch_outs]
|
batch_outs = [batch_outs]
|
||||||
|
|
||||||
if model._distribution_strategy:
|
if model._distribution_strategy:
|
||||||
batch_outs = distributed_training_utils._per_replica_aggregate_batch(
|
batch_outs = (
|
||||||
model._distribution_strategy, batch_outs, model, mode)
|
distributed_training_utils_v1._per_replica_aggregate_batch(
|
||||||
|
model._distribution_strategy, batch_outs, model, mode))
|
||||||
|
|
||||||
# Aggregate results.
|
# Aggregate results.
|
||||||
if step == 0:
|
if step == 0:
|
||||||
@ -413,7 +415,7 @@ def model_iteration(model,
|
|||||||
if model._compile_distribution:
|
if model._compile_distribution:
|
||||||
# Since we create a new clone from the original model we need to copy
|
# Since we create a new clone from the original model we need to copy
|
||||||
# the weights back to the original model before we can run validation.
|
# the weights back to the original model before we can run validation.
|
||||||
distributed_training_utils._copy_weights_to_original_model(
|
distributed_training_utils_v1._copy_weights_to_original_model(
|
||||||
model, ModeKeys.TRAIN)
|
model, ModeKeys.TRAIN)
|
||||||
|
|
||||||
val_results = model_iteration(
|
val_results = model_iteration(
|
||||||
@ -450,7 +452,7 @@ def model_iteration(model,
|
|||||||
if model._distribution_strategy:
|
if model._distribution_strategy:
|
||||||
if model._compile_distribution:
|
if model._compile_distribution:
|
||||||
# TODO(priyag, psv): Copy back metrics to the original model as well?
|
# TODO(priyag, psv): Copy back metrics to the original model as well?
|
||||||
distributed_training_utils._copy_weights_to_original_model(model, mode)
|
distributed_training_utils_v1._copy_weights_to_original_model(model, mode)
|
||||||
scope.__exit__(None, None, None)
|
scope.__exit__(None, None, None)
|
||||||
|
|
||||||
if mode == ModeKeys.TRAIN:
|
if mode == ModeKeys.TRAIN:
|
||||||
@ -500,11 +502,11 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
|||||||
"""
|
"""
|
||||||
if model._distribution_strategy:
|
if model._distribution_strategy:
|
||||||
if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
|
if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
|
||||||
inputs = distributed_training_utils.get_iterator(
|
inputs = distributed_training_utils_v1.get_iterator(
|
||||||
inputs, model._distribution_strategy)
|
inputs, model._distribution_strategy)
|
||||||
|
|
||||||
def get_distributed_inputs():
|
def get_distributed_inputs():
|
||||||
return distributed_training_utils._prepare_feed_values(
|
return distributed_training_utils_v1._prepare_feed_values(
|
||||||
model, inputs, targets, sample_weights, mode)
|
model, inputs, targets, sample_weights, mode)
|
||||||
|
|
||||||
# In the eager case, we want to call the input method per step, so return
|
# In the eager case, we want to call the input method per step, so return
|
||||||
@ -537,14 +539,14 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
|||||||
|
|
||||||
def _get_iterator(inputs, distribution_strategy=None):
|
def _get_iterator(inputs, distribution_strategy=None):
|
||||||
if distribution_strategy:
|
if distribution_strategy:
|
||||||
return distributed_training_utils.get_iterator(
|
return distributed_training_utils_v1.get_iterator(
|
||||||
inputs, distribution_strategy)
|
inputs, distribution_strategy)
|
||||||
return training_utils.get_iterator(inputs)
|
return training_utils.get_iterator(inputs)
|
||||||
|
|
||||||
|
|
||||||
def _reinitialize_iterator(iterator, distribution_strategy=None):
|
def _reinitialize_iterator(iterator, distribution_strategy=None):
|
||||||
if distribution_strategy:
|
if distribution_strategy:
|
||||||
distributed_training_utils.initialize_iterator(
|
distributed_training_utils_v1.initialize_iterator(
|
||||||
iterator, distribution_strategy)
|
iterator, distribution_strategy)
|
||||||
else:
|
else:
|
||||||
training_utils.initialize_iterator(iterator)
|
training_utils.initialize_iterator(iterator)
|
||||||
@ -553,7 +555,7 @@ def _reinitialize_iterator(iterator, distribution_strategy=None):
|
|||||||
def _make_execution_function(model, mode):
|
def _make_execution_function(model, mode):
|
||||||
"""Makes function to run one step of model execution."""
|
"""Makes function to run one step of model execution."""
|
||||||
if model._distribution_strategy:
|
if model._distribution_strategy:
|
||||||
return distributed_training_utils._make_execution_function(model, mode)
|
return distributed_training_utils_v1._make_execution_function(model, mode)
|
||||||
return model._make_execution_function(mode)
|
return model._make_execution_function(mode)
|
||||||
|
|
||||||
|
|
||||||
@ -580,7 +582,7 @@ def _update_sample_weight_mode(model, mode, inputs):
|
|||||||
# Call the DistributionStrategy specific function to update the
|
# Call the DistributionStrategy specific function to update the
|
||||||
# sample_weight_mode on the model.
|
# sample_weight_mode on the model.
|
||||||
if model._distribution_strategy:
|
if model._distribution_strategy:
|
||||||
distributed_training_utils._update_sample_weight_modes(model, mode,
|
distributed_training_utils_v1._update_sample_weight_modes(model, mode,
|
||||||
sample_weights)
|
sample_weights)
|
||||||
|
|
||||||
# For backwards compatibility for internal users of these loops.
|
# For backwards compatibility for internal users of these loops.
|
||||||
|
@ -31,7 +31,8 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras import callbacks as cbks
|
from tensorflow.python.keras import callbacks as cbks
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils_v2
|
||||||
|
from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
|
||||||
from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
|
from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
|
||||||
from tensorflow.python.keras.engine import training_arrays_v1
|
from tensorflow.python.keras.engine import training_arrays_v1
|
||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
@ -648,7 +649,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
|||||||
raise ValueError('validation_split argument is not supported with '
|
raise ValueError('validation_split argument is not supported with '
|
||||||
'distribution strategies.')
|
'distribution strategies.')
|
||||||
|
|
||||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||||
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||||
model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
||||||
if steps_per_epoch is None:
|
if steps_per_epoch is None:
|
||||||
@ -705,7 +706,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
allow_partial_batch=True)
|
allow_partial_batch=True)
|
||||||
|
|
||||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||||
steps = training_utils.infer_steps_for_dataset(
|
steps = training_utils.infer_steps_for_dataset(
|
||||||
model, dataset, steps, steps_name='steps')
|
model, dataset, steps, steps_name='steps')
|
||||||
if steps is None:
|
if steps is None:
|
||||||
@ -742,7 +743,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
|||||||
x,
|
x,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
allow_partial_batch=True)
|
allow_partial_batch=True)
|
||||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
|
||||||
steps = training_utils.infer_steps_for_dataset(
|
steps = training_utils.infer_steps_for_dataset(
|
||||||
model, dataset, steps, steps_name='steps')
|
model, dataset, steps, steps_name='steps')
|
||||||
if steps is None:
|
if steps is None:
|
||||||
|
@ -42,6 +42,7 @@ from tensorflow.python.keras import losses
|
|||||||
from tensorflow.python.keras import metrics as metrics_module
|
from tensorflow.python.keras import metrics as metrics_module
|
||||||
from tensorflow.python.keras import optimizers
|
from tensorflow.python.keras import optimizers
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils
|
||||||
|
from tensorflow.python.keras.distribute import distributed_training_utils_v1
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
from tensorflow.python.keras.engine import training as training_lib
|
from tensorflow.python.keras.engine import training as training_lib
|
||||||
from tensorflow.python.keras.engine import training_arrays_v1
|
from tensorflow.python.keras.engine import training_arrays_v1
|
||||||
@ -993,7 +994,7 @@ class Model(training_lib.Model):
|
|||||||
|
|
||||||
# Reset metrics on all the distributed (cloned) models.
|
# Reset metrics on all the distributed (cloned) models.
|
||||||
if self._distribution_strategy:
|
if self._distribution_strategy:
|
||||||
distributed_training_utils._reset_metrics(self) # pylint: disable=protected-access
|
distributed_training_utils_v1._reset_metrics(self) # pylint: disable=protected-access
|
||||||
|
|
||||||
def train_on_batch(self,
|
def train_on_batch(self,
|
||||||
x,
|
x,
|
||||||
@ -1398,7 +1399,7 @@ class Model(training_lib.Model):
|
|||||||
'We currently do not support enabling `run_eagerly` with '
|
'We currently do not support enabling `run_eagerly` with '
|
||||||
'distribution strategy.')
|
'distribution strategy.')
|
||||||
|
|
||||||
if (distributed_training_utils.is_distributing_by_cloning(self) and
|
if (distributed_training_utils_v1.is_distributing_by_cloning(self) and
|
||||||
(not self.built or not self.inputs or not self.outputs)):
|
(not self.built or not self.inputs or not self.outputs)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'We currently do not support distribution strategy with a '
|
'We currently do not support distribution strategy with a '
|
||||||
@ -2856,7 +2857,7 @@ class DistributedCallbackModel(Model):
|
|||||||
self._original_model.load_weights(filepath, by_name=False)
|
self._original_model.load_weights(filepath, by_name=False)
|
||||||
# Copy the weights from the original model to each of the replicated models.
|
# Copy the weights from the original model to each of the replicated models.
|
||||||
orig_model_weights = self._original_model.get_weights()
|
orig_model_weights = self._original_model.get_weights()
|
||||||
distributed_training_utils.set_weights(
|
distributed_training_utils_v1.set_weights(
|
||||||
self._original_model._distribution_strategy, self, # pylint: disable=protected-access
|
self._original_model._distribution_strategy, self, # pylint: disable=protected-access
|
||||||
orig_model_weights)
|
orig_model_weights)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user