1. Do not raise steps unsupported with numpy arrays warning message in single execution path.

2. Raise error if batch_size argument is used when input is dataset/generator/keras sequence.

PiperOrigin-RevId: 263272222
This commit is contained in:
Pavithra Vijay 2019-08-13 20:36:51 -07:00 committed by TensorFlower Gardener
parent 20180d323c
commit a59ad83d06
6 changed files with 100 additions and 35 deletions

View File

@ -147,6 +147,7 @@ class CallbackCountsTest(keras_parameterized.TestCase):
def test_callback_hooks_are_called_in_fit(self, data):
x, y = data
val_x, val_y = np.ones((4, 10)), np.ones((4, 1))
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
model = self._get_model()
counter = Counter()
@ -154,7 +155,8 @@ class CallbackCountsTest(keras_parameterized.TestCase):
x,
y,
validation_data=(val_x, val_y),
batch_size=2,
batch_size=2 if not is_sequence else None,
steps_per_epoch=5 if is_sequence else None,
epochs=5,
callbacks=[counter])
@ -182,10 +184,16 @@ class CallbackCountsTest(keras_parameterized.TestCase):
('with_sequence', _get_sequence()))
def test_callback_hooks_are_called_in_evaluate(self, data):
x, y = data
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
model = self._get_model()
counter = Counter()
model.evaluate(x, y, batch_size=2, callbacks=[counter])
model.evaluate(
x,
y,
batch_size=2 if not is_sequence else None,
steps=5 if is_sequence else None,
callbacks=[counter])
self._check_counts(
counter, {
'on_test_batch_begin': 5,
@ -198,10 +206,15 @@ class CallbackCountsTest(keras_parameterized.TestCase):
('with_sequence', _get_sequence()))
def test_callback_hooks_are_called_in_predict(self, data):
x = data[0]
is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
model = self._get_model()
counter = Counter()
model.predict(x, batch_size=2, callbacks=[counter])
model.predict(
x,
batch_size=2 if not is_sequence else None,
steps=5 if is_sequence else None,
callbacks=[counter])
self._check_counts(
counter, {
'on_predict_batch_begin': 5,

View File

@ -1801,9 +1801,16 @@ class Model(network.Network):
The validated batch_size, auto-inferred from the first layer if not
provided.
"""
if batch_size is not None and isinstance(x, dataset_ops.DatasetV2):
raise ValueError('The `batch_size` argument must not be specified when'
' using dataset as an input.')
if (isinstance(x, (dataset_ops.DatasetV1,
dataset_ops.DatasetV2,
data_utils.Sequence)) or
tf_inspect.isgenerator(x)):
if batch_size is not None:
raise ValueError(
'The `batch_size` argument must not be specified for the given '
'input type. Received input: {}, batch_size: {}'.format(
x, batch_size))
return
layers = super(Model, self).layers # Avoids the override in Sequential.
if layers:
@ -1857,13 +1864,7 @@ class Model(network.Network):
if steps is None:
batch_size = static_batch_size
if (batch_size is None
and steps is None
and not isinstance(x, (dataset_ops.DatasetV2,
iterator_ops.Iterator,
iterator_ops.IteratorV2,
data_utils.Sequence))
and not tf_inspect.isgenerator(x)):
if batch_size is None and steps is None:
# Backwards compatibility
batch_size = 32
return batch_size

View File

@ -129,19 +129,16 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
sample_weight=sample_weight)
# Test invalid usage
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
' must not be specified when using dataset'
' as an input.'):
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.fit(dataset, batch_size=10, epochs=1, steps_per_epoch=2,
verbose=0)
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
' must not be specified when using dataset'
' as an input.'):
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.predict(dataset, batch_size=10, steps=2, verbose=0)
with self.assertRaisesRegexp(ValueError, 'The `batch_size` argument'
' must not be specified when using dataset'
' as an input.'):
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.evaluate(dataset, batch_size=10, steps=2, verbose=0)
with self.assertRaisesRegexp(ValueError,

View File

@ -314,6 +314,34 @@ class TestGeneratorMethods(ForkRobustTestCase):
model.evaluate(ones_generator(), steps=2)
model.predict(ones_generator(), steps=2)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_invalid_batch_size_argument(self):
def ones_generator():
while True:
yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
model = testing_utils.get_small_mlp(
num_hidden=10, num_classes=1, input_dim=10)
model.compile(
'adam',
'binary_crossentropy',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.fit(ones_generator(), batch_size=2, epochs=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.evaluate(ones_generator(), batch_size=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.predict(ones_generator(), batch_size=2)
class TestGeneratorMethodsWithSequences(ForkRobustTestCase):

View File

@ -1131,12 +1131,6 @@ class TrainingTest(keras_parameterized.TestCase):
'incompatible with the specified batch size'):
model.fit(x, y, batch_size=4)
data = dataset_ops.DatasetV2.from_tensor_slices((x, y))
data = data.batch(4, drop_remainder=True)
with self.assertRaisesRegexp(ValueError,
'incompatible with the specified batch size'):
model.fit(data, steps_per_epoch=16)
@tf_test_util.run_in_graph_and_eager_modes
def test_compatible_batch_size_functional_model(self):
@ -1563,11 +1557,10 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
'sgd',
loss='mse',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
experimental_run_tf_function=False)
err_msg = 'When passing input data as arrays, do not specify'
if testing_utils.should_run_eagerly(
) and not model._experimental_run_tf_function:
if testing_utils.should_run_eagerly():
with self.assertRaisesRegex(ValueError, err_msg):
model.fit(x=np.zeros((100, 1)), y=np.ones((100, 1)), steps_per_epoch=4)
@ -1581,11 +1574,42 @@ class TestExceptionsAndWarnings(keras_parameterized.TestCase):
model._standardize_user_data(
np.zeros((100, 1)),
np.ones((100, 1)),
batch_size=25,
check_steps=True,
steps=4)
self.assertRegexpMatches(str(mock_log.call_args), err_msg)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_invalid_batch_size_argument_with_sequence_input(self):
class DummySequence(keras.utils.Sequence):
def __getitem__(self, idx):
return np.zeros([10, 2]), np.ones([10, 4])
def __len__(self):
return 10
model = testing_utils.get_small_mlp(
num_hidden=10, num_classes=1, input_dim=10)
model.compile(
'adam',
'binary_crossentropy',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.fit(DummySequence(), batch_size=2, epochs=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.evaluate(DummySequence(), batch_size=2)
with self.assertRaisesRegexp(
ValueError, 'The `batch_size` argument must not be specified'):
model.predict(DummySequence(), batch_size=2)
class LossWeightingTest(keras_parameterized.TestCase):

View File

@ -490,10 +490,12 @@ def _process_training_inputs(model, x, y, batch_size=None,
# Retrieve the training section from x and y, and then construct dataset
# from it.
x, y, sample_weights = model._standardize_user_data(
x, y, sample_weight=sample_weights,
x,
y,
sample_weight=sample_weights,
class_weight=class_weights,
batch_size=batch_size,
check_steps=True,
check_steps=False,
steps=steps_per_epoch)
(x, y, sample_weights,
val_x, val_y,
@ -550,7 +552,7 @@ def _process_inputs(model, x, y, batch_size=None, sample_weights=None,
sample_weight=sample_weights,
class_weight=class_weights,
batch_size=batch_size,
check_steps=True,
check_steps=False,
steps=steps)
adapter = adapter_cls(x, y, batch_size=batch_size, steps=steps,
sample_weights=sample_weights, shuffle=shuffle,