Do not wrap a single x input passed to Model.fit in a tuple

This makes Numpy and Tensor array behavior consistent with the behavior when passing a Dataset, generator, or Keras Sequence object with a single Tensor. In all cases, the single Tensor is passed to Model.train_step directly without modification

PiperOrigin-RevId: 315431108
Change-Id: I65d7a57967ffa54ae7786029d235c7f3c37da80f
This commit is contained in:
Thomas O'Malley 2020-06-08 23:59:21 -07:00 committed by TensorFlower Gardener
parent 13c09da422
commit d570632ba8
3 changed files with 68 additions and 7 deletions

View File

@ -1500,12 +1500,13 @@ def pack_x_y_sample_weight(x, y=None, sample_weight=None):
>>> x = tf.ones((10, 1))
>>> data = tf.keras.utils.pack_x_y_sample_weight(x)
>>> len(data)
1
>>> isinstance(data, tf.Tensor)
True
>>> y = tf.ones((10, 1))
>>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
>>> len(data)
2
>>> isinstance(data, tuple)
True
>>> x, y = data
Arguments:
x: Features to pass to `Model`.
@ -1516,6 +1517,13 @@ def pack_x_y_sample_weight(x, y=None, sample_weight=None):
Tuple in the format used in `Model.fit`.
"""
if y is None:
# For single x-input, we do no tuple wrapping since in this case
# there is no ambiguity. This also makes NumPy and Dataset
# consistent in that the user does not have to wrap their Dataset
# data in an unecessary tuple
if not nest.is_sequence(x):
return x
else:
return (x,)
elif sample_weight is None:
return (x, y)

View File

@ -278,7 +278,7 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
def _get_epoch(ds_iter):
ds_data = []
for _ in range(int(math.ceil(num_samples / batch_size))):
ds_data.append(next(ds_iter)[0].numpy())
ds_data.append(next(ds_iter).numpy())
return np.concatenate(ds_data)
ds_iter = iter(adapter.get_dataset())
@ -507,7 +507,7 @@ class GenericArrayLikeDataAdapterTest(DataAdapterTestBase):
def _get_epoch(ds_iter):
ds_data = []
for _ in range(int(math.ceil(num_samples / batch_size))):
ds_data.append(next(ds_iter)[0].numpy())
ds_data.append(next(ds_iter).numpy())
return np.concatenate(ds_data)
ds_iter = iter(adapter.get_dataset())
@ -981,6 +981,22 @@ class DataHandlerTest(keras_parameterized.TestCase):
2: 1.5
})
@parameterized.named_parameters(('numpy', True), ('dataset', False))
def test_single_x_input_no_tuple_wrapping(self, use_numpy):
x = np.ones((10, 1))
if use_numpy:
batch_size = 2
else:
x = dataset_ops.Dataset.from_tensor_slices(x).batch(2)
batch_size = None
data_handler = data_adapter.DataHandler(x, batch_size=batch_size)
for _, iterator in data_handler.enumerate_epochs():
for _ in data_handler.steps():
# Check that single x input is not wrapped in a tuple.
self.assertIsInstance(next(iterator), ops.Tensor)
class TestValidationSplit(keras_parameterized.TestCase):

View File

@ -1558,6 +1558,43 @@ class TrainingTest(keras_parameterized.TestCase):
# assign_add not called.
self.assertEqual(self.evaluate(layer.v), 1.)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@parameterized.named_parameters(('numpy', True), ('dataset', False))
def test_single_input_no_tuple_wrapping(self, use_numpy):
x = np.ones((10, 1))
if use_numpy:
batch_size = 3
else:
x = dataset_ops.Dataset.from_tensor_slices(x).batch(3)
batch_size = None
test_case = self
class MyModel(training_module.Model):
def train_step(self, data):
# No tuple wrapping for single x input and no targets.
test_case.assertIsInstance(data, ops.Tensor)
return super(MyModel, self).train_step(data)
def test_step(self, data):
test_case.assertIsInstance(data, ops.Tensor)
return super(MyModel, self).test_step(data)
def predict_step(self, data):
test_case.assertIsInstance(data, ops.Tensor)
return super(MyModel, self).predict_step(data)
inputs = layers_module.Input(1)
outputs = layers_module.Dense(1)(inputs)
model = MyModel(inputs, outputs)
model.add_loss(math_ops.reduce_sum(outputs))
model.compile('sgd', 'mse')
model.fit(x, batch_size=batch_size)
model.evaluate(x, batch_size=batch_size)
model.predict(x, batch_size=batch_size)
class TestExceptionsAndWarnings(keras_parameterized.TestCase):