In .adapt
, don't freeze any shape values, only the number of dimensions.
Currently, this freezes the batch size and other shape elements. This is too strict, leads to downstream failures if you have: * a different batch_size for `adapt()` and `fit()`. * variable sequence length or image size. PiperOrigin-RevId: 317425544 Change-Id: I8cfeceeb6816d2f70ed112a04f51fdd15e6658bf
This commit is contained in:
parent
09ec15539e
commit
fcba37557a
@ -185,14 +185,21 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
||||
if not self.built:
|
||||
try:
|
||||
# If this is a Numpy array or tensor, we can get shape from .shape.
|
||||
# If not, an attribute error will be thrown (and we can assume the
|
||||
# input data is a scalar with shape None.
|
||||
shape = data_element.shape
|
||||
# If not, an attribute error will be thrown.
|
||||
data_shape = data_element.shape
|
||||
data_shape_nones = tuple([None]*len(data_element.shape))
|
||||
except AttributeError:
|
||||
shape = None
|
||||
# The input has an unknown number of dimensions.
|
||||
data_shape = None
|
||||
data_shape_nones = None
|
||||
|
||||
# TODO (b/159261555): move this to base layer build.
|
||||
self._batch_input_shape = shape
|
||||
self.build(shape)
|
||||
batch_input_shape = getattr(self, '_batch_input_shape', None)
|
||||
if batch_input_shape is None:
|
||||
# Set the number of dimensions.
|
||||
self._batch_input_shape = data_shape_nones
|
||||
|
||||
self.build(data_shape)
|
||||
|
||||
# Once we have built the Layer, we can process the input data. We do so
|
||||
# until we've gotten an exception indicating that we have no more data.
|
||||
|
@ -122,11 +122,11 @@ class AddingPreprocessingLayerV1(
|
||||
pass
|
||||
|
||||
|
||||
def get_layer():
|
||||
def get_layer(**kwargs):
|
||||
if context.executing_eagerly():
|
||||
return AddingPreprocessingLayer()
|
||||
return AddingPreprocessingLayer(**kwargs)
|
||||
else:
|
||||
return AddingPreprocessingLayerV1()
|
||||
return AddingPreprocessingLayerV1(**kwargs)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@ -366,6 +366,38 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Unable to restore a layer of"):
|
||||
_ = keras.models.load_model(output_path)
|
||||
|
||||
def test_adapt_sets_input_shape_rank(self):
|
||||
"""Check that `.adapt()` sets the `input_shape`'s rank."""
|
||||
# Shape: (3,1,2)
|
||||
adapt_dataset = np.array([[[1., 2.]],
|
||||
[[3., 4.]],
|
||||
[[5., 6.]]], dtype=np.float32)
|
||||
|
||||
layer = get_layer()
|
||||
layer.adapt(adapt_dataset)
|
||||
|
||||
input_dataset = np.array([[[1., 2.], [3., 4.]],
|
||||
[[3., 4.], [5., 6.]]], dtype=np.float32)
|
||||
layer(input_dataset)
|
||||
|
||||
model = keras.Sequential([layer])
|
||||
self.assertTrue(model.built)
|
||||
self.assertEqual(model.input_shape, (None, None, None))
|
||||
|
||||
def test_adapt_doesnt_overwrite_input_shape(self):
|
||||
"""Check that `.adapt()` doesn't change the `input_shape`."""
|
||||
# Shape: (3, 1, 2)
|
||||
adapt_dataset = np.array([[[1., 2.]],
|
||||
[[3., 4.]],
|
||||
[[5., 6.]]], dtype=np.float32)
|
||||
|
||||
layer = get_layer(input_shape=[1, 2])
|
||||
layer.adapt(adapt_dataset)
|
||||
|
||||
model = keras.Sequential([layer])
|
||||
self.assertTrue(model.built)
|
||||
self.assertEqual(model.input_shape, (None, 1, 2))
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class ConvertToListTest(keras_parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user