Merge pull request #31663 from tensorflow/ggadde-cp9

r2.0 Cherrypick: bugs fixes and test fixes.
This commit is contained in:
Goldie Gadde 2019-08-15 17:06:13 -07:00 committed by GitHub
commit 1ad1f90069
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 265 additions and 237 deletions

View File

@ -306,13 +306,16 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
output_saved_model_dir=self.mkdtemp(),
need_calibration=need_calibration)
def _CreateConverterV2(self,
input_saved_model_dir,
precision_mode=trt_convert.TrtPrecisionMode.FP32):
def _CreateConverterV2(
self,
input_saved_model_dir,
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
precision_mode=trt_convert.TrtPrecisionMode.FP32):
return trt_convert.TrtGraphConverterV2(
input_saved_model_dir=input_saved_model_dir,
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
input_saved_model_signature_key=input_saved_model_signature_key,
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
max_workspace_size_bytes=10 << 20, # Use a smaller workspace.
precision_mode=precision_mode,
is_dynamic_op=True,
maximum_cached_engines=2))
@ -493,8 +496,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
root = model_class()
save.save(root, input_saved_model_dir)
converter = trt_convert.TrtGraphConverterV2(
input_saved_model_dir=input_saved_model_dir)
converter = self._CreateConverterV2(
input_saved_model_dir, input_saved_model_signature_key=signature_key)
converter.convert()
output_saved_model_dir = self.mkdtemp()
converter.save(output_saved_model_dir)

View File

@ -1577,22 +1577,6 @@ tf_py_test(
],
)
tf_py_test(
name = "training_v2_utils_test",
size = "medium",
srcs = ["engine/training_v2_utils_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
tags = [
"nomac", # TODO(mihaimaruseac): b/127695564
"notsan",
],
)
tf_py_test(
name = "model_subclassing_test",
size = "medium",

View File

@ -143,6 +143,7 @@ class CallbackCountsTest(keras_parameterized.TestCase):
def test_callback_hooks_are_called_in_fit(self, data):
x, y = data
val_x, val_y = np.ones((4, 10)), np.ones((4, 1))
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
model = self._get_model()
counter = Counter()
@ -150,7 +151,8 @@ class CallbackCountsTest(keras_parameterized.TestCase):
x,
y,
validation_data=(val_x, val_y),
batch_size=2,
batch_size=2 if not is_sequence else None,
steps_per_epoch=5 if is_sequence else None,
epochs=5,
callbacks=[counter])
@ -178,10 +180,16 @@ class CallbackCountsTest(keras_parameterized.TestCase):
('with_sequence', _get_sequence()))
def test_callback_hooks_are_called_in_evaluate(self, data):
x, y = data
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
model = self._get_model()
counter = Counter()
model.evaluate(x, y, batch_size=2, callbacks=[counter])
model.evaluate(
x,
y,
batch_size=2 if not is_sequence else None,
steps=5 if is_sequence else None,
callbacks=[counter])
self._check_counts(
counter, {
'on_test_batch_begin': 5,
@ -194,10 +202,15 @@ class CallbackCountsTest(keras_parameterized.TestCase):
('with_sequence', _get_sequence()))
def test_callback_hooks_are_called_in_predict(self, data):
x = data[0]
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
model = self._get_model()
counter = Counter()
model.predict(x, batch_size=2, callbacks=[counter])
model.predict(
x,
batch_size=2 if not is_sequence else None,
steps=5 if is_sequence else None,
callbacks=[counter])
self._check_counts(
counter, {
'on_predict_batch_begin': 5,
@ -270,6 +283,21 @@ class KerasCallbacksTest(keras_parameterized.TestCase):
model.fit(dataset, epochs=2, steps_per_epoch=10)
self.assertRegexpMatches(printed.contents(), expected_log)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_progbar_logging_validation_data(self):
model = self._get_model(input_shape=(3,))
x = array_ops.ones((50, 3))
y = array_ops.zeros((50, 2))
training_dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
val_dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
expected_log = r'(.*5/5.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*)+'
with self.captureWritesToStream(sys.stdout) as printed:
model.fit(training_dataset, epochs=2, validation_data=val_dataset)
self.assertRegexpMatches(printed.contents(), expected_log)
@keras_parameterized.run_with_all_model_types
def test_ModelCheckpoint(self):
if h5py is None:

View File

@ -137,6 +137,38 @@ def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
return all_inputs, all_outputs, all_updates, all_session_args
def unwrap_output_dict(strategy, grouped_outputs, mode):
"""Unwrap the list of outputs contained in the PerReplica parameters."""
if mode == ModeKeys.PREDICT:
return flatten_per_replica_values(strategy, grouped_outputs)
# In the case of fit/eval, the grouped_outputs is a dict, whereas in predict,
# the output is as same structure as model output. They need to be treated
# differently
total_loss = strategy.reduce(reduce_util.ReduceOp.SUM,
grouped_outputs['total_loss'][0], axis=None)
output_losses = flatten_per_replica_values(strategy,
grouped_outputs['output_losses'])
metrics = flatten_per_replica_values(strategy,
grouped_outputs['metrics'])
batch_size = strategy.reduce(reduce_util.ReduceOp.SUM,
grouped_outputs['batch_size'], axis=None)
if (is_tpu_strategy(strategy) and
ops.executing_eagerly_outside_functions()):
# Choose 1 value per replica in the TPU case since all replicas produce the
# same output.
# We only do this in eager mode for now since this function is used in
# both graph and eager mode and in the graph case we currently don't use
# experimental_run so would need to be removed when we converge the graph
# code path as well.
output_losses = output_losses[::strategy.num_replicas_in_sync]
metrics = metrics[::strategy.num_replicas_in_sync]
return {'total_loss': [total_loss],
'output_losses': output_losses,
'metrics': metrics,
'batch_size': batch_size}
def unwrap_outputs(distribution_strategy, grouped_outputs,
with_loss_tensor=False):
"""Unwrap the list of outputs contained in the PerReplica parameters.

View File

@ -70,7 +70,7 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
hist = model.fit(inputs, output, epochs=5)
else:
hist = model.fit(get_dataset(), epochs=5)
self.assertLess(hist.history['loss'][4], 0.1)
self.assertLess(hist.history['loss'][4], 0.2)
@combinations.generate(strategy_combinations_eager_data_fn())
def test_wide_deep_model(self, distribution, data_fn):

View File

@ -483,7 +483,7 @@ class Model(network.Network):
def run_eagerly(self, value):
self._run_eagerly = value
def _select_training_loop(self, inputs, callbacks):
def _select_training_loop(self, inputs):
"""Select training loop for fit/eval/predict based on the inputs."""
# TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely
# integrated into the data adapters in the v2 loop. We can't do this yet
@ -501,9 +501,7 @@ class Model(network.Network):
if (context.executing_eagerly()
and self._experimental_run_tf_function
and not distributed_training_utils.is_tpu_strategy(
self._distribution_strategy)
and not training_v2_utils.should_fallback_to_v1_for_callback(
inputs, callbacks)):
self._distribution_strategy)):
try:
valid_adapter = data_adapter.select_data_adapter(inputs, None)
except ValueError as data_failure_exception:
@ -713,7 +711,7 @@ class Model(network.Network):
self._assert_compile_was_called()
self._check_call_args('fit')
func = self._select_training_loop(x, callbacks)
func = self._select_training_loop(x)
return func.fit(
self,
x=x,
@ -826,7 +824,7 @@ class Model(network.Network):
self._assert_compile_was_called()
self._check_call_args('evaluate')
func = self._select_training_loop(x, callbacks)
func = self._select_training_loop(x)
return func.evaluate(
self,
x=x,
@ -904,7 +902,7 @@ class Model(network.Network):
_keras_api_gauge.get_cell('predict').set(True)
self._check_call_args('predict')
func = self._select_training_loop(x, callbacks)
func = self._select_training_loop(x)
return func.predict(
self,
x=x,
@ -979,6 +977,8 @@ class Model(network.Network):
outputs = training_v2_utils.train_on_batch(
self, x, y=y, sample_weight=sample_weight,
class_weight=class_weight, reset_metrics=reset_metrics)
outputs = (outputs['total_loss'] + outputs['output_losses'] +
outputs['metrics'])
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
if len(outputs) == 1:
@ -1002,12 +1002,14 @@ class Model(network.Network):
# for each replica by `self._distribution_strategy` and the same code path
# as Eager is expected to be taken.
if self.run_eagerly or self._distribution_strategy:
outputs = training_eager.train_on_batch(
output_dict = training_eager.train_on_batch(
self,
x,
y,
sample_weights=sample_weights,
output_loss_metrics=self._output_loss_metrics)
outputs = (output_dict['total_loss'] + output_dict['output_losses']
+ output_dict['metrics'])
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
else:
@ -1072,6 +1074,8 @@ class Model(network.Network):
outputs = training_v2_utils.test_on_batch(
self, x, y=y, sample_weight=sample_weight,
reset_metrics=reset_metrics)
outputs = (outputs['total_loss'] + outputs['output_losses'] +
outputs['metrics'])
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
if len(outputs) == 1:
@ -1089,12 +1093,14 @@ class Model(network.Network):
# If `self._distribution_strategy` is True, then we are in a replica context
# at this point.
if self.run_eagerly or self._distribution_strategy:
outputs = training_eager.test_on_batch(
output_dict = training_eager.test_on_batch(
self,
x,
y,
sample_weights=sample_weights,
output_loss_metrics=self._output_loss_metrics)
outputs = (output_dict['total_loss'] + output_dict['output_losses']
+ output_dict['metrics'])
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
else:
@ -1763,9 +1769,16 @@ class Model(network.Network):
The validated batch_size, auto-inferred from the first layer if not
provided.
"""
if batch_size is not None and isinstance(x, dataset_ops.DatasetV2):
raise ValueError('The `batch_size` argument must not be specified when'
' using dataset as an input.')
if (isinstance(x, (dataset_ops.DatasetV1,
dataset_ops.DatasetV2,
data_utils.Sequence)) or
tf_inspect.isgenerator(x)):
if batch_size is not None:
raise ValueError(
'The `batch_size` argument must not be specified for the given '
'input type. Received input: {}, batch_size: {}'.format(
x, batch_size))
return
layers = super(Model, self).layers # Avoids the override in Sequential.
if layers:
@ -1819,13 +1832,7 @@ class Model(network.Network):
if steps is None:
batch_size = static_batch_size
if (batch_size is None
and steps is None
and not isinstance(x, (dataset_ops.DatasetV2,
iterator_ops.Iterator,
iterator_ops.IteratorV2,
data_utils.Sequence))
and not tf_inspect.isgenerator(x)):
if batch_size is None and steps is None:
# Backwards compatibility
batch_size = 32
return batch_size

View File

@ -129,19 +129,16 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
sample_weight=sample_weight)
# Test invalid usage
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
' must not be specified when using dataset'
' as an input.'):
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.fit(dataset, batch_size=10, epochs=1, steps_per_epoch=2,
verbose=0)
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
' must not be specified when using dataset'
' as an input.'):
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.predict(dataset, batch_size=10, steps=2, verbose=0)
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
' must not be specified when using dataset'
' as an input.'):
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.evaluate(dataset, batch_size=10, steps=2, verbose=0)
with self.assertRaisesRegexp(ValueError,

View File

@ -294,7 +294,11 @@ def train_on_batch(model,
loss values.
Returns:
total loss and the loss associated with each output.
Dict with three items:
'total_loss': list with a single tensor for overall loss,
'output_losses': list of tensors for loss corresponding to each of the
model output. Could be a empty list when model has only one output.
'metrics': list of tensors for metric specified.
"""
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
outs, total_loss, output_losses, masks = (
@ -310,9 +314,9 @@ def train_on_batch(model,
metrics_results = _eager_metrics_fn(
model, outs, targets, sample_weights=sample_weights, masks=masks)
total_loss = nest.flatten(total_loss)
results = total_loss + output_losses + metrics_results
return results
return {'total_loss': total_loss,
'output_losses': output_losses,
'metrics': metrics_results}
def test_on_batch(model,
@ -331,7 +335,11 @@ def test_on_batch(model,
loss values.
Returns:
total loss, loss and metrics associated with each output.
Dict with three items:
'total_loss': single tensor for overall loss,
'output_losses': list of tensors for loss corresponding to each of the
model output. Could be a empty list when model has only one output.
'metrics': list of tensors for metric specified.
"""
inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
@ -349,6 +357,7 @@ def test_on_batch(model,
metrics_results = _eager_metrics_fn(
model, outs, targets, sample_weights=sample_weights, masks=masks)
total_loss = nest.flatten(total_loss)
results = total_loss + output_losses + metrics_results
return results
return {'total_loss': total_loss,
'output_losses': output_losses,
'metrics': metrics_results}

View File

@ -314,6 +314,34 @@ class TestGeneratorMethods(ForkRobustTestCase):
model.evaluate(ones_generator(), steps=2)
model.predict(ones_generator(), steps=2)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_invalid_batch_size_argument(self):
def ones_generator():
while True:
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
model = testing_utils.get_small_mlp(
num_hidden=10, num_classes=1, input_dim=10)
model.compile(
'adam',
'binary_crossentropy',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.fit(ones_generator(), batch_size=2, epochs=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.evaluate(ones_generator(), batch_size=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.predict(ones_generator(), batch_size=2)
class TestGeneratorMethodsWithSequences(ForkRobustTestCase):

View File

@ -1131,12 +1131,6 @@ class TrainingTest(keras_parameterized.TestCase):
'incompatible with the specified batch size'):
model.fit(x, y, batch_size=4)
data = dataset_ops.DatasetV2.from_tensor_slices((x, y))
data = data.batch(4, drop_remainder=True)
with self.assertRaisesRegexp(ValueError,
'incompatible with the specified batch size'):
model.fit(data, steps_per_epoch=16)
@tf_test_util.run_in_graph_and_eager_modes
def test_compatible_batch_size_functional_model(self):
@ -1563,11 +1557,10 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
'sgd',
loss='mse',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
experimental_run_tf_function=False)
err_msg = 'When passing input data as arrays, do not specify'
if testing_utils.should_run_eagerly(
) and not model._experimental_run_tf_function:
if testing_utils.should_run_eagerly():
with self.assertRaisesRegex(ValueError, err_msg):
model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4)
@ -1581,11 +1574,42 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
model._standardize_user_data(
np.zeros((100, 1)),
np.ones((100, 1)),
batch_size=25,
check_steps=True,
steps=4)
self.assertRegexpMatches(str(mock_log.call_args), err_msg)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_invalid_batch_size_argument_with_sequence_input(self):
class DummySequence(keras.utils.Sequence):
def __getitem__(self, idx):
return np.zeros([10, 2]), np.ones([10, 4])
def __len__(self):
return 10
model = testing_utils.get_small_mlp(
num_hidden=10, num_classes=1, input_dim=10)
model.compile(
'adam',
'binary_crossentropy',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.fit(DummySequence(), batch_size=2, epochs=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.evaluate(DummySequence(), batch_size=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.predict(DummySequence(), batch_size=2)
class LossWeightingTest(keras_parameterized.TestCase):

View File

@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine import training_v2_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@ -148,11 +149,15 @@ def run_one_epoch(model,
batch_logs['data_exhausted'] = True
break
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
if strategy:
batch_outs = dist_utils._per_replica_aggregate_batch(
strategy, batch_outs, model, mode)
if mode != ModeKeys.PREDICT:
data_batch_size = batch_outs['batch_size']
batch_outs = (batch_outs['total_loss'] + batch_outs['output_losses']
+ batch_outs['metrics'])
if current_batch_size != data_batch_size:
batch_logs['size'] = data_batch_size
current_batch_size = data_batch_size
else:
batch_outs = _aggregate_predict_results(strategy, batch_outs, model)
if step == 0:
aggregator.create(batch_outs)
@ -215,6 +220,9 @@ class Loop(training_utils.TrainingLoop):
use_sample = total_samples is not None
do_validation = (validation_adapter is not None)
# TODO(psv): Add step inference for when steps/val_steps is None to
# prevent end of sequence warning message.
if not steps_per_epoch:
steps_per_epoch = training_data_adapter.get_size()
@ -261,21 +269,28 @@ class Loop(training_utils.TrainingLoop):
epochs=0)
validation_dataset = strategy.experimental_distribute_dataset(
validation_dataset)
val_total_samples = _get_total_number_of_samples(validation_adapter)
else:
val_total_samples = None
callbacks = cbks.configure_callbacks(
if verbose and (total_samples or steps_per_epoch):
_print_train_info(total_samples, steps_per_epoch, val_total_samples,
validation_steps)
training_callbacks = cbks.configure_callbacks(
callbacks,
model,
do_validation=do_validation,
batch_size=batch_size,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
samples=total_samples,
samples=total_samples or steps_per_epoch,
count_mode='samples' if use_sample else 'steps',
verbose=0, # Handle ProgBarLogger separately in this loop.
mode=ModeKeys.TRAIN)
with training_context.on_start(
model, callbacks, use_sample, verbose, ModeKeys.TRAIN):
with training_context.on_start(model, training_callbacks, use_sample,
verbose, ModeKeys.TRAIN):
# TODO(scottzhu): Handle TPUStrategy training loop
for epoch in range(initial_epoch, epochs):
if training_context.callbacks.model.stop_training:
@ -310,7 +325,7 @@ class Loop(training_utils.TrainingLoop):
# Evaluation
if (do_validation and
training_utils.should_run_validation(validation_freq, epoch) and
not callbacks.model.stop_training):
not training_callbacks.model.stop_training):
if (eval_data_iter is not None and
distribution_strategy_context.has_strategy()):
# TODO(kaftan): remove this when MultiDeviceIterator is a
@ -319,11 +334,24 @@ class Loop(training_utils.TrainingLoop):
else:
eval_data_iter = iter(validation_dataset)
val_total_samples = _get_total_number_of_samples(
validation_adapter)
validation_callbacks = cbks.configure_callbacks(
training_callbacks,
model,
batch_size=batch_size,
epochs=1,
steps_per_epoch=validation_steps,
samples=val_total_samples or validation_steps,
count_mode='samples' if use_sample else 'steps',
verbose=0, # Handle ProgBarLogger separately in this loop.
mode=ModeKeys.TEST)
eval_context = TrainingContext()
with eval_context.on_start(
model, callbacks, use_sample, verbose=0, mode=ModeKeys.TEST):
model,
validation_callbacks,
use_sample,
verbose=0,
mode=ModeKeys.TEST):
with eval_context.on_epoch(epoch, ModeKeys.TEST):
model.reset_metrics()
eval_result = run_one_epoch(
@ -469,10 +497,12 @@ def _process_training_inputs(model, x, y, batch_size=None,
# Retrieve the training section from x and y, and then construct dataset
# from it.
x, y, sample_weights = model._standardize_user_data(
x, y, sample_weight=sample_weights,
x,
y,
sample_weight=sample_weights,
class_weight=class_weights,
batch_size=batch_size,
check_steps=True,
check_steps=False,
steps=steps_per_epoch)
(x, y, sample_weights,
val_x, val_y,
@ -525,7 +555,7 @@ def _process_inputs(model, x, y, batch_size=None, sample_weights=None,
sample_weight=sample_weights,
class_weight=class_weights,
batch_size=batch_size,
check_steps=True,
check_steps=False,
steps=steps)
adapter = adapter_cls(x, y, batch_size=batch_size, steps=steps,
sample_weights=sample_weights, shuffle=shuffle,
@ -546,6 +576,30 @@ def _get_total_number_of_samples(adapter):
return total_sample
def _aggregate_predict_results(strategy, batch_outs, model):
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
total_batch_outs = []
for i in range(len(model.outputs)):
num_replicas = strategy.num_replicas_in_sync
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
total_batch_outs.append(
dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs)))
return total_batch_outs
def _print_train_info(total_samples, steps, val_total_samples, val_steps):
increment = 'samples' if total_samples else 'steps'
conjunction = 'on' if total_samples else 'for'
msg = 'Train {} {} {}'.format(conjunction, total_samples or steps, increment)
if val_total_samples or val_steps:
increment = 'samples' if val_total_samples else 'steps'
conjunction = 'on' if val_total_samples else 'for'
msg += ', validate {} {} {}'.format(conjunction, val_total_samples or
val_steps, increment)
print(msg)
class TrainingContext(object):
"""Utility object that wrap around callbacks and progress bars."""

View File

@ -28,15 +28,17 @@ import functools
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework.ops import composite_tensor
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
def _get_or_make_execution_function(model, mode):
@ -70,8 +72,8 @@ def _make_execution_function(model, mode):
outputs = strategy.experimental_run_v2(
per_replica_function, args=(model, x, y, sample_weights))
# Out of PerReplica outputs reduce or pick values to return.
all_outputs = dist_utils.unwrap_outputs(
strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
all_outputs = dist_utils.unwrap_output_dict(
strategy, outputs, mode)
return all_outputs
if not model.run_eagerly:
@ -80,8 +82,8 @@ def _make_execution_function(model, mode):
def execution_function(input_fn):
# `numpy` translates Tensors to values in Eager mode.
return [_non_none_constant_value(out)
for out in distributed_function(input_fn)]
return nest.map_structure(_non_none_constant_value,
distributed_function(input_fn))
return execution_function
@ -192,43 +194,6 @@ def _prepare_model_with_inputs(model, dataset):
model.sample_weight_mode)
def should_fallback_to_v1_for_callback(inputs, callbacks):
"""Whether to fallback to v1 training loop because of callbacks.
This is only a temporary solution until the v2 training loop is fixed for
using batch based callbacks.
Args:
inputs: the inputs to the model. Certain input type might not handle certain
callbacks well if it need batch based counting.
callbacks: list of callbacks configured for the fit/eval/predict.
Returns:
boolean, whether it should fallbacks to use v1 training loop.
"""
try:
adapter_cls = data_adapter.select_data_adapter(inputs, None)
if adapter_cls not in (data_adapter.GeneratorDataAdapter,
data_adapter.DatasetAdapter):
# For any input data that we know the overall size, eg numpy, list of
# list, etc, we don't need to fallback since the v2 loop can get the batch
# size.
return False
except ValueError:
# In case we can't find the adapter, then we should fallback to v1.
return True
callbacks = callbacks or []
for c in callbacks:
if isinstance(c, cbks.ModelCheckpoint) and isinstance(c.save_freq, int):
return True
elif (isinstance(c, cbks.TensorBoard) and
isinstance(c.update_freq, int) and
c.update_freq > 1): # This is a implementation detail for TB.
return True
return False
def train_on_batch(
model,
x,
@ -286,7 +251,7 @@ def train_on_batch(
x, y, sample_weights = model._standardize_user_data(
x, y, sample_weight=sample_weight, class_weight=class_weight,
extract_tensors_from_dataset=True)
batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
# If `model._distribution_strategy` is True, then we are in a replica context
# at this point because of the check above. `train_on_batch` is being run
# for each replica by `model._distribution_strategy` and the same code path
@ -301,6 +266,7 @@ def train_on_batch(
if reset_metrics:
model.reset_metrics()
outputs['batch_size'] = math_ops.cast(batch_size, dtypes.int64)
return outputs
@ -352,6 +318,7 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
x, y, sample_weights = model._standardize_user_data(
x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
outputs = training_eager.test_on_batch(
model,
x,
@ -362,6 +329,7 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
if reset_metrics:
model.reset_metrics()
outputs['batch_size'] = math_ops.cast(batch_size, dtypes.int64)
return outputs

View File

@ -1,106 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for training utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.engine import training_v2_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class TestSequence(data_utils.Sequence):
def __init__(self, batch_size, feature_shape):
self.batch_size = batch_size
self.feature_shape = feature_shape
def __getitem__(self, item):
return (np.zeros((self.batch_size, self.feature_shape)),
np.ones((self.batch_size,)))
def __len__(self):
return 10
class CallbackFallbackTest(test.TestCase):
def setUp(self):
super(CallbackFallbackTest, self).setUp()
self.batch_size = 5
self.numpy_input = np.zeros((50, 10))
self.numpy_target = np.ones(50)
self.tensor_input = constant_op.constant(2.0, shape=(50, 10))
self.tensor_target = array_ops.ones((50,))
self.dataset_input = dataset_ops.DatasetV2.from_tensor_slices(
(self.numpy_input, self.numpy_target)).shuffle(50).batch(
self.batch_size)
def generator():
yield (np.zeros((self.batch_size, 10)), np.ones(self.batch_size))
self.generator_input = generator()
self.sequence_input = TestSequence(batch_size=self.batch_size,
feature_shape=10)
self.fallback_ckeckpoint_cb = cbks.ModelCheckpoint(
self.get_temp_dir(), save_freq=10)
self.normal_checkpoint_cb = cbks.ModelCheckpoint(
self.get_temp_dir(), save_freq='epoch')
self.fallback_tensorboard_cb = cbks.TensorBoard(update_freq=10)
self.normal_tensorboard_cb = cbks.TensorBoard(update_freq='batch')
self.unaffected_cb = cbks.CSVLogger(self.get_temp_dir())
def test_not_fallback_based_on_input(self):
callback_list = [self.fallback_ckeckpoint_cb]
test_cases = [
[(self.numpy_input, self.numpy_target), False],
[[self.tensor_input, self.tensor_target], False],
[self.sequence_input, False],
[self.dataset_input, True],
[self.generator_input, True],
]
for case in test_cases:
inputs, expected_result = case
self.assertEqual(training_v2_utils.should_fallback_to_v1_for_callback(
inputs, callback_list), expected_result)
def test_fallback_based_on_callbacks(self):
inputs = self.dataset_input
test_cases = [
[[self.fallback_ckeckpoint_cb], True],
[[self.normal_checkpoint_cb], False],
[[self.fallback_ckeckpoint_cb, self.normal_checkpoint_cb], True],
[[self.fallback_tensorboard_cb], True],
[[self.normal_tensorboard_cb], False],
[[self.unaffected_cb], False],
]
for case in test_cases:
callbacks, expected_result = case
self.assertEqual(training_v2_utils.should_fallback_to_v1_for_callback(
inputs, callbacks), expected_result)
if __name__ == '__main__':
test.main()