Support model.fit and evaluate in 2.0 with TPUStrategy using the experimental_run + train_on_batch API.
PiperOrigin-RevId: 251570029
This commit is contained in:
parent
0ba3119050
commit
a169923759
@ -668,6 +668,15 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
|
# Remove None at the end of args as they are not replicatable
|
||||||
|
# If there are None in the middle we can't do anything about it
|
||||||
|
# so let those cases fail.
|
||||||
|
# For example when Keras model predict is used they pass the targets as
|
||||||
|
# None. We want to handle it here so all client libraries don't have to
|
||||||
|
# do this as other strategies can handle None values better.
|
||||||
|
while args and args[-1] is None:
|
||||||
|
args = args[:-1]
|
||||||
|
|
||||||
# Used to re-structure flattened output tensors from `tpu.replicate()`
|
# Used to re-structure flattened output tensors from `tpu.replicate()`
|
||||||
# into a structured format.
|
# into a structured format.
|
||||||
result = [[]]
|
result = [[]]
|
||||||
|
@ -84,7 +84,7 @@ distribute_py_test(
|
|||||||
srcs = ["distribute_strategy_test.py"],
|
srcs = ["distribute_strategy_test.py"],
|
||||||
full_precision = True,
|
full_precision = True,
|
||||||
main = "distribute_strategy_test.py",
|
main = "distribute_strategy_test.py",
|
||||||
shard_count = 4,
|
shard_count = 5,
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_oss", # TODO(b/117919883): Fix python error.
|
"no_oss", # TODO(b/117919883): Fix python error.
|
||||||
|
@ -297,6 +297,11 @@ def strategy_minus_tpu_combinations():
|
|||||||
|
|
||||||
|
|
||||||
def tpu_strategy_combinations():
|
def tpu_strategy_combinations():
|
||||||
|
return combinations.combine(distribution=tpu_strategies,
|
||||||
|
mode=['graph', 'eager'])
|
||||||
|
|
||||||
|
|
||||||
|
def tpu_strategy_combinations_graph_only():
|
||||||
return combinations.combine(distribution=tpu_strategies,
|
return combinations.combine(distribution=tpu_strategies,
|
||||||
mode=['graph'])
|
mode=['graph'])
|
||||||
|
|
||||||
@ -313,8 +318,8 @@ def all_strategy_combinations_plus_cloning():
|
|||||||
cloning=[True, False]) +
|
cloning=[True, False]) +
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=tpu_strategies,
|
distribution=tpu_strategies,
|
||||||
mode=['graph'],
|
mode=['graph', 'eager'],
|
||||||
cloning=[True, False]))
|
cloning=[False]))
|
||||||
|
|
||||||
|
|
||||||
def all_strategy_minus_default_and_tpu_combinations():
|
def all_strategy_minus_default_and_tpu_combinations():
|
||||||
@ -334,8 +339,8 @@ def all_strategy_combinations_minus_default():
|
|||||||
|
|
||||||
|
|
||||||
def strategy_and_optimizer_combinations():
|
def strategy_and_optimizer_combinations():
|
||||||
return combinations.times(
|
non_tpu_strategies = combinations.times(
|
||||||
all_strategy_combinations(),
|
strategy_minus_tpu_combinations(),
|
||||||
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
|
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
optimizer=[
|
optimizer=[
|
||||||
@ -353,6 +358,32 @@ def strategy_and_optimizer_combinations():
|
|||||||
strategy_combinations.rmsprop_optimizer_keras_v2_fn
|
strategy_combinations.rmsprop_optimizer_keras_v2_fn
|
||||||
],
|
],
|
||||||
cloning=[True, False]))
|
cloning=[True, False]))
|
||||||
|
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
|
||||||
|
tpu_strategies_graph = combinations.combine(
|
||||||
|
distribution=tpu_strategies,
|
||||||
|
mode=['graph'],
|
||||||
|
cloning=[True],
|
||||||
|
optimizer=[
|
||||||
|
strategy_combinations.adagrad_optimizer_v1_fn,
|
||||||
|
strategy_combinations.adam_optimizer_v1_fn,
|
||||||
|
strategy_combinations.gradient_descent_optimizer_v1_fn,
|
||||||
|
strategy_combinations.rmsprop_optimizer_v1_fn,
|
||||||
|
strategy_combinations.adagrad_optimizer_keras_v2_fn,
|
||||||
|
strategy_combinations.adam_optimizer_keras_v2_fn,
|
||||||
|
strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
|
||||||
|
strategy_combinations.rmsprop_optimizer_keras_v2_fn
|
||||||
|
])
|
||||||
|
tpu_strategies_eager = combinations.combine(
|
||||||
|
distribution=tpu_strategies,
|
||||||
|
mode=['eager'],
|
||||||
|
cloning=[False],
|
||||||
|
optimizer=[
|
||||||
|
strategy_combinations.adagrad_optimizer_keras_v2_fn,
|
||||||
|
strategy_combinations.adam_optimizer_keras_v2_fn,
|
||||||
|
strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
|
||||||
|
strategy_combinations.rmsprop_optimizer_keras_v2_fn
|
||||||
|
])
|
||||||
|
return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph
|
||||||
|
|
||||||
|
|
||||||
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase,
|
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase,
|
||||||
@ -769,7 +800,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
self.assertAllEqual([6, 7], outs[1].shape)
|
self.assertAllEqual([6, 7], outs[1].shape)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(tpu_strategy_combinations(),
|
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||||
combinations.combine(batch_size=[4, 6])))
|
combinations.combine(batch_size=[4, 6])))
|
||||||
def test_evaluate_with_partial_batch(self, distribution, batch_size):
|
def test_evaluate_with_partial_batch(self, distribution, batch_size):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -812,7 +843,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
rtol=1e-5)
|
rtol=1e-5)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(tpu_strategy_combinations(),
|
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||||
combinations.combine(cloning=[True, False])))
|
combinations.combine(cloning=[True, False])))
|
||||||
def test_predict_with_partial_batch(self, distribution, cloning):
|
def test_predict_with_partial_batch(self, distribution, cloning):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -846,7 +877,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
atol=1e-5,
|
atol=1e-5,
|
||||||
rtol=1e-5)
|
rtol=1e-5)
|
||||||
|
|
||||||
@combinations.generate(tpu_strategy_combinations())
|
@combinations.generate(tpu_strategy_combinations_graph_only())
|
||||||
def test_no_target_model(self, distribution):
|
def test_no_target_model(self, distribution):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||||
@ -872,7 +903,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
|||||||
model.evaluate(inputs, steps=1)
|
model.evaluate(inputs, steps=1)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(tpu_strategy_combinations(),
|
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||||
combinations.combine(cloning=[True, False])))
|
combinations.combine(cloning=[True, False])))
|
||||||
def test_predict_multi_output_model_with_partial_batch(
|
def test_predict_multi_output_model_with_partial_batch(
|
||||||
self, distribution, cloning):
|
self, distribution, cloning):
|
||||||
@ -1192,6 +1223,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
|||||||
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer,
|
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer,
|
||||||
cloning):
|
cloning):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
|
|
||||||
model = get_model()
|
model = get_model()
|
||||||
@ -1341,7 +1373,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
|||||||
self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
|
self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(tpu_strategy_combinations(),
|
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||||
combinations.combine(batch_size=[4, 6])))
|
combinations.combine(batch_size=[4, 6])))
|
||||||
def test_evaluate_with_dataset_with_partial_batch(self, distribution,
|
def test_evaluate_with_dataset_with_partial_batch(self, distribution,
|
||||||
batch_size):
|
batch_size):
|
||||||
@ -1382,7 +1414,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
|||||||
rtol=1e-5)
|
rtol=1e-5)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(tpu_strategy_combinations(),
|
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||||
combinations.combine(cloning=[True, False])))
|
combinations.combine(cloning=[True, False])))
|
||||||
def test_predict_with_dataset_with_partial_batch(self, distribution, cloning):
|
def test_predict_with_dataset_with_partial_batch(self, distribution, cloning):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -1411,7 +1443,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
|||||||
rtol=1e-5)
|
rtol=1e-5)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(tpu_strategy_combinations(),
|
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||||
combinations.combine(cloning=[True, False])))
|
combinations.combine(cloning=[True, False])))
|
||||||
def test_predict_multi_output_model_with_dataset_with_partial_batch(
|
def test_predict_multi_output_model_with_dataset_with_partial_batch(
|
||||||
self, distribution, cloning):
|
self, distribution, cloning):
|
||||||
|
@ -164,6 +164,15 @@ def unwrap_outputs(distribution_strategy, grouped_outputs,
|
|||||||
grouped_outputs[0], axis=None)
|
grouped_outputs[0], axis=None)
|
||||||
all_outputs = flatten_per_replica_values(distribution_strategy,
|
all_outputs = flatten_per_replica_values(distribution_strategy,
|
||||||
grouped_outputs[1:])
|
grouped_outputs[1:])
|
||||||
|
if (is_tpu_strategy(distribution_strategy) and
|
||||||
|
ops.executing_eagerly_outside_functions()):
|
||||||
|
# Choose 1 value per replica in the TPU case since all replicas produce the
|
||||||
|
# same output.
|
||||||
|
# We only do this in eager mode for now since this function is used in
|
||||||
|
# both graph and eager mode and in the graph case we currently don't use
|
||||||
|
# experimental_run so would need to be removed when we converge the graph
|
||||||
|
# code path as well.
|
||||||
|
all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync]
|
||||||
return [loss] + all_outputs
|
return [loss] + all_outputs
|
||||||
|
|
||||||
|
|
||||||
@ -578,6 +587,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
|||||||
"""
|
"""
|
||||||
strategy = model._distribution_strategy
|
strategy = model._distribution_strategy
|
||||||
inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
|
inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
|
||||||
|
if is_tpu_strategy(strategy):
|
||||||
|
if sample_weights is not None:
|
||||||
|
raise ValueError('TPUStrategy does not support sample weights.')
|
||||||
|
|
||||||
# When the inputs are dict, then we want to flatten it in the same order as
|
# When the inputs are dict, then we want to flatten it in the same order as
|
||||||
# the input layers, such that the data are fed into the input layers in the
|
# the input layers, such that the data are fed into the input layers in the
|
||||||
@ -611,8 +623,8 @@ def is_distributing_by_cloning(model):
|
|||||||
"""Decide whether this model is going to be distributed via cloning.
|
"""Decide whether this model is going to be distributed via cloning.
|
||||||
|
|
||||||
We are going to distribute the model by cloning if the user has signaled
|
We are going to distribute the model by cloning if the user has signaled
|
||||||
that intent by not setting `cloning=False` in `Model.compile()` unless we
|
that intent by setting `cloning=True` in `Model.compile()` unless we are in
|
||||||
are in graph mode or running on TPU.
|
graph mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Keras model to distribute.
|
model: Keras model to distribute.
|
||||||
@ -621,9 +633,15 @@ def is_distributing_by_cloning(model):
|
|||||||
True if the `model` is going to be distributed using cloning and False
|
True if the `model` is going to be distributed using cloning and False
|
||||||
otherwise.
|
otherwise.
|
||||||
"""
|
"""
|
||||||
|
if (is_tpu_strategy(model._distribution_strategy) and
|
||||||
|
context.executing_eagerly):
|
||||||
|
if model._cloning:
|
||||||
|
logging.warning(
|
||||||
|
'Model cloning is not supported in TPU Strategy in Eager mode.'
|
||||||
|
'cloning argument will be ignored.')
|
||||||
|
return False
|
||||||
return (model._cloning or model._compile_distribution or
|
return (model._cloning or model._compile_distribution or
|
||||||
not ops.executing_eagerly_outside_functions() or
|
not ops.executing_eagerly_outside_functions())
|
||||||
K.is_tpu_strategy(model._distribution_strategy))
|
|
||||||
|
|
||||||
|
|
||||||
def _custom_compile_for_predict(model):
|
def _custom_compile_for_predict(model):
|
||||||
|
@ -90,11 +90,20 @@ def strategies_for_embedding_models():
|
|||||||
|
|
||||||
|
|
||||||
def test_combinations_for_embedding_model():
|
def test_combinations_for_embedding_model():
|
||||||
|
# TODO(sourabhbajaj): Enable tests for eager mode
|
||||||
|
eager_mode_strategies = [s for s in strategies_for_embedding_models()
|
||||||
|
if not s.required_tpu]
|
||||||
|
|
||||||
return (combinations.times(
|
return (combinations.times(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=strategies_for_embedding_models(),
|
distribution=strategies_for_embedding_models(),
|
||||||
cloning=[True, False]),
|
cloning=[True, False]),
|
||||||
(graph_mode_test_configuration() + eager_mode_test_configuration())))
|
(graph_mode_test_configuration())) +
|
||||||
|
combinations.times(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=eager_mode_strategies,
|
||||||
|
cloning=[False]),
|
||||||
|
(eager_mode_test_configuration())))
|
||||||
|
|
||||||
|
|
||||||
def test_combinations_with_tpu_strategies():
|
def test_combinations_with_tpu_strategies():
|
||||||
@ -322,7 +331,7 @@ def compare_results(results_with_ds,
|
|||||||
|
|
||||||
return default_tolerance
|
return default_tolerance
|
||||||
|
|
||||||
for key in results_with_ds:
|
for key in sorted(results_with_ds.keys()):
|
||||||
if (key.startswith('training_history') and
|
if (key.startswith('training_history') and
|
||||||
isinstance(distribution, (tpu_strategy.TPUStrategy,
|
isinstance(distribution, (tpu_strategy.TPUStrategy,
|
||||||
tpu_strategy.TPUStrategyV1)) and
|
tpu_strategy.TPUStrategyV1)) and
|
||||||
@ -420,9 +429,9 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
|||||||
def get_model(self, distribution=None, cloning=None, input_shapes=None):
|
def get_model(self, distribution=None, cloning=None, input_shapes=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def skip_unsupported_test_configuration(self, distribution):
|
def skip_unsupported_test_configuration(self, distribution, cloning):
|
||||||
if should_skip_tpu_with_eager(distribution):
|
if should_skip_tpu_with_eager(distribution) and cloning:
|
||||||
self.skipTest('TPUStrategy does not support eager mode now.')
|
self.skipTest('TPUStrategy does not support eager mode with cloning.')
|
||||||
return
|
return
|
||||||
|
|
||||||
def run_correctness_test(self,
|
def run_correctness_test(self,
|
||||||
@ -443,7 +452,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
|||||||
self.skipTest('Test broken; see b/129793413 and b/117920141')
|
self.skipTest('Test broken; see b/129793413 and b/117920141')
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)
|
self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)
|
||||||
self.skip_unsupported_test_configuration(distribution)
|
self.skip_unsupported_test_configuration(distribution, cloning)
|
||||||
|
|
||||||
if partial_last_batch == 'eval':
|
if partial_last_batch == 'eval':
|
||||||
x_train, y_train, x_eval, y_eval, x_predict = (
|
x_train, y_train, x_eval, y_eval, x_predict = (
|
||||||
@ -540,7 +549,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
|||||||
def run_dynamic_lr_test(self, distribution, cloning=None):
|
def run_dynamic_lr_test(self, distribution, cloning=None):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.set_up_test_config()
|
self.set_up_test_config()
|
||||||
self.skip_unsupported_test_configuration(distribution)
|
self.skip_unsupported_test_configuration(distribution, cloning)
|
||||||
|
|
||||||
x_train, y_train, _ = self.get_data()
|
x_train, y_train, _ = self.get_data()
|
||||||
model = self.get_model(cloning=cloning, input_shapes=get_shapes(x_train))
|
model = self.get_model(cloning=cloning, input_shapes=get_shapes(x_train))
|
||||||
|
@ -153,7 +153,7 @@ class TestDistributionStrategyDnnMetricCorrectness(
|
|||||||
def run_metric_correctness_test(self, distribution, cloning):
|
def run_metric_correctness_test(self, distribution, cloning):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.set_up_test_config()
|
self.set_up_test_config()
|
||||||
self.skip_unsupported_test_configuration(distribution)
|
self.skip_unsupported_test_configuration(distribution, cloning)
|
||||||
|
|
||||||
x_train, y_train, _ = self.get_data()
|
x_train, y_train, _ = self.get_data()
|
||||||
model = self.get_model(cloning, distribution=distribution)
|
model = self.get_model(cloning, distribution=distribution)
|
||||||
@ -195,7 +195,7 @@ class TestDistributionStrategyDnnMetricEvalCorrectness(
|
|||||||
def run_eval_metrics_correctness_test(self, distribution, cloning):
|
def run_eval_metrics_correctness_test(self, distribution, cloning):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.set_up_test_config()
|
self.set_up_test_config()
|
||||||
self.skip_unsupported_test_configuration(distribution)
|
self.skip_unsupported_test_configuration(distribution, cloning)
|
||||||
|
|
||||||
model = self.get_model(cloning, distribution=distribution)
|
model = self.get_model(cloning, distribution=distribution)
|
||||||
|
|
||||||
@ -266,11 +266,17 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
|||||||
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
||||||
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data,
|
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data,
|
||||||
cloning):
|
cloning):
|
||||||
if ((not cloning and context.executing_eagerly() and
|
if ((not cloning and context.executing_eagerly()) or
|
||||||
not K.is_tpu_strategy(distribution)) or
|
|
||||||
is_default_strategy(distribution)):
|
is_default_strategy(distribution)):
|
||||||
self.run_correctness_test(distribution, use_numpy, use_validation_data,
|
self.run_correctness_test(distribution, use_numpy, use_validation_data,
|
||||||
cloning)
|
cloning)
|
||||||
|
elif K.is_tpu_strategy(distribution) and not context.executing_eagerly():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
'Expected `model` argument to be a functional `Model` instance, '
|
||||||
|
'but got a subclass model instead.'):
|
||||||
|
self.run_correctness_test(distribution, use_numpy, use_validation_data,
|
||||||
|
cloning)
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
@ -286,6 +292,12 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
|||||||
not K.is_tpu_strategy(distribution)) or
|
not K.is_tpu_strategy(distribution)) or
|
||||||
is_default_strategy(distribution)):
|
is_default_strategy(distribution)):
|
||||||
self.run_dynamic_lr_test(distribution, cloning)
|
self.run_dynamic_lr_test(distribution, cloning)
|
||||||
|
elif K.is_tpu_strategy(distribution):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
'Expected `model` argument to be a functional `Model` instance, '
|
||||||
|
'but got a subclass model instead.'):
|
||||||
|
self.run_dynamic_lr_test(distribution, cloning)
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
@ -301,9 +313,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
|||||||
use_validation_data):
|
use_validation_data):
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
'We currently do not support distribution strategy with a '
|
'Expected `model` argument to be a functional `Model` instance, '
|
||||||
'`Sequential` model that is created without `input_shape`/'
|
'but got a subclass model instead.'):
|
||||||
'`input_dim` set in its first layer or a subclassed model.'):
|
|
||||||
self.run_correctness_test(
|
self.run_correctness_test(
|
||||||
distribution,
|
distribution,
|
||||||
use_numpy,
|
use_numpy,
|
||||||
|
@ -103,8 +103,8 @@ class TestDistributionStrategyWithCallbacks(test.TestCase,
|
|||||||
validation_steps=validation_steps,
|
validation_steps=validation_steps,
|
||||||
callbacks=[counter])
|
callbacks=[counter])
|
||||||
|
|
||||||
if isinstance(distribution, (tpu_strategy.TPUStrategy,
|
if (isinstance(distribution, tpu_strategy.TPUStrategyV1) and
|
||||||
tpu_strategy.TPUStrategyV1)):
|
not context.executing_eagerly()):
|
||||||
# TPU Strategy can have multi step training, from extended.steps_per_run
|
# TPU Strategy can have multi step training, from extended.steps_per_run
|
||||||
# if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch
|
# if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch
|
||||||
steps_per_run = distribution.extended.steps_per_run
|
steps_per_run = distribution.extended.steps_per_run
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.distribute import distribute_coordinator as dc
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -166,8 +167,6 @@ def experimental_tpu_fit_loop(model,
|
|||||||
# TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
|
# TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
|
||||||
current_strategy = model._distribution_strategy
|
current_strategy = model._distribution_strategy
|
||||||
iterator = dist_utils.get_iterator(dataset, current_strategy)
|
iterator = dist_utils.get_iterator(dataset, current_strategy)
|
||||||
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
|
||||||
dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
|
||||||
|
|
||||||
scope = dist_utils.distributed_scope(
|
scope = dist_utils.distributed_scope(
|
||||||
strategy=current_strategy, learning_phase=1)
|
strategy=current_strategy, learning_phase=1)
|
||||||
@ -185,12 +184,8 @@ def experimental_tpu_fit_loop(model,
|
|||||||
tensor = model._all_metrics_tensors[name]
|
tensor = model._all_metrics_tensors[name]
|
||||||
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
|
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
|
||||||
|
|
||||||
if steps_per_epoch is not None:
|
iteration_value = min(steps_per_epoch,
|
||||||
iteration_value = min(steps_per_epoch,
|
current_strategy.extended.steps_per_run)
|
||||||
current_strategy.extended.steps_per_run)
|
|
||||||
else:
|
|
||||||
raise ValueError('Number of steps could not be infered from the data, '
|
|
||||||
'please pass the steps_per_epoch argument.')
|
|
||||||
|
|
||||||
steps_per_run = K.variable(
|
steps_per_run = K.variable(
|
||||||
value=iteration_value,
|
value=iteration_value,
|
||||||
@ -320,8 +315,6 @@ def experimental_tpu_test_loop(model,
|
|||||||
mode = ModeKeys.TEST
|
mode = ModeKeys.TEST
|
||||||
current_strategy = model._distribution_strategy
|
current_strategy = model._distribution_strategy
|
||||||
iterator = dist_utils.get_iterator(dataset, current_strategy)
|
iterator = dist_utils.get_iterator(dataset, current_strategy)
|
||||||
steps = training_utils.infer_steps_for_dataset(dataset, steps,
|
|
||||||
steps_name='steps')
|
|
||||||
|
|
||||||
scope = dist_utils.distributed_scope(
|
scope = dist_utils.distributed_scope(
|
||||||
strategy=current_strategy, learning_phase=0)
|
strategy=current_strategy, learning_phase=0)
|
||||||
@ -449,8 +442,6 @@ def experimental_tpu_predict_loop(model,
|
|||||||
(if the model has multiple outputs).
|
(if the model has multiple outputs).
|
||||||
"""
|
"""
|
||||||
mode = ModeKeys.PREDICT
|
mode = ModeKeys.PREDICT
|
||||||
steps = training_utils.infer_steps_for_dataset(dataset, steps,
|
|
||||||
steps_name='steps')
|
|
||||||
dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
|
dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
|
||||||
padding_handler = None
|
padding_handler = None
|
||||||
if not dataset_fully_shaped:
|
if not dataset_fully_shaped:
|
||||||
@ -653,32 +644,40 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
|||||||
'distribution strategies.')
|
'distribution strategies.')
|
||||||
|
|
||||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||||
return experimental_tpu_fit_loop(
|
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||||
model,
|
dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
||||||
dataset,
|
if steps_per_epoch is None:
|
||||||
epochs=epochs,
|
raise ValueError('Number of steps could not be infered from the data, '
|
||||||
verbose=verbose,
|
'please pass the steps_per_epoch argument.')
|
||||||
callbacks=callbacks,
|
|
||||||
val_dataset=val_dataset,
|
if not context.executing_eagerly():
|
||||||
initial_epoch=initial_epoch,
|
# Run TPU training in a custom loop in graph mode.
|
||||||
steps_per_epoch=steps_per_epoch,
|
return experimental_tpu_fit_loop(
|
||||||
validation_steps=validation_steps,
|
model,
|
||||||
validation_freq=validation_freq)
|
dataset,
|
||||||
else:
|
epochs=epochs,
|
||||||
return training_arrays.fit_loop(
|
verbose=verbose,
|
||||||
model,
|
callbacks=callbacks,
|
||||||
dataset,
|
val_dataset=val_dataset,
|
||||||
batch_size=batch_size,
|
initial_epoch=initial_epoch,
|
||||||
epochs=epochs,
|
steps_per_epoch=steps_per_epoch,
|
||||||
verbose=verbose,
|
validation_steps=validation_steps,
|
||||||
callbacks=callbacks,
|
validation_freq=validation_freq)
|
||||||
val_inputs=val_dataset,
|
|
||||||
shuffle=shuffle,
|
return training_arrays.fit_loop(
|
||||||
initial_epoch=initial_epoch,
|
model,
|
||||||
steps_per_epoch=steps_per_epoch,
|
dataset,
|
||||||
validation_steps=validation_steps,
|
batch_size=batch_size,
|
||||||
validation_freq=validation_freq,
|
epochs=epochs,
|
||||||
steps_name='steps_per_epoch')
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
val_inputs=val_dataset,
|
||||||
|
shuffle=shuffle,
|
||||||
|
initial_epoch=initial_epoch,
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
validation_steps=validation_steps,
|
||||||
|
validation_freq=validation_freq,
|
||||||
|
steps_name='steps_per_epoch')
|
||||||
|
|
||||||
def evaluate(self,
|
def evaluate(self,
|
||||||
model,
|
model,
|
||||||
@ -702,16 +701,24 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
|||||||
allow_partial_batch=True)
|
allow_partial_batch=True)
|
||||||
|
|
||||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||||
return experimental_tpu_test_loop(
|
steps = training_utils.infer_steps_for_dataset(
|
||||||
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
|
dataset, steps, steps_name='steps')
|
||||||
else:
|
if steps is None:
|
||||||
return training_arrays.test_loop(
|
raise ValueError('Number of steps could not be infered from the data, '
|
||||||
model,
|
'please pass the steps argument.')
|
||||||
inputs=dataset,
|
|
||||||
batch_size=batch_size,
|
if not context.executing_eagerly():
|
||||||
verbose=verbose,
|
# Run TPU evaluation in a custom loop in graph mode.
|
||||||
steps=steps,
|
return experimental_tpu_test_loop(
|
||||||
callbacks=callbacks)
|
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
|
||||||
|
|
||||||
|
return training_arrays.test_loop(
|
||||||
|
model,
|
||||||
|
inputs=dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
verbose=verbose,
|
||||||
|
steps=steps,
|
||||||
|
callbacks=callbacks)
|
||||||
|
|
||||||
def predict(self,
|
def predict(self,
|
||||||
model,
|
model,
|
||||||
@ -731,16 +738,21 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
allow_partial_batch=True)
|
allow_partial_batch=True)
|
||||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||||
return experimental_tpu_predict_loop(
|
steps = training_utils.infer_steps_for_dataset(
|
||||||
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
|
dataset, steps, steps_name='steps')
|
||||||
else:
|
if steps is None:
|
||||||
return training_arrays.predict_loop(
|
raise ValueError('Number of steps could not be infered from the data, '
|
||||||
model,
|
'please pass the steps argument.')
|
||||||
dataset,
|
if not context.executing_eagerly():
|
||||||
batch_size=batch_size,
|
return experimental_tpu_predict_loop(
|
||||||
verbose=verbose,
|
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
|
||||||
steps=steps,
|
return training_arrays.predict_loop(
|
||||||
callbacks=callbacks)
|
model,
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
verbose=verbose,
|
||||||
|
steps=steps,
|
||||||
|
callbacks=callbacks)
|
||||||
|
|
||||||
def _process_batch_and_step_size(
|
def _process_batch_and_step_size(
|
||||||
self, model, inputs, batch_size, steps_per_epoch, mode):
|
self, model, inputs, batch_size, steps_per_epoch, mode):
|
||||||
|
@ -90,6 +90,11 @@ def initialize_tpu_system(cluster_resolver=None):
|
|||||||
|
|
||||||
with ops.device(tpu_system_device):
|
with ops.device(tpu_system_device):
|
||||||
output = _tpu_init_fn()
|
output = _tpu_init_fn()
|
||||||
|
|
||||||
|
# Clear out the eager context caches since the memory is invalid now.
|
||||||
|
logging.info("Clearing out eager caches")
|
||||||
|
context.context()._clear_caches() # pylint: disable=protected-access
|
||||||
|
|
||||||
serialized_topology = output.numpy()
|
serialized_topology = output.numpy()
|
||||||
else:
|
else:
|
||||||
master = cluster_resolver.master()
|
master = cluster_resolver.master()
|
||||||
|
Loading…
Reference in New Issue
Block a user