From d570632ba89e7352a3d08acfc2c1737f9e5983af Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Mon, 8 Jun 2020 23:59:21 -0700 Subject: [PATCH] Do not wrap a single x input passed to Model.fit in a tuple This makes Numpy and Tensor array behavior consistent with the behavior when passing a Dataset, generator, or Keras Sequence object with a single Tensor. In all cases, the single Tensor is passed to Model.train_step directly without modification PiperOrigin-RevId: 315431108 Change-Id: I65d7a57967ffa54ae7786029d235c7f3c37da80f --- .../python/keras/engine/data_adapter.py | 18 ++++++--- .../python/keras/engine/data_adapter_test.py | 20 +++++++++- .../python/keras/engine/training_test.py | 37 +++++++++++++++++++ 3 files changed, 68 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index bf0bbb7d994..469355dd722 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -1500,12 +1500,13 @@ def pack_x_y_sample_weight(x, y=None, sample_weight=None): >>> x = tf.ones((10, 1)) >>> data = tf.keras.utils.pack_x_y_sample_weight(x) - >>> len(data) - 1 + >>> isinstance(data, tf.Tensor) + True >>> y = tf.ones((10, 1)) >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y) - >>> len(data) - 2 + >>> isinstance(data, tuple) + True + >>> x, y = data Arguments: x: Features to pass to `Model`. @@ -1516,7 +1517,14 @@ def pack_x_y_sample_weight(x, y=None, sample_weight=None): Tuple in the format used in `Model.fit`. """ if y is None: - return (x,) + # For single x-input, we do no tuple wrapping since in this case + # there is no ambiguity. This also makes NumPy and Dataset + # consistent in that the user does not have to wrap their Dataset + # data in an unecessary tuple + if not nest.is_sequence(x): + return x + else: + return (x,) elif sample_weight is None: return (x, y) else: diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 3f4e6d0cb83..be9c6d79193 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -278,7 +278,7 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase): def _get_epoch(ds_iter): ds_data = [] for _ in range(int(math.ceil(num_samples / batch_size))): - ds_data.append(next(ds_iter)[0].numpy()) + ds_data.append(next(ds_iter).numpy()) return np.concatenate(ds_data) ds_iter = iter(adapter.get_dataset()) @@ -507,7 +507,7 @@ class GenericArrayLikeDataAdapterTest(DataAdapterTestBase): def _get_epoch(ds_iter): ds_data = [] for _ in range(int(math.ceil(num_samples / batch_size))): - ds_data.append(next(ds_iter)[0].numpy()) + ds_data.append(next(ds_iter).numpy()) return np.concatenate(ds_data) ds_iter = iter(adapter.get_dataset()) @@ -981,6 +981,22 @@ class DataHandlerTest(keras_parameterized.TestCase): 2: 1.5 }) + @parameterized.named_parameters(('numpy', True), ('dataset', False)) + def test_single_x_input_no_tuple_wrapping(self, use_numpy): + x = np.ones((10, 1)) + + if use_numpy: + batch_size = 2 + else: + x = dataset_ops.Dataset.from_tensor_slices(x).batch(2) + batch_size = None + + data_handler = data_adapter.DataHandler(x, batch_size=batch_size) + for _, iterator in data_handler.enumerate_epochs(): + for _ in data_handler.steps(): + # Check that single x input is not wrapped in a tuple. + self.assertIsInstance(next(iterator), ops.Tensor) + class TestValidationSplit(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index bc63c3acec6..523349faf58 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1558,6 +1558,43 @@ class TrainingTest(keras_parameterized.TestCase): # assign_add not called. self.assertEqual(self.evaluate(layer.v), 1.) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + @parameterized.named_parameters(('numpy', True), ('dataset', False)) + def test_single_input_no_tuple_wrapping(self, use_numpy): + x = np.ones((10, 1)) + + if use_numpy: + batch_size = 3 + else: + x = dataset_ops.Dataset.from_tensor_slices(x).batch(3) + batch_size = None + + test_case = self + + class MyModel(training_module.Model): + + def train_step(self, data): + # No tuple wrapping for single x input and no targets. + test_case.assertIsInstance(data, ops.Tensor) + return super(MyModel, self).train_step(data) + + def test_step(self, data): + test_case.assertIsInstance(data, ops.Tensor) + return super(MyModel, self).test_step(data) + + def predict_step(self, data): + test_case.assertIsInstance(data, ops.Tensor) + return super(MyModel, self).predict_step(data) + + inputs = layers_module.Input(1) + outputs = layers_module.Dense(1)(inputs) + model = MyModel(inputs, outputs) + model.add_loss(math_ops.reduce_sum(outputs)) + model.compile('sgd', 'mse') + model.fit(x, batch_size=batch_size) + model.evaluate(x, batch_size=batch_size) + model.predict(x, batch_size=batch_size) + class TestExceptionsAndWarnings(keras_parameterized.TestCase):