diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 476aaba6bb5..83bfb47c5b1 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -305,8 +305,9 @@ class Conv(Layer): """Recreate conv_op if necessary. Check if the input_shape in call() is different from that in build(). - For the values that are not None, if they are different, recreate - the _convolution_op to avoid the stateful behavior. + If the most-specific input shape describing the build and call shapes is not + equal to the shape we currently built with, then we need to rebuild the + _convolution_op to avoid incorrect behavior. Args: inputs: The input data to call() method. @@ -315,12 +316,10 @@ class Conv(Layer): `True` or `False` to indicate whether to recreate the conv_op. """ call_input_shape = inputs.get_shape() - for axis in range(1, len(call_input_shape)): - if (call_input_shape[axis] is not None - and self._build_conv_op_input_shape[axis] is not None - and call_input_shape[axis] != self._build_conv_op_input_shape[axis]): - return True - return False + # If the most specific compatible shape between _build_input_shape and + # call_input_shape is not _build_input_shape then we must re-build. + return self._build_conv_op_input_shape.most_specific_compatible_shape( + call_input_shape) != self._build_conv_op_input_shape @keras_export('keras.layers.Conv1D', 'keras.layers.Convolution1D') diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py index a36efd9da26..528bc14adf4 100644 --- a/tensorflow/python/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/layers/convolutional_test.py @@ -23,6 +23,8 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils @@ -108,6 +110,26 @@ class Conv1DTest(keras_parameterized.TestCase): _ = layer(inpt2).shape self.assertEqual(outp1_shape, layer(inpt1).shape) + def test_conv1d_recreate_conv_unknown_dims(self): + with self.cached_session(use_gpu=True): + layer = keras.layers.Conv1D(filters=1, + kernel_size=3, + strides=1, + dilation_rate=2, + padding='causal') + + inpt1 = np.random.normal(size=[1, 9, 1]).astype(np.float32) + inpt2 = np.random.normal(size=[1, 2, 1]).astype(np.float32) + outp1_shape = layer(inpt1).shape + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec([1, None, 1])]) + def fn(inpt): + return layer(inpt) + + fn(inpt2) + self.assertEqual(outp1_shape, layer(inpt1).shape) + @keras_parameterized.run_all_keras_modes class Conv2DTest(keras_parameterized.TestCase):