From fcba37557a30e707b1368ee270a29e41410ad546 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 19 Jun 2020 20:36:23 -0700 Subject: [PATCH] In `.adapt`, don't freeze any shape values, only the number of dimensions. Currently, this freezes the batch size and other shape elements. This is too strict, leads to downstream failures if you have: * a different batch_size for `adapt()` and `fit()`. * variable sequence length or image size. PiperOrigin-RevId: 317425544 Change-Id: I8cfeceeb6816d2f70ed112a04f51fdd15e6658bf --- .../keras/engine/base_preprocessing_layer.py | 19 +++++++--- .../engine/base_preprocessing_layer_test.py | 38 +++++++++++++++++-- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py index b2ab0880422..c8ba1229ff5 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py @@ -185,14 +185,21 @@ class CombinerPreprocessingLayer(PreprocessingLayer): if not self.built: try: # If this is a Numpy array or tensor, we can get shape from .shape. - # If not, an attribute error will be thrown (and we can assume the - # input data is a scalar with shape None. - shape = data_element.shape + # If not, an attribute error will be thrown. + data_shape = data_element.shape + data_shape_nones = tuple([None]*len(data_element.shape)) except AttributeError: - shape = None + # The input has an unknown number of dimensions. + data_shape = None + data_shape_nones = None + # TODO (b/159261555): move this to base layer build. - self._batch_input_shape = shape - self.build(shape) + batch_input_shape = getattr(self, '_batch_input_shape', None) + if batch_input_shape is None: + # Set the number of dimensions. + self._batch_input_shape = data_shape_nones + + self.build(data_shape) # Once we have built the Layer, we can process the input data. We do so # until we've gotten an exception indicating that we have no more data. diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer_test.py b/tensorflow/python/keras/engine/base_preprocessing_layer_test.py index 70d088cf3d3..a3a36a9bf11 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer_test.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer_test.py @@ -122,11 +122,11 @@ class AddingPreprocessingLayerV1( pass -def get_layer(): +def get_layer(**kwargs): if context.executing_eagerly(): - return AddingPreprocessingLayer() + return AddingPreprocessingLayer(**kwargs) else: - return AddingPreprocessingLayerV1() + return AddingPreprocessingLayerV1(**kwargs) @keras_parameterized.run_all_keras_modes @@ -366,6 +366,38 @@ class PreprocessingLayerTest(keras_parameterized.TestCase): with self.assertRaisesRegex(RuntimeError, "Unable to restore a layer of"): _ = keras.models.load_model(output_path) + def test_adapt_sets_input_shape_rank(self): + """Check that `.adapt()` sets the `input_shape`'s rank.""" + # Shape: (3,1,2) + adapt_dataset = np.array([[[1., 2.]], + [[3., 4.]], + [[5., 6.]]], dtype=np.float32) + + layer = get_layer() + layer.adapt(adapt_dataset) + + input_dataset = np.array([[[1., 2.], [3., 4.]], + [[3., 4.], [5., 6.]]], dtype=np.float32) + layer(input_dataset) + + model = keras.Sequential([layer]) + self.assertTrue(model.built) + self.assertEqual(model.input_shape, (None, None, None)) + + def test_adapt_doesnt_overwrite_input_shape(self): + """Check that `.adapt()` doesn't change the `input_shape`.""" + # Shape: (3, 1, 2) + adapt_dataset = np.array([[[1., 2.]], + [[3., 4.]], + [[5., 6.]]], dtype=np.float32) + + layer = get_layer(input_shape=[1, 2]) + layer.adapt(adapt_dataset) + + model = keras.Sequential([layer]) + self.assertTrue(model.built) + self.assertEqual(model.input_shape, (None, 1, 2)) + @keras_parameterized.run_all_keras_modes class ConvertToListTest(keras_parameterized.TestCase):