In .adapt, don't freeze any shape values, only the number of dimensions.

Currently, this freezes the batch size and other shape elements. This is too strict, leads to downstream failures if you have:

* a different batch_size for `adapt()` and `fit()`.
* variable sequence length or image size.

PiperOrigin-RevId: 317425544
Change-Id: I8cfeceeb6816d2f70ed112a04f51fdd15e6658bf
This commit is contained in:
Mark Daoust 2020-06-19 20:36:23 -07:00 committed by TensorFlower Gardener
parent 09ec15539e
commit fcba37557a
2 changed files with 48 additions and 9 deletions

View File

@ -185,14 +185,21 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
if not self.built:
try:
# If this is a Numpy array or tensor, we can get shape from .shape.
# If not, an attribute error will be thrown (and we can assume the
# input data is a scalar with shape None.
shape = data_element.shape
# If not, an attribute error will be thrown.
data_shape = data_element.shape
data_shape_nones = tuple([None]*len(data_element.shape))
except AttributeError:
shape = None
# The input has an unknown number of dimensions.
data_shape = None
data_shape_nones = None
# TODO (b/159261555): move this to base layer build.
self._batch_input_shape = shape
self.build(shape)
batch_input_shape = getattr(self, '_batch_input_shape', None)
if batch_input_shape is None:
# Set the number of dimensions.
self._batch_input_shape = data_shape_nones
self.build(data_shape)
# Once we have built the Layer, we can process the input data. We do so
# until we've gotten an exception indicating that we have no more data.

View File

@ -122,11 +122,11 @@ class AddingPreprocessingLayerV1(
pass
def get_layer():
def get_layer(**kwargs):
if context.executing_eagerly():
return AddingPreprocessingLayer()
return AddingPreprocessingLayer(**kwargs)
else:
return AddingPreprocessingLayerV1()
return AddingPreprocessingLayerV1(**kwargs)
@keras_parameterized.run_all_keras_modes
@ -366,6 +366,38 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
with self.assertRaisesRegex(RuntimeError, "Unable to restore a layer of"):
_ = keras.models.load_model(output_path)
def test_adapt_sets_input_shape_rank(self):
"""Check that `.adapt()` sets the `input_shape`'s rank."""
# Shape: (3,1,2)
adapt_dataset = np.array([[[1., 2.]],
[[3., 4.]],
[[5., 6.]]], dtype=np.float32)
layer = get_layer()
layer.adapt(adapt_dataset)
input_dataset = np.array([[[1., 2.], [3., 4.]],
[[3., 4.], [5., 6.]]], dtype=np.float32)
layer(input_dataset)
model = keras.Sequential([layer])
self.assertTrue(model.built)
self.assertEqual(model.input_shape, (None, None, None))
def test_adapt_doesnt_overwrite_input_shape(self):
"""Check that `.adapt()` doesn't change the `input_shape`."""
# Shape: (3, 1, 2)
adapt_dataset = np.array([[[1., 2.]],
[[3., 4.]],
[[5., 6.]]], dtype=np.float32)
layer = get_layer(input_shape=[1, 2])
layer.adapt(adapt_dataset)
model = keras.Sequential([layer])
self.assertTrue(model.built)
self.assertEqual(model.input_shape, (None, 1, 2))
@keras_parameterized.run_all_keras_modes
class ConvertToListTest(keras_parameterized.TestCase):