In .adapt
, don't freeze any shape values, only the number of dimensions.
Currently, this freezes the batch size and other shape elements. This is too strict, leads to downstream failures if you have: * a different batch_size for `adapt()` and `fit()`. * variable sequence length or image size. PiperOrigin-RevId: 317425544 Change-Id: I8cfeceeb6816d2f70ed112a04f51fdd15e6658bf
This commit is contained in:
parent
09ec15539e
commit
fcba37557a
@ -185,14 +185,21 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
|||||||
if not self.built:
|
if not self.built:
|
||||||
try:
|
try:
|
||||||
# If this is a Numpy array or tensor, we can get shape from .shape.
|
# If this is a Numpy array or tensor, we can get shape from .shape.
|
||||||
# If not, an attribute error will be thrown (and we can assume the
|
# If not, an attribute error will be thrown.
|
||||||
# input data is a scalar with shape None.
|
data_shape = data_element.shape
|
||||||
shape = data_element.shape
|
data_shape_nones = tuple([None]*len(data_element.shape))
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
shape = None
|
# The input has an unknown number of dimensions.
|
||||||
|
data_shape = None
|
||||||
|
data_shape_nones = None
|
||||||
|
|
||||||
# TODO (b/159261555): move this to base layer build.
|
# TODO (b/159261555): move this to base layer build.
|
||||||
self._batch_input_shape = shape
|
batch_input_shape = getattr(self, '_batch_input_shape', None)
|
||||||
self.build(shape)
|
if batch_input_shape is None:
|
||||||
|
# Set the number of dimensions.
|
||||||
|
self._batch_input_shape = data_shape_nones
|
||||||
|
|
||||||
|
self.build(data_shape)
|
||||||
|
|
||||||
# Once we have built the Layer, we can process the input data. We do so
|
# Once we have built the Layer, we can process the input data. We do so
|
||||||
# until we've gotten an exception indicating that we have no more data.
|
# until we've gotten an exception indicating that we have no more data.
|
||||||
|
@ -122,11 +122,11 @@ class AddingPreprocessingLayerV1(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_layer():
|
def get_layer(**kwargs):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return AddingPreprocessingLayer()
|
return AddingPreprocessingLayer(**kwargs)
|
||||||
else:
|
else:
|
||||||
return AddingPreprocessingLayerV1()
|
return AddingPreprocessingLayerV1(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
@ -366,6 +366,38 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, "Unable to restore a layer of"):
|
with self.assertRaisesRegex(RuntimeError, "Unable to restore a layer of"):
|
||||||
_ = keras.models.load_model(output_path)
|
_ = keras.models.load_model(output_path)
|
||||||
|
|
||||||
|
def test_adapt_sets_input_shape_rank(self):
|
||||||
|
"""Check that `.adapt()` sets the `input_shape`'s rank."""
|
||||||
|
# Shape: (3,1,2)
|
||||||
|
adapt_dataset = np.array([[[1., 2.]],
|
||||||
|
[[3., 4.]],
|
||||||
|
[[5., 6.]]], dtype=np.float32)
|
||||||
|
|
||||||
|
layer = get_layer()
|
||||||
|
layer.adapt(adapt_dataset)
|
||||||
|
|
||||||
|
input_dataset = np.array([[[1., 2.], [3., 4.]],
|
||||||
|
[[3., 4.], [5., 6.]]], dtype=np.float32)
|
||||||
|
layer(input_dataset)
|
||||||
|
|
||||||
|
model = keras.Sequential([layer])
|
||||||
|
self.assertTrue(model.built)
|
||||||
|
self.assertEqual(model.input_shape, (None, None, None))
|
||||||
|
|
||||||
|
def test_adapt_doesnt_overwrite_input_shape(self):
|
||||||
|
"""Check that `.adapt()` doesn't change the `input_shape`."""
|
||||||
|
# Shape: (3, 1, 2)
|
||||||
|
adapt_dataset = np.array([[[1., 2.]],
|
||||||
|
[[3., 4.]],
|
||||||
|
[[5., 6.]]], dtype=np.float32)
|
||||||
|
|
||||||
|
layer = get_layer(input_shape=[1, 2])
|
||||||
|
layer.adapt(adapt_dataset)
|
||||||
|
|
||||||
|
model = keras.Sequential([layer])
|
||||||
|
self.assertTrue(model.built)
|
||||||
|
self.assertEqual(model.input_shape, (None, 1, 2))
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class ConvertToListTest(keras_parameterized.TestCase):
|
class ConvertToListTest(keras_parameterized.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user