Merge pull request #31663 from tensorflow/ggadde-cp9
r2.0 Cherrypick: bugs fixes and test fixes.
This commit is contained in:
commit
1ad1f90069
@ -306,13 +306,16 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
output_saved_model_dir=self.mkdtemp(),
|
||||
need_calibration=need_calibration)
|
||||
|
||||
def _CreateConverterV2(self,
|
||||
input_saved_model_dir,
|
||||
precision_mode=trt_convert.TrtPrecisionMode.FP32):
|
||||
def _CreateConverterV2(
|
||||
self,
|
||||
input_saved_model_dir,
|
||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||
precision_mode=trt_convert.TrtPrecisionMode.FP32):
|
||||
return trt_convert.TrtGraphConverterV2(
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||
input_saved_model_signature_key=input_saved_model_signature_key,
|
||||
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
|
||||
precision_mode=precision_mode,
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=2))
|
||||
@ -493,8 +496,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
root = model_class()
|
||||
save.save(root, input_saved_model_dir)
|
||||
|
||||
converter = trt_convert.TrtGraphConverterV2(
|
||||
input_saved_model_dir=input_saved_model_dir)
|
||||
converter = self._CreateConverterV2(
|
||||
input_saved_model_dir, input_saved_model_signature_key=signature_key)
|
||||
converter.convert()
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
converter.save(output_saved_model_dir)
|
||||
|
@ -1577,22 +1577,6 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "training_v2_utils_test",
|
||||
size = "medium",
|
||||
srcs = ["engine/training_v2_utils_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
tags = [
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
"notsan",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "model_subclassing_test",
|
||||
size = "medium",
|
||||
|
@ -143,6 +143,7 @@ class CallbackCountsTest(keras_parameterized.TestCase):
|
||||
def test_callback_hooks_are_called_in_fit(self, data):
|
||||
x, y = data
|
||||
val_x, val_y = np.ones((4, 10)), np.ones((4, 1))
|
||||
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
|
||||
|
||||
model = self._get_model()
|
||||
counter = Counter()
|
||||
@ -150,7 +151,8 @@ class CallbackCountsTest(keras_parameterized.TestCase):
|
||||
x,
|
||||
y,
|
||||
validation_data=(val_x, val_y),
|
||||
batch_size=2,
|
||||
batch_size=2 if not is_sequence else None,
|
||||
steps_per_epoch=5 if is_sequence else None,
|
||||
epochs=5,
|
||||
callbacks=[counter])
|
||||
|
||||
@ -178,10 +180,16 @@ class CallbackCountsTest(keras_parameterized.TestCase):
|
||||
('with_sequence', _get_sequence()))
|
||||
def test_callback_hooks_are_called_in_evaluate(self, data):
|
||||
x, y = data
|
||||
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
|
||||
|
||||
model = self._get_model()
|
||||
counter = Counter()
|
||||
model.evaluate(x, y, batch_size=2, callbacks=[counter])
|
||||
model.evaluate(
|
||||
x,
|
||||
y,
|
||||
batch_size=2 if not is_sequence else None,
|
||||
steps=5 if is_sequence else None,
|
||||
callbacks=[counter])
|
||||
self._check_counts(
|
||||
counter, {
|
||||
'on_test_batch_begin': 5,
|
||||
@ -194,10 +202,15 @@ class CallbackCountsTest(keras_parameterized.TestCase):
|
||||
('with_sequence', _get_sequence()))
|
||||
def test_callback_hooks_are_called_in_predict(self, data):
|
||||
x = data[0]
|
||||
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
|
||||
|
||||
model = self._get_model()
|
||||
counter = Counter()
|
||||
model.predict(x, batch_size=2, callbacks=[counter])
|
||||
model.predict(
|
||||
x,
|
||||
batch_size=2 if not is_sequence else None,
|
||||
steps=5 if is_sequence else None,
|
||||
callbacks=[counter])
|
||||
self._check_counts(
|
||||
counter, {
|
||||
'on_predict_batch_begin': 5,
|
||||
@ -270,6 +283,21 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
|
||||
model.fit(dataset, epochs=2, steps_per_epoch=10)
|
||||
self.assertRegexpMatches(printed.contents(), expected_log)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_progbar_logging_validation_data(self):
|
||||
model = self._get_model(input_shape=(3,))
|
||||
|
||||
x = array_ops.ones((50, 3))
|
||||
y = array_ops.zeros((50, 2))
|
||||
training_dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
|
||||
val_dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
|
||||
expected_log = r'(.*5/5.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*)+'
|
||||
|
||||
with self.captureWritesToStream(sys.stdout) as printed:
|
||||
model.fit(training_dataset, epochs=2, validation_data=val_dataset)
|
||||
self.assertRegexpMatches(printed.contents(), expected_log)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_ModelCheckpoint(self):
|
||||
if h5py is None:
|
||||
|
@ -137,6 +137,38 @@ def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
|
||||
return all_inputs, all_outputs, all_updates, all_session_args
|
||||
|
||||
|
||||
def unwrap_output_dict(strategy, grouped_outputs, mode):
|
||||
"""Unwrap the list of outputs contained in the PerReplica parameters."""
|
||||
if mode == ModeKeys.PREDICT:
|
||||
return flatten_per_replica_values(strategy, grouped_outputs)
|
||||
|
||||
# In the case of fit/eval, the grouped_outputs is a dict, whereas in predict,
|
||||
# the output is as same structure as model output. They need to be treated
|
||||
# differently
|
||||
total_loss = strategy.reduce(reduce_util.ReduceOp.SUM,
|
||||
grouped_outputs['total_loss'][0], axis=None)
|
||||
output_losses = flatten_per_replica_values(strategy,
|
||||
grouped_outputs['output_losses'])
|
||||
metrics = flatten_per_replica_values(strategy,
|
||||
grouped_outputs['metrics'])
|
||||
batch_size = strategy.reduce(reduce_util.ReduceOp.SUM,
|
||||
grouped_outputs['batch_size'], axis=None)
|
||||
if (is_tpu_strategy(strategy) and
|
||||
ops.executing_eagerly_outside_functions()):
|
||||
# Choose 1 value per replica in the TPU case since all replicas produce the
|
||||
# same output.
|
||||
# We only do this in eager mode for now since this function is used in
|
||||
# both graph and eager mode and in the graph case we currently don't use
|
||||
# experimental_run so would need to be removed when we converge the graph
|
||||
# code path as well.
|
||||
output_losses = output_losses[::strategy.num_replicas_in_sync]
|
||||
metrics = metrics[::strategy.num_replicas_in_sync]
|
||||
return {'total_loss': [total_loss],
|
||||
'output_losses': output_losses,
|
||||
'metrics': metrics,
|
||||
'batch_size': batch_size}
|
||||
|
||||
|
||||
def unwrap_outputs(distribution_strategy, grouped_outputs,
|
||||
with_loss_tensor=False):
|
||||
"""Unwrap the list of outputs contained in the PerReplica parameters.
|
||||
|
@ -70,7 +70,7 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
|
||||
hist = model.fit(inputs, output, epochs=5)
|
||||
else:
|
||||
hist = model.fit(get_dataset(), epochs=5)
|
||||
self.assertLess(hist.history['loss'][4], 0.1)
|
||||
self.assertLess(hist.history['loss'][4], 0.2)
|
||||
|
||||
@combinations.generate(strategy_combinations_eager_data_fn())
|
||||
def test_wide_deep_model(self, distribution, data_fn):
|
||||
|
@ -483,7 +483,7 @@ class Model(network.Network):
|
||||
def run_eagerly(self, value):
|
||||
self._run_eagerly = value
|
||||
|
||||
def _select_training_loop(self, inputs, callbacks):
|
||||
def _select_training_loop(self, inputs):
|
||||
"""Select training loop for fit/eval/predict based on the inputs."""
|
||||
# TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely
|
||||
# integrated into the data adapters in the v2 loop. We can't do this yet
|
||||
@ -501,9 +501,7 @@ class Model(network.Network):
|
||||
if (context.executing_eagerly()
|
||||
and self._experimental_run_tf_function
|
||||
and not distributed_training_utils.is_tpu_strategy(
|
||||
self._distribution_strategy)
|
||||
and not training_v2_utils.should_fallback_to_v1_for_callback(
|
||||
inputs, callbacks)):
|
||||
self._distribution_strategy)):
|
||||
try:
|
||||
valid_adapter = data_adapter.select_data_adapter(inputs, None)
|
||||
except ValueError as data_failure_exception:
|
||||
@ -713,7 +711,7 @@ class Model(network.Network):
|
||||
self._assert_compile_was_called()
|
||||
self._check_call_args('fit')
|
||||
|
||||
func = self._select_training_loop(x, callbacks)
|
||||
func = self._select_training_loop(x)
|
||||
return func.fit(
|
||||
self,
|
||||
x=x,
|
||||
@ -826,7 +824,7 @@ class Model(network.Network):
|
||||
self._assert_compile_was_called()
|
||||
self._check_call_args('evaluate')
|
||||
|
||||
func = self._select_training_loop(x, callbacks)
|
||||
func = self._select_training_loop(x)
|
||||
return func.evaluate(
|
||||
self,
|
||||
x=x,
|
||||
@ -904,7 +902,7 @@ class Model(network.Network):
|
||||
_keras_api_gauge.get_cell('predict').set(True)
|
||||
self._check_call_args('predict')
|
||||
|
||||
func = self._select_training_loop(x, callbacks)
|
||||
func = self._select_training_loop(x)
|
||||
return func.predict(
|
||||
self,
|
||||
x=x,
|
||||
@ -979,6 +977,8 @@ class Model(network.Network):
|
||||
outputs = training_v2_utils.train_on_batch(
|
||||
self, x, y=y, sample_weight=sample_weight,
|
||||
class_weight=class_weight, reset_metrics=reset_metrics)
|
||||
outputs = (outputs['total_loss'] + outputs['output_losses'] +
|
||||
outputs['metrics'])
|
||||
outputs = [
|
||||
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||
if len(outputs) == 1:
|
||||
@ -1002,12 +1002,14 @@ class Model(network.Network):
|
||||
# for each replica by `self._distribution_strategy` and the same code path
|
||||
# as Eager is expected to be taken.
|
||||
if self.run_eagerly or self._distribution_strategy:
|
||||
outputs = training_eager.train_on_batch(
|
||||
output_dict = training_eager.train_on_batch(
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
sample_weights=sample_weights,
|
||||
output_loss_metrics=self._output_loss_metrics)
|
||||
outputs = (output_dict['total_loss'] + output_dict['output_losses']
|
||||
+ output_dict['metrics'])
|
||||
outputs = [
|
||||
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||
else:
|
||||
@ -1072,6 +1074,8 @@ class Model(network.Network):
|
||||
outputs = training_v2_utils.test_on_batch(
|
||||
self, x, y=y, sample_weight=sample_weight,
|
||||
reset_metrics=reset_metrics)
|
||||
outputs = (outputs['total_loss'] + outputs['output_losses'] +
|
||||
outputs['metrics'])
|
||||
outputs = [
|
||||
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||
if len(outputs) == 1:
|
||||
@ -1089,12 +1093,14 @@ class Model(network.Network):
|
||||
# If `self._distribution_strategy` is True, then we are in a replica context
|
||||
# at this point.
|
||||
if self.run_eagerly or self._distribution_strategy:
|
||||
outputs = training_eager.test_on_batch(
|
||||
output_dict = training_eager.test_on_batch(
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
sample_weights=sample_weights,
|
||||
output_loss_metrics=self._output_loss_metrics)
|
||||
outputs = (output_dict['total_loss'] + output_dict['output_losses']
|
||||
+ output_dict['metrics'])
|
||||
outputs = [
|
||||
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||
else:
|
||||
@ -1763,9 +1769,16 @@ class Model(network.Network):
|
||||
The validated batch_size, auto-inferred from the first layer if not
|
||||
provided.
|
||||
"""
|
||||
if batch_size is not None and isinstance(x, dataset_ops.DatasetV2):
|
||||
raise ValueError('The `batch_size` argument must not be specified when'
|
||||
' using dataset as an input.')
|
||||
if (isinstance(x, (dataset_ops.DatasetV1,
|
||||
dataset_ops.DatasetV2,
|
||||
data_utils.Sequence)) or
|
||||
tf_inspect.isgenerator(x)):
|
||||
if batch_size is not None:
|
||||
raise ValueError(
|
||||
'The `batch_size` argument must not be specified for the given '
|
||||
'input type. Received input: {}, batch_size: {}'.format(
|
||||
x, batch_size))
|
||||
return
|
||||
|
||||
layers = super(Model, self).layers # Avoids the override in Sequential.
|
||||
if layers:
|
||||
@ -1819,13 +1832,7 @@ class Model(network.Network):
|
||||
if steps is None:
|
||||
batch_size = static_batch_size
|
||||
|
||||
if (batch_size is None
|
||||
and steps is None
|
||||
and not isinstance(x, (dataset_ops.DatasetV2,
|
||||
iterator_ops.Iterator,
|
||||
iterator_ops.IteratorV2,
|
||||
data_utils.Sequence))
|
||||
and not tf_inspect.isgenerator(x)):
|
||||
if batch_size is None and steps is None:
|
||||
# Backwards compatibility
|
||||
batch_size = 32
|
||||
return batch_size
|
||||
|
@ -129,19 +129,16 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
||||
sample_weight=sample_weight)
|
||||
|
||||
# Test invalid usage
|
||||
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
|
||||
' must not be specified when using dataset'
|
||||
' as an input.'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.fit(dataset, batch_size=10, epochs=1, steps_per_epoch=2,
|
||||
verbose=0)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
|
||||
' must not be specified when using dataset'
|
||||
' as an input.'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.predict(dataset, batch_size=10, steps=2, verbose=0)
|
||||
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
|
||||
' must not be specified when using dataset'
|
||||
' as an input.'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.evaluate(dataset, batch_size=10, steps=2, verbose=0)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
|
@ -294,7 +294,11 @@ def train_on_batch(model,
|
||||
loss values.
|
||||
|
||||
Returns:
|
||||
total loss and the loss associated with each output.
|
||||
Dict with three items:
|
||||
'total_loss': list with a single tensor for overall loss,
|
||||
'output_losses': list of tensors for loss corresponding to each of the
|
||||
model output. Could be a empty list when model has only one output.
|
||||
'metrics': list of tensors for metric specified.
|
||||
"""
|
||||
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
|
||||
outs, total_loss, output_losses, masks = (
|
||||
@ -310,9 +314,9 @@ def train_on_batch(model,
|
||||
metrics_results = _eager_metrics_fn(
|
||||
model, outs, targets, sample_weights=sample_weights, masks=masks)
|
||||
total_loss = nest.flatten(total_loss)
|
||||
results = total_loss + output_losses + metrics_results
|
||||
|
||||
return results
|
||||
return {'total_loss': total_loss,
|
||||
'output_losses': output_losses,
|
||||
'metrics': metrics_results}
|
||||
|
||||
|
||||
def test_on_batch(model,
|
||||
@ -331,7 +335,11 @@ def test_on_batch(model,
|
||||
loss values.
|
||||
|
||||
Returns:
|
||||
total loss, loss and metrics associated with each output.
|
||||
Dict with three items:
|
||||
'total_loss': single tensor for overall loss,
|
||||
'output_losses': list of tensors for loss corresponding to each of the
|
||||
model output. Could be a empty list when model has only one output.
|
||||
'metrics': list of tensors for metric specified.
|
||||
"""
|
||||
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
|
||||
|
||||
@ -349,6 +357,7 @@ def test_on_batch(model,
|
||||
metrics_results = _eager_metrics_fn(
|
||||
model, outs, targets, sample_weights=sample_weights, masks=masks)
|
||||
total_loss = nest.flatten(total_loss)
|
||||
results = total_loss + output_losses + metrics_results
|
||||
|
||||
return results
|
||||
return {'total_loss': total_loss,
|
||||
'output_losses': output_losses,
|
||||
'metrics': metrics_results}
|
||||
|
@ -314,6 +314,34 @@ class TestGeneratorMethods(ForkRobustTestCase):
|
||||
model.evaluate(ones_generator(), steps=2)
|
||||
model.predict(ones_generator(), steps=2)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_invalid_batch_size_argument(self):
|
||||
|
||||
def ones_generator():
|
||||
while True:
|
||||
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
|
||||
|
||||
model = testing_utils.get_small_mlp(
|
||||
num_hidden=10, num_classes=1, input_dim=10)
|
||||
|
||||
model.compile(
|
||||
'adam',
|
||||
'binary_crossentropy',
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.fit(ones_generator(), batch_size=2, epochs=2)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.evaluate(ones_generator(), batch_size=2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.predict(ones_generator(), batch_size=2)
|
||||
|
||||
|
||||
class TestGeneratorMethodsWithSequences(ForkRobustTestCase):
|
||||
|
||||
|
@ -1131,12 +1131,6 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
'incompatible with the specified batch size'):
|
||||
model.fit(x, y, batch_size=4)
|
||||
|
||||
data = dataset_ops.DatasetV2.from_tensor_slices((x, y))
|
||||
data = data.batch(4, drop_remainder=True)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'incompatible with the specified batch size'):
|
||||
model.fit(data, steps_per_epoch=16)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_compatible_batch_size_functional_model(self):
|
||||
|
||||
@ -1563,11 +1557,10 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
'sgd',
|
||||
loss='mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
experimental_run_tf_function=False)
|
||||
err_msg = 'When passing input data as arrays, do not specify'
|
||||
|
||||
if testing_utils.should_run_eagerly(
|
||||
) and not model._experimental_run_tf_function:
|
||||
if testing_utils.should_run_eagerly():
|
||||
with self.assertRaisesRegex(ValueError, err_msg):
|
||||
model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4)
|
||||
|
||||
@ -1581,11 +1574,42 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
model._standardize_user_data(
|
||||
np.zeros((100, 1)),
|
||||
np.ones((100, 1)),
|
||||
batch_size=25,
|
||||
check_steps=True,
|
||||
steps=4)
|
||||
self.assertRegexpMatches(str(mock_log.call_args), err_msg)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_invalid_batch_size_argument_with_sequence_input(self):
|
||||
|
||||
class DummySequence(keras.utils.Sequence):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return np.zeros([10, 2]), np.ones([10, 4])
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
model = testing_utils.get_small_mlp(
|
||||
num_hidden=10, num_classes=1, input_dim=10)
|
||||
|
||||
model.compile(
|
||||
'adam',
|
||||
'binary_crossentropy',
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.fit(DummySequence(), batch_size=2, epochs=2)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.evaluate(DummySequence(), batch_size=2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'The `batch_size` argument must not be specified'):
|
||||
model.predict(DummySequence(), batch_size=2)
|
||||
|
||||
|
||||
class LossWeightingTest(keras_parameterized.TestCase):
|
||||
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.engine import training_v2_utils
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
@ -148,11 +149,15 @@ def run_one_epoch(model,
|
||||
batch_logs['data_exhausted'] = True
|
||||
break
|
||||
|
||||
if not isinstance(batch_outs, list):
|
||||
batch_outs = [batch_outs]
|
||||
if strategy:
|
||||
batch_outs = dist_utils._per_replica_aggregate_batch(
|
||||
strategy, batch_outs, model, mode)
|
||||
if mode != ModeKeys.PREDICT:
|
||||
data_batch_size = batch_outs['batch_size']
|
||||
batch_outs = (batch_outs['total_loss'] + batch_outs['output_losses']
|
||||
+ batch_outs['metrics'])
|
||||
if current_batch_size != data_batch_size:
|
||||
batch_logs['size'] = data_batch_size
|
||||
current_batch_size = data_batch_size
|
||||
else:
|
||||
batch_outs = _aggregate_predict_results(strategy, batch_outs, model)
|
||||
|
||||
if step == 0:
|
||||
aggregator.create(batch_outs)
|
||||
@ -215,6 +220,9 @@ class Loop(training_utils.TrainingLoop):
|
||||
use_sample = total_samples is not None
|
||||
do_validation = (validation_adapter is not None)
|
||||
|
||||
# TODO(psv): Add step inference for when steps/val_steps is None to
|
||||
# prevent end of sequence warning message.
|
||||
|
||||
if not steps_per_epoch:
|
||||
steps_per_epoch = training_data_adapter.get_size()
|
||||
|
||||
@ -261,21 +269,28 @@ class Loop(training_utils.TrainingLoop):
|
||||
epochs=0)
|
||||
validation_dataset = strategy.experimental_distribute_dataset(
|
||||
validation_dataset)
|
||||
val_total_samples = _get_total_number_of_samples(validation_adapter)
|
||||
else:
|
||||
val_total_samples = None
|
||||
|
||||
callbacks = cbks.configure_callbacks(
|
||||
if verbose and (total_samples or steps_per_epoch):
|
||||
_print_train_info(total_samples, steps_per_epoch, val_total_samples,
|
||||
validation_steps)
|
||||
|
||||
training_callbacks = cbks.configure_callbacks(
|
||||
callbacks,
|
||||
model,
|
||||
do_validation=do_validation,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
samples=total_samples,
|
||||
samples=total_samples or steps_per_epoch,
|
||||
count_mode='samples' if use_sample else 'steps',
|
||||
verbose=0, # Handle ProgBarLogger separately in this loop.
|
||||
mode=ModeKeys.TRAIN)
|
||||
|
||||
with training_context.on_start(
|
||||
model, callbacks, use_sample, verbose, ModeKeys.TRAIN):
|
||||
with training_context.on_start(model, training_callbacks, use_sample,
|
||||
verbose, ModeKeys.TRAIN):
|
||||
# TODO(scottzhu): Handle TPUStrategy training loop
|
||||
for epoch in range(initial_epoch, epochs):
|
||||
if training_context.callbacks.model.stop_training:
|
||||
@ -310,7 +325,7 @@ class Loop(training_utils.TrainingLoop):
|
||||
# Evaluation
|
||||
if (do_validation and
|
||||
training_utils.should_run_validation(validation_freq, epoch) and
|
||||
not callbacks.model.stop_training):
|
||||
not training_callbacks.model.stop_training):
|
||||
if (eval_data_iter is not None and
|
||||
distribution_strategy_context.has_strategy()):
|
||||
# TODO(kaftan): remove this when MultiDeviceIterator is a
|
||||
@ -319,11 +334,24 @@ class Loop(training_utils.TrainingLoop):
|
||||
else:
|
||||
eval_data_iter = iter(validation_dataset)
|
||||
|
||||
val_total_samples = _get_total_number_of_samples(
|
||||
validation_adapter)
|
||||
validation_callbacks = cbks.configure_callbacks(
|
||||
training_callbacks,
|
||||
model,
|
||||
batch_size=batch_size,
|
||||
epochs=1,
|
||||
steps_per_epoch=validation_steps,
|
||||
samples=val_total_samples or validation_steps,
|
||||
count_mode='samples' if use_sample else 'steps',
|
||||
verbose=0, # Handle ProgBarLogger separately in this loop.
|
||||
mode=ModeKeys.TEST)
|
||||
|
||||
eval_context = TrainingContext()
|
||||
with eval_context.on_start(
|
||||
model, callbacks, use_sample, verbose=0, mode=ModeKeys.TEST):
|
||||
model,
|
||||
validation_callbacks,
|
||||
use_sample,
|
||||
verbose=0,
|
||||
mode=ModeKeys.TEST):
|
||||
with eval_context.on_epoch(epoch, ModeKeys.TEST):
|
||||
model.reset_metrics()
|
||||
eval_result = run_one_epoch(
|
||||
@ -469,10 +497,12 @@ def _process_training_inputs(model, x, y, batch_size=None,
|
||||
# Retrieve the training section from x and y, and then construct dataset
|
||||
# from it.
|
||||
x, y, sample_weights = model._standardize_user_data(
|
||||
x, y, sample_weight=sample_weights,
|
||||
x,
|
||||
y,
|
||||
sample_weight=sample_weights,
|
||||
class_weight=class_weights,
|
||||
batch_size=batch_size,
|
||||
check_steps=True,
|
||||
check_steps=False,
|
||||
steps=steps_per_epoch)
|
||||
(x, y, sample_weights,
|
||||
val_x, val_y,
|
||||
@ -525,7 +555,7 @@ def _process_inputs(model, x, y, batch_size=None, sample_weights=None,
|
||||
sample_weight=sample_weights,
|
||||
class_weight=class_weights,
|
||||
batch_size=batch_size,
|
||||
check_steps=True,
|
||||
check_steps=False,
|
||||
steps=steps)
|
||||
adapter = adapter_cls(x, y, batch_size=batch_size, steps=steps,
|
||||
sample_weights=sample_weights, shuffle=shuffle,
|
||||
@ -546,6 +576,30 @@ def _get_total_number_of_samples(adapter):
|
||||
return total_sample
|
||||
|
||||
|
||||
def _aggregate_predict_results(strategy, batch_outs, model):
|
||||
if not isinstance(batch_outs, list):
|
||||
batch_outs = [batch_outs]
|
||||
total_batch_outs = []
|
||||
for i in range(len(model.outputs)):
|
||||
num_replicas = strategy.num_replicas_in_sync
|
||||
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
|
||||
total_batch_outs.append(
|
||||
dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs)))
|
||||
return total_batch_outs
|
||||
|
||||
|
||||
def _print_train_info(total_samples, steps, val_total_samples, val_steps):
|
||||
increment = 'samples' if total_samples else 'steps'
|
||||
conjunction = 'on' if total_samples else 'for'
|
||||
msg = 'Train {} {} {}'.format(conjunction, total_samples or steps, increment)
|
||||
if val_total_samples or val_steps:
|
||||
increment = 'samples' if val_total_samples else 'steps'
|
||||
conjunction = 'on' if val_total_samples else 'for'
|
||||
msg += ', validate {} {} {}'.format(conjunction, val_total_samples or
|
||||
val_steps, increment)
|
||||
print(msg)
|
||||
|
||||
|
||||
class TrainingContext(object):
|
||||
"""Utility object that wrap around callbacks and progress bars."""
|
||||
|
||||
|
@ -28,15 +28,17 @@ import functools
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework.ops import composite_tensor
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
||||
from tensorflow.python.keras.engine import data_adapter
|
||||
from tensorflow.python.keras.engine import training_eager
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def _get_or_make_execution_function(model, mode):
|
||||
@ -70,8 +72,8 @@ def _make_execution_function(model, mode):
|
||||
outputs = strategy.experimental_run_v2(
|
||||
per_replica_function, args=(model, x, y, sample_weights))
|
||||
# Out of PerReplica outputs reduce or pick values to return.
|
||||
all_outputs = dist_utils.unwrap_outputs(
|
||||
strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
|
||||
all_outputs = dist_utils.unwrap_output_dict(
|
||||
strategy, outputs, mode)
|
||||
return all_outputs
|
||||
|
||||
if not model.run_eagerly:
|
||||
@ -80,8 +82,8 @@ def _make_execution_function(model, mode):
|
||||
|
||||
def execution_function(input_fn):
|
||||
# `numpy` translates Tensors to values in Eager mode.
|
||||
return [_non_none_constant_value(out)
|
||||
for out in distributed_function(input_fn)]
|
||||
return nest.map_structure(_non_none_constant_value,
|
||||
distributed_function(input_fn))
|
||||
|
||||
return execution_function
|
||||
|
||||
@ -192,43 +194,6 @@ def _prepare_model_with_inputs(model, dataset):
|
||||
model.sample_weight_mode)
|
||||
|
||||
|
||||
def should_fallback_to_v1_for_callback(inputs, callbacks):
|
||||
"""Whether to fallback to v1 training loop because of callbacks.
|
||||
|
||||
This is only a temporary solution until the v2 training loop is fixed for
|
||||
using batch based callbacks.
|
||||
|
||||
Args:
|
||||
inputs: the inputs to the model. Certain input type might not handle certain
|
||||
callbacks well if it need batch based counting.
|
||||
callbacks: list of callbacks configured for the fit/eval/predict.
|
||||
|
||||
Returns:
|
||||
boolean, whether it should fallbacks to use v1 training loop.
|
||||
"""
|
||||
try:
|
||||
adapter_cls = data_adapter.select_data_adapter(inputs, None)
|
||||
if adapter_cls not in (data_adapter.GeneratorDataAdapter,
|
||||
data_adapter.DatasetAdapter):
|
||||
# For any input data that we know the overall size, eg numpy, list of
|
||||
# list, etc, we don't need to fallback since the v2 loop can get the batch
|
||||
# size.
|
||||
return False
|
||||
except ValueError:
|
||||
# In case we can't find the adapter, then we should fallback to v1.
|
||||
return True
|
||||
|
||||
callbacks = callbacks or []
|
||||
for c in callbacks:
|
||||
if isinstance(c, cbks.ModelCheckpoint) and isinstance(c.save_freq, int):
|
||||
return True
|
||||
elif (isinstance(c, cbks.TensorBoard) and
|
||||
isinstance(c.update_freq, int) and
|
||||
c.update_freq > 1): # This is a implementation detail for TB.
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def train_on_batch(
|
||||
model,
|
||||
x,
|
||||
@ -286,7 +251,7 @@ def train_on_batch(
|
||||
x, y, sample_weights = model._standardize_user_data(
|
||||
x, y, sample_weight=sample_weight, class_weight=class_weight,
|
||||
extract_tensors_from_dataset=True)
|
||||
|
||||
batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
|
||||
# If `model._distribution_strategy` is True, then we are in a replica context
|
||||
# at this point because of the check above. `train_on_batch` is being run
|
||||
# for each replica by `model._distribution_strategy` and the same code path
|
||||
@ -301,6 +266,7 @@ def train_on_batch(
|
||||
if reset_metrics:
|
||||
model.reset_metrics()
|
||||
|
||||
outputs['batch_size'] = math_ops.cast(batch_size, dtypes.int64)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -352,6 +318,7 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
|
||||
x, y, sample_weights = model._standardize_user_data(
|
||||
x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
|
||||
|
||||
batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
|
||||
outputs = training_eager.test_on_batch(
|
||||
model,
|
||||
x,
|
||||
@ -362,6 +329,7 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
|
||||
if reset_metrics:
|
||||
model.reset_metrics()
|
||||
|
||||
outputs['batch_size'] = math_ops.cast(batch_size, dtypes.int64)
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -1,106 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for training utility functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras.engine import training_v2_utils
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TestSequence(data_utils.Sequence):
|
||||
|
||||
def __init__(self, batch_size, feature_shape):
|
||||
self.batch_size = batch_size
|
||||
self.feature_shape = feature_shape
|
||||
|
||||
def __getitem__(self, item):
|
||||
return (np.zeros((self.batch_size, self.feature_shape)),
|
||||
np.ones((self.batch_size,)))
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
|
||||
class CallbackFallbackTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CallbackFallbackTest, self).setUp()
|
||||
self.batch_size = 5
|
||||
self.numpy_input = np.zeros((50, 10))
|
||||
self.numpy_target = np.ones(50)
|
||||
self.tensor_input = constant_op.constant(2.0, shape=(50, 10))
|
||||
self.tensor_target = array_ops.ones((50,))
|
||||
self.dataset_input = dataset_ops.DatasetV2.from_tensor_slices(
|
||||
(self.numpy_input, self.numpy_target)).shuffle(50).batch(
|
||||
self.batch_size)
|
||||
|
||||
def generator():
|
||||
yield (np.zeros((self.batch_size, 10)), np.ones(self.batch_size))
|
||||
self.generator_input = generator()
|
||||
self.sequence_input = TestSequence(batch_size=self.batch_size,
|
||||
feature_shape=10)
|
||||
|
||||
self.fallback_ckeckpoint_cb = cbks.ModelCheckpoint(
|
||||
self.get_temp_dir(), save_freq=10)
|
||||
self.normal_checkpoint_cb = cbks.ModelCheckpoint(
|
||||
self.get_temp_dir(), save_freq='epoch')
|
||||
self.fallback_tensorboard_cb = cbks.TensorBoard(update_freq=10)
|
||||
self.normal_tensorboard_cb = cbks.TensorBoard(update_freq='batch')
|
||||
self.unaffected_cb = cbks.CSVLogger(self.get_temp_dir())
|
||||
|
||||
def test_not_fallback_based_on_input(self):
|
||||
callback_list = [self.fallback_ckeckpoint_cb]
|
||||
|
||||
test_cases = [
|
||||
[(self.numpy_input, self.numpy_target), False],
|
||||
[[self.tensor_input, self.tensor_target], False],
|
||||
[self.sequence_input, False],
|
||||
[self.dataset_input, True],
|
||||
[self.generator_input, True],
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
inputs, expected_result = case
|
||||
self.assertEqual(training_v2_utils.should_fallback_to_v1_for_callback(
|
||||
inputs, callback_list), expected_result)
|
||||
|
||||
def test_fallback_based_on_callbacks(self):
|
||||
inputs = self.dataset_input
|
||||
test_cases = [
|
||||
[[self.fallback_ckeckpoint_cb], True],
|
||||
[[self.normal_checkpoint_cb], False],
|
||||
[[self.fallback_ckeckpoint_cb, self.normal_checkpoint_cb], True],
|
||||
[[self.fallback_tensorboard_cb], True],
|
||||
[[self.normal_tensorboard_cb], False],
|
||||
[[self.unaffected_cb], False],
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
callbacks, expected_result = case
|
||||
self.assertEqual(training_v2_utils.should_fallback_to_v1_for_callback(
|
||||
inputs, callbacks), expected_result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
x
Reference in New Issue
Block a user