Do not wrap a single x input passed to Model.fit in a tuple
This makes Numpy and Tensor array behavior consistent with the behavior when passing a Dataset, generator, or Keras Sequence object with a single Tensor. In all cases, the single Tensor is passed to Model.train_step directly without modification PiperOrigin-RevId: 315431108 Change-Id: I65d7a57967ffa54ae7786029d235c7f3c37da80f
This commit is contained in:
parent
13c09da422
commit
d570632ba8
@ -1500,12 +1500,13 @@ def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
||||
|
||||
>>> x = tf.ones((10, 1))
|
||||
>>> data = tf.keras.utils.pack_x_y_sample_weight(x)
|
||||
>>> len(data)
|
||||
1
|
||||
>>> isinstance(data, tf.Tensor)
|
||||
True
|
||||
>>> y = tf.ones((10, 1))
|
||||
>>> data = tf.keras.utils.pack_x_y_sample_weight(x, y)
|
||||
>>> len(data)
|
||||
2
|
||||
>>> isinstance(data, tuple)
|
||||
True
|
||||
>>> x, y = data
|
||||
|
||||
Arguments:
|
||||
x: Features to pass to `Model`.
|
||||
@ -1516,6 +1517,13 @@ def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
||||
Tuple in the format used in `Model.fit`.
|
||||
"""
|
||||
if y is None:
|
||||
# For single x-input, we do no tuple wrapping since in this case
|
||||
# there is no ambiguity. This also makes NumPy and Dataset
|
||||
# consistent in that the user does not have to wrap their Dataset
|
||||
# data in an unecessary tuple
|
||||
if not nest.is_sequence(x):
|
||||
return x
|
||||
else:
|
||||
return (x,)
|
||||
elif sample_weight is None:
|
||||
return (x, y)
|
||||
|
@ -278,7 +278,7 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
|
||||
def _get_epoch(ds_iter):
|
||||
ds_data = []
|
||||
for _ in range(int(math.ceil(num_samples / batch_size))):
|
||||
ds_data.append(next(ds_iter)[0].numpy())
|
||||
ds_data.append(next(ds_iter).numpy())
|
||||
return np.concatenate(ds_data)
|
||||
|
||||
ds_iter = iter(adapter.get_dataset())
|
||||
@ -507,7 +507,7 @@ class GenericArrayLikeDataAdapterTest(DataAdapterTestBase):
|
||||
def _get_epoch(ds_iter):
|
||||
ds_data = []
|
||||
for _ in range(int(math.ceil(num_samples / batch_size))):
|
||||
ds_data.append(next(ds_iter)[0].numpy())
|
||||
ds_data.append(next(ds_iter).numpy())
|
||||
return np.concatenate(ds_data)
|
||||
|
||||
ds_iter = iter(adapter.get_dataset())
|
||||
@ -981,6 +981,22 @@ class DataHandlerTest(keras_parameterized.TestCase):
|
||||
2: 1.5
|
||||
})
|
||||
|
||||
@parameterized.named_parameters(('numpy', True), ('dataset', False))
|
||||
def test_single_x_input_no_tuple_wrapping(self, use_numpy):
|
||||
x = np.ones((10, 1))
|
||||
|
||||
if use_numpy:
|
||||
batch_size = 2
|
||||
else:
|
||||
x = dataset_ops.Dataset.from_tensor_slices(x).batch(2)
|
||||
batch_size = None
|
||||
|
||||
data_handler = data_adapter.DataHandler(x, batch_size=batch_size)
|
||||
for _, iterator in data_handler.enumerate_epochs():
|
||||
for _ in data_handler.steps():
|
||||
# Check that single x input is not wrapped in a tuple.
|
||||
self.assertIsInstance(next(iterator), ops.Tensor)
|
||||
|
||||
|
||||
class TestValidationSplit(keras_parameterized.TestCase):
|
||||
|
||||
|
@ -1558,6 +1558,43 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
# assign_add not called.
|
||||
self.assertEqual(self.evaluate(layer.v), 1.)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@parameterized.named_parameters(('numpy', True), ('dataset', False))
|
||||
def test_single_input_no_tuple_wrapping(self, use_numpy):
|
||||
x = np.ones((10, 1))
|
||||
|
||||
if use_numpy:
|
||||
batch_size = 3
|
||||
else:
|
||||
x = dataset_ops.Dataset.from_tensor_slices(x).batch(3)
|
||||
batch_size = None
|
||||
|
||||
test_case = self
|
||||
|
||||
class MyModel(training_module.Model):
|
||||
|
||||
def train_step(self, data):
|
||||
# No tuple wrapping for single x input and no targets.
|
||||
test_case.assertIsInstance(data, ops.Tensor)
|
||||
return super(MyModel, self).train_step(data)
|
||||
|
||||
def test_step(self, data):
|
||||
test_case.assertIsInstance(data, ops.Tensor)
|
||||
return super(MyModel, self).test_step(data)
|
||||
|
||||
def predict_step(self, data):
|
||||
test_case.assertIsInstance(data, ops.Tensor)
|
||||
return super(MyModel, self).predict_step(data)
|
||||
|
||||
inputs = layers_module.Input(1)
|
||||
outputs = layers_module.Dense(1)(inputs)
|
||||
model = MyModel(inputs, outputs)
|
||||
model.add_loss(math_ops.reduce_sum(outputs))
|
||||
model.compile('sgd', 'mse')
|
||||
model.fit(x, batch_size=batch_size)
|
||||
model.evaluate(x, batch_size=batch_size)
|
||||
model.predict(x, batch_size=batch_size)
|
||||
|
||||
|
||||
class TestExceptionsAndWarnings(keras_parameterized.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user