Support model.fit and evaluate in 2.0 with TPUStrategy using the experimental_run + train_on_batch API.
PiperOrigin-RevId: 251570029
This commit is contained in:
parent
0ba3119050
commit
a169923759
@ -668,6 +668,15 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# Remove None at the end of args as they are not replicatable
|
||||
# If there are None in the middle we can't do anything about it
|
||||
# so let those cases fail.
|
||||
# For example when Keras model predict is used they pass the targets as
|
||||
# None. We want to handle it here so all client libraries don't have to
|
||||
# do this as other strategies can handle None values better.
|
||||
while args and args[-1] is None:
|
||||
args = args[:-1]
|
||||
|
||||
# Used to re-structure flattened output tensors from `tpu.replicate()`
|
||||
# into a structured format.
|
||||
result = [[]]
|
||||
|
@ -84,7 +84,7 @@ distribute_py_test(
|
||||
srcs = ["distribute_strategy_test.py"],
|
||||
full_precision = True,
|
||||
main = "distribute_strategy_test.py",
|
||||
shard_count = 4,
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_oss", # TODO(b/117919883): Fix python error.
|
||||
|
@ -297,6 +297,11 @@ def strategy_minus_tpu_combinations():
|
||||
|
||||
|
||||
def tpu_strategy_combinations():
|
||||
return combinations.combine(distribution=tpu_strategies,
|
||||
mode=['graph', 'eager'])
|
||||
|
||||
|
||||
def tpu_strategy_combinations_graph_only():
|
||||
return combinations.combine(distribution=tpu_strategies,
|
||||
mode=['graph'])
|
||||
|
||||
@ -313,8 +318,8 @@ def all_strategy_combinations_plus_cloning():
|
||||
cloning=[True, False]) +
|
||||
combinations.combine(
|
||||
distribution=tpu_strategies,
|
||||
mode=['graph'],
|
||||
cloning=[True, False]))
|
||||
mode=['graph', 'eager'],
|
||||
cloning=[False]))
|
||||
|
||||
|
||||
def all_strategy_minus_default_and_tpu_combinations():
|
||||
@ -334,8 +339,8 @@ def all_strategy_combinations_minus_default():
|
||||
|
||||
|
||||
def strategy_and_optimizer_combinations():
|
||||
return combinations.times(
|
||||
all_strategy_combinations(),
|
||||
non_tpu_strategies = combinations.times(
|
||||
strategy_minus_tpu_combinations(),
|
||||
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
|
||||
combinations.combine(
|
||||
optimizer=[
|
||||
@ -353,6 +358,32 @@ def strategy_and_optimizer_combinations():
|
||||
strategy_combinations.rmsprop_optimizer_keras_v2_fn
|
||||
],
|
||||
cloning=[True, False]))
|
||||
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
|
||||
tpu_strategies_graph = combinations.combine(
|
||||
distribution=tpu_strategies,
|
||||
mode=['graph'],
|
||||
cloning=[True],
|
||||
optimizer=[
|
||||
strategy_combinations.adagrad_optimizer_v1_fn,
|
||||
strategy_combinations.adam_optimizer_v1_fn,
|
||||
strategy_combinations.gradient_descent_optimizer_v1_fn,
|
||||
strategy_combinations.rmsprop_optimizer_v1_fn,
|
||||
strategy_combinations.adagrad_optimizer_keras_v2_fn,
|
||||
strategy_combinations.adam_optimizer_keras_v2_fn,
|
||||
strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
|
||||
strategy_combinations.rmsprop_optimizer_keras_v2_fn
|
||||
])
|
||||
tpu_strategies_eager = combinations.combine(
|
||||
distribution=tpu_strategies,
|
||||
mode=['eager'],
|
||||
cloning=[False],
|
||||
optimizer=[
|
||||
strategy_combinations.adagrad_optimizer_keras_v2_fn,
|
||||
strategy_combinations.adam_optimizer_keras_v2_fn,
|
||||
strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
|
||||
strategy_combinations.rmsprop_optimizer_keras_v2_fn
|
||||
])
|
||||
return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph
|
||||
|
||||
|
||||
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase,
|
||||
@ -769,7 +800,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
self.assertAllEqual([6, 7], outs[1].shape)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(tpu_strategy_combinations(),
|
||||
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||
combinations.combine(batch_size=[4, 6])))
|
||||
def test_evaluate_with_partial_batch(self, distribution, batch_size):
|
||||
with self.cached_session():
|
||||
@ -812,7 +843,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
rtol=1e-5)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(tpu_strategy_combinations(),
|
||||
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||
combinations.combine(cloning=[True, False])))
|
||||
def test_predict_with_partial_batch(self, distribution, cloning):
|
||||
with self.cached_session():
|
||||
@ -846,7 +877,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
atol=1e-5,
|
||||
rtol=1e-5)
|
||||
|
||||
@combinations.generate(tpu_strategy_combinations())
|
||||
@combinations.generate(tpu_strategy_combinations_graph_only())
|
||||
def test_no_target_model(self, distribution):
|
||||
with self.cached_session():
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
|
||||
@ -872,7 +903,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
model.evaluate(inputs, steps=1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(tpu_strategy_combinations(),
|
||||
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||
combinations.combine(cloning=[True, False])))
|
||||
def test_predict_multi_output_model_with_partial_batch(
|
||||
self, distribution, cloning):
|
||||
@ -1192,6 +1223,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer,
|
||||
cloning):
|
||||
with self.cached_session():
|
||||
|
||||
with distribution.scope():
|
||||
|
||||
model = get_model()
|
||||
@ -1341,7 +1373,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(tpu_strategy_combinations(),
|
||||
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||
combinations.combine(batch_size=[4, 6])))
|
||||
def test_evaluate_with_dataset_with_partial_batch(self, distribution,
|
||||
batch_size):
|
||||
@ -1382,7 +1414,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
rtol=1e-5)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(tpu_strategy_combinations(),
|
||||
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||
combinations.combine(cloning=[True, False])))
|
||||
def test_predict_with_dataset_with_partial_batch(self, distribution, cloning):
|
||||
with self.cached_session():
|
||||
@ -1411,7 +1443,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
rtol=1e-5)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(tpu_strategy_combinations(),
|
||||
combinations.times(tpu_strategy_combinations_graph_only(),
|
||||
combinations.combine(cloning=[True, False])))
|
||||
def test_predict_multi_output_model_with_dataset_with_partial_batch(
|
||||
self, distribution, cloning):
|
||||
|
@ -164,6 +164,15 @@ def unwrap_outputs(distribution_strategy, grouped_outputs,
|
||||
grouped_outputs[0], axis=None)
|
||||
all_outputs = flatten_per_replica_values(distribution_strategy,
|
||||
grouped_outputs[1:])
|
||||
if (is_tpu_strategy(distribution_strategy) and
|
||||
ops.executing_eagerly_outside_functions()):
|
||||
# Choose 1 value per replica in the TPU case since all replicas produce the
|
||||
# same output.
|
||||
# We only do this in eager mode for now since this function is used in
|
||||
# both graph and eager mode and in the graph case we currently don't use
|
||||
# experimental_run so would need to be removed when we converge the graph
|
||||
# code path as well.
|
||||
all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync]
|
||||
return [loss] + all_outputs
|
||||
|
||||
|
||||
@ -578,6 +587,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
|
||||
"""
|
||||
strategy = model._distribution_strategy
|
||||
inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
|
||||
if is_tpu_strategy(strategy):
|
||||
if sample_weights is not None:
|
||||
raise ValueError('TPUStrategy does not support sample weights.')
|
||||
|
||||
# When the inputs are dict, then we want to flatten it in the same order as
|
||||
# the input layers, such that the data are fed into the input layers in the
|
||||
@ -611,8 +623,8 @@ def is_distributing_by_cloning(model):
|
||||
"""Decide whether this model is going to be distributed via cloning.
|
||||
|
||||
We are going to distribute the model by cloning if the user has signaled
|
||||
that intent by not setting `cloning=False` in `Model.compile()` unless we
|
||||
are in graph mode or running on TPU.
|
||||
that intent by setting `cloning=True` in `Model.compile()` unless we are in
|
||||
graph mode.
|
||||
|
||||
Args:
|
||||
model: Keras model to distribute.
|
||||
@ -621,9 +633,15 @@ def is_distributing_by_cloning(model):
|
||||
True if the `model` is going to be distributed using cloning and False
|
||||
otherwise.
|
||||
"""
|
||||
if (is_tpu_strategy(model._distribution_strategy) and
|
||||
context.executing_eagerly):
|
||||
if model._cloning:
|
||||
logging.warning(
|
||||
'Model cloning is not supported in TPU Strategy in Eager mode.'
|
||||
'cloning argument will be ignored.')
|
||||
return False
|
||||
return (model._cloning or model._compile_distribution or
|
||||
not ops.executing_eagerly_outside_functions() or
|
||||
K.is_tpu_strategy(model._distribution_strategy))
|
||||
not ops.executing_eagerly_outside_functions())
|
||||
|
||||
|
||||
def _custom_compile_for_predict(model):
|
||||
|
@ -90,11 +90,20 @@ def strategies_for_embedding_models():
|
||||
|
||||
|
||||
def test_combinations_for_embedding_model():
|
||||
# TODO(sourabhbajaj): Enable tests for eager mode
|
||||
eager_mode_strategies = [s for s in strategies_for_embedding_models()
|
||||
if not s.required_tpu]
|
||||
|
||||
return (combinations.times(
|
||||
combinations.combine(
|
||||
distribution=strategies_for_embedding_models(),
|
||||
cloning=[True, False]),
|
||||
(graph_mode_test_configuration() + eager_mode_test_configuration())))
|
||||
(graph_mode_test_configuration())) +
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
distribution=eager_mode_strategies,
|
||||
cloning=[False]),
|
||||
(eager_mode_test_configuration())))
|
||||
|
||||
|
||||
def test_combinations_with_tpu_strategies():
|
||||
@ -322,7 +331,7 @@ def compare_results(results_with_ds,
|
||||
|
||||
return default_tolerance
|
||||
|
||||
for key in results_with_ds:
|
||||
for key in sorted(results_with_ds.keys()):
|
||||
if (key.startswith('training_history') and
|
||||
isinstance(distribution, (tpu_strategy.TPUStrategy,
|
||||
tpu_strategy.TPUStrategyV1)) and
|
||||
@ -420,9 +429,9 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
||||
def get_model(self, distribution=None, cloning=None, input_shapes=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def skip_unsupported_test_configuration(self, distribution):
|
||||
if should_skip_tpu_with_eager(distribution):
|
||||
self.skipTest('TPUStrategy does not support eager mode now.')
|
||||
def skip_unsupported_test_configuration(self, distribution, cloning):
|
||||
if should_skip_tpu_with_eager(distribution) and cloning:
|
||||
self.skipTest('TPUStrategy does not support eager mode with cloning.')
|
||||
return
|
||||
|
||||
def run_correctness_test(self,
|
||||
@ -443,7 +452,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
||||
self.skipTest('Test broken; see b/129793413 and b/117920141')
|
||||
with self.cached_session():
|
||||
self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)
|
||||
self.skip_unsupported_test_configuration(distribution)
|
||||
self.skip_unsupported_test_configuration(distribution, cloning)
|
||||
|
||||
if partial_last_batch == 'eval':
|
||||
x_train, y_train, x_eval, y_eval, x_predict = (
|
||||
@ -540,7 +549,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
|
||||
def run_dynamic_lr_test(self, distribution, cloning=None):
|
||||
with self.cached_session():
|
||||
self.set_up_test_config()
|
||||
self.skip_unsupported_test_configuration(distribution)
|
||||
self.skip_unsupported_test_configuration(distribution, cloning)
|
||||
|
||||
x_train, y_train, _ = self.get_data()
|
||||
model = self.get_model(cloning=cloning, input_shapes=get_shapes(x_train))
|
||||
|
@ -153,7 +153,7 @@ class TestDistributionStrategyDnnMetricCorrectness(
|
||||
def run_metric_correctness_test(self, distribution, cloning):
|
||||
with self.cached_session():
|
||||
self.set_up_test_config()
|
||||
self.skip_unsupported_test_configuration(distribution)
|
||||
self.skip_unsupported_test_configuration(distribution, cloning)
|
||||
|
||||
x_train, y_train, _ = self.get_data()
|
||||
model = self.get_model(cloning, distribution=distribution)
|
||||
@ -195,7 +195,7 @@ class TestDistributionStrategyDnnMetricEvalCorrectness(
|
||||
def run_eval_metrics_correctness_test(self, distribution, cloning):
|
||||
with self.cached_session():
|
||||
self.set_up_test_config()
|
||||
self.skip_unsupported_test_configuration(distribution)
|
||||
self.skip_unsupported_test_configuration(distribution, cloning)
|
||||
|
||||
model = self.get_model(cloning, distribution=distribution)
|
||||
|
||||
@ -266,11 +266,17 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
||||
keras_correctness_test_base.all_strategy_and_input_config_combinations())
|
||||
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data,
|
||||
cloning):
|
||||
if ((not cloning and context.executing_eagerly() and
|
||||
not K.is_tpu_strategy(distribution)) or
|
||||
if ((not cloning and context.executing_eagerly()) or
|
||||
is_default_strategy(distribution)):
|
||||
self.run_correctness_test(distribution, use_numpy, use_validation_data,
|
||||
cloning)
|
||||
elif K.is_tpu_strategy(distribution) and not context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'Expected `model` argument to be a functional `Model` instance, '
|
||||
'but got a subclass model instead.'):
|
||||
self.run_correctness_test(distribution, use_numpy, use_validation_data,
|
||||
cloning)
|
||||
else:
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
@ -286,6 +292,12 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
||||
not K.is_tpu_strategy(distribution)) or
|
||||
is_default_strategy(distribution)):
|
||||
self.run_dynamic_lr_test(distribution, cloning)
|
||||
elif K.is_tpu_strategy(distribution):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'Expected `model` argument to be a functional `Model` instance, '
|
||||
'but got a subclass model instead.'):
|
||||
self.run_dynamic_lr_test(distribution, cloning)
|
||||
else:
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
@ -301,9 +313,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
|
||||
use_validation_data):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'We currently do not support distribution strategy with a '
|
||||
'`Sequential` model that is created without `input_shape`/'
|
||||
'`input_dim` set in its first layer or a subclassed model.'):
|
||||
'Expected `model` argument to be a functional `Model` instance, '
|
||||
'but got a subclass model instead.'):
|
||||
self.run_correctness_test(
|
||||
distribution,
|
||||
use_numpy,
|
||||
|
@ -103,8 +103,8 @@ class TestDistributionStrategyWithCallbacks(test.TestCase,
|
||||
validation_steps=validation_steps,
|
||||
callbacks=[counter])
|
||||
|
||||
if isinstance(distribution, (tpu_strategy.TPUStrategy,
|
||||
tpu_strategy.TPUStrategyV1)):
|
||||
if (isinstance(distribution, tpu_strategy.TPUStrategyV1) and
|
||||
not context.executing_eagerly()):
|
||||
# TPU Strategy can have multi step training, from extended.steps_per_run
|
||||
# if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch
|
||||
steps_per_run = distribution.extended.steps_per_run
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -166,8 +167,6 @@ def experimental_tpu_fit_loop(model,
|
||||
# TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
|
||||
current_strategy = model._distribution_strategy
|
||||
iterator = dist_utils.get_iterator(dataset, current_strategy)
|
||||
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||
dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
||||
|
||||
scope = dist_utils.distributed_scope(
|
||||
strategy=current_strategy, learning_phase=1)
|
||||
@ -185,12 +184,8 @@ def experimental_tpu_fit_loop(model,
|
||||
tensor = model._all_metrics_tensors[name]
|
||||
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
|
||||
|
||||
if steps_per_epoch is not None:
|
||||
iteration_value = min(steps_per_epoch,
|
||||
current_strategy.extended.steps_per_run)
|
||||
else:
|
||||
raise ValueError('Number of steps could not be infered from the data, '
|
||||
'please pass the steps_per_epoch argument.')
|
||||
iteration_value = min(steps_per_epoch,
|
||||
current_strategy.extended.steps_per_run)
|
||||
|
||||
steps_per_run = K.variable(
|
||||
value=iteration_value,
|
||||
@ -320,8 +315,6 @@ def experimental_tpu_test_loop(model,
|
||||
mode = ModeKeys.TEST
|
||||
current_strategy = model._distribution_strategy
|
||||
iterator = dist_utils.get_iterator(dataset, current_strategy)
|
||||
steps = training_utils.infer_steps_for_dataset(dataset, steps,
|
||||
steps_name='steps')
|
||||
|
||||
scope = dist_utils.distributed_scope(
|
||||
strategy=current_strategy, learning_phase=0)
|
||||
@ -449,8 +442,6 @@ def experimental_tpu_predict_loop(model,
|
||||
(if the model has multiple outputs).
|
||||
"""
|
||||
mode = ModeKeys.PREDICT
|
||||
steps = training_utils.infer_steps_for_dataset(dataset, steps,
|
||||
steps_name='steps')
|
||||
dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
|
||||
padding_handler = None
|
||||
if not dataset_fully_shaped:
|
||||
@ -653,32 +644,40 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
||||
'distribution strategies.')
|
||||
|
||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||
return experimental_tpu_fit_loop(
|
||||
model,
|
||||
dataset,
|
||||
epochs=epochs,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
val_dataset=val_dataset,
|
||||
initial_epoch=initial_epoch,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
validation_steps=validation_steps,
|
||||
validation_freq=validation_freq)
|
||||
else:
|
||||
return training_arrays.fit_loop(
|
||||
model,
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
val_inputs=val_dataset,
|
||||
shuffle=shuffle,
|
||||
initial_epoch=initial_epoch,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
validation_steps=validation_steps,
|
||||
validation_freq=validation_freq,
|
||||
steps_name='steps_per_epoch')
|
||||
steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||
dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
|
||||
if steps_per_epoch is None:
|
||||
raise ValueError('Number of steps could not be infered from the data, '
|
||||
'please pass the steps_per_epoch argument.')
|
||||
|
||||
if not context.executing_eagerly():
|
||||
# Run TPU training in a custom loop in graph mode.
|
||||
return experimental_tpu_fit_loop(
|
||||
model,
|
||||
dataset,
|
||||
epochs=epochs,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
val_dataset=val_dataset,
|
||||
initial_epoch=initial_epoch,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
validation_steps=validation_steps,
|
||||
validation_freq=validation_freq)
|
||||
|
||||
return training_arrays.fit_loop(
|
||||
model,
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
val_inputs=val_dataset,
|
||||
shuffle=shuffle,
|
||||
initial_epoch=initial_epoch,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
validation_steps=validation_steps,
|
||||
validation_freq=validation_freq,
|
||||
steps_name='steps_per_epoch')
|
||||
|
||||
def evaluate(self,
|
||||
model,
|
||||
@ -702,16 +701,24 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
||||
allow_partial_batch=True)
|
||||
|
||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||
return experimental_tpu_test_loop(
|
||||
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
|
||||
else:
|
||||
return training_arrays.test_loop(
|
||||
model,
|
||||
inputs=dataset,
|
||||
batch_size=batch_size,
|
||||
verbose=verbose,
|
||||
steps=steps,
|
||||
callbacks=callbacks)
|
||||
steps = training_utils.infer_steps_for_dataset(
|
||||
dataset, steps, steps_name='steps')
|
||||
if steps is None:
|
||||
raise ValueError('Number of steps could not be infered from the data, '
|
||||
'please pass the steps argument.')
|
||||
|
||||
if not context.executing_eagerly():
|
||||
# Run TPU evaluation in a custom loop in graph mode.
|
||||
return experimental_tpu_test_loop(
|
||||
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
|
||||
|
||||
return training_arrays.test_loop(
|
||||
model,
|
||||
inputs=dataset,
|
||||
batch_size=batch_size,
|
||||
verbose=verbose,
|
||||
steps=steps,
|
||||
callbacks=callbacks)
|
||||
|
||||
def predict(self,
|
||||
model,
|
||||
@ -731,16 +738,21 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
|
||||
batch_size=batch_size,
|
||||
allow_partial_batch=True)
|
||||
if dist_utils.is_tpu_strategy(model._distribution_strategy):
|
||||
return experimental_tpu_predict_loop(
|
||||
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
|
||||
else:
|
||||
return training_arrays.predict_loop(
|
||||
model,
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
verbose=verbose,
|
||||
steps=steps,
|
||||
callbacks=callbacks)
|
||||
steps = training_utils.infer_steps_for_dataset(
|
||||
dataset, steps, steps_name='steps')
|
||||
if steps is None:
|
||||
raise ValueError('Number of steps could not be infered from the data, '
|
||||
'please pass the steps argument.')
|
||||
if not context.executing_eagerly():
|
||||
return experimental_tpu_predict_loop(
|
||||
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
|
||||
return training_arrays.predict_loop(
|
||||
model,
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
verbose=verbose,
|
||||
steps=steps,
|
||||
callbacks=callbacks)
|
||||
|
||||
def _process_batch_and_step_size(
|
||||
self, model, inputs, batch_size, steps_per_epoch, mode):
|
||||
|
@ -90,6 +90,11 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
|
||||
with ops.device(tpu_system_device):
|
||||
output = _tpu_init_fn()
|
||||
|
||||
# Clear out the eager context caches since the memory is invalid now.
|
||||
logging.info("Clearing out eager caches")
|
||||
context.context()._clear_caches() # pylint: disable=protected-access
|
||||
|
||||
serialized_topology = output.numpy()
|
||||
else:
|
||||
master = cluster_resolver.master()
|
||||
|
Loading…
Reference in New Issue
Block a user