Internal change
PiperOrigin-RevId: 308757298 Change-Id: I9ec9dfdd236a1ed796f82fd77e0afee777ca626d
This commit is contained in:
parent
fb7ea8f0e6
commit
7b357dd06b
@ -305,8 +305,9 @@ class Conv(Layer):
|
||||
"""Recreate conv_op if necessary.
|
||||
|
||||
Check if the input_shape in call() is different from that in build().
|
||||
For the values that are not None, if they are different, recreate
|
||||
the _convolution_op to avoid the stateful behavior.
|
||||
If the most-specific input shape describing the build and call shapes is not
|
||||
equal to the shape we currently built with, then we need to rebuild the
|
||||
_convolution_op to avoid incorrect behavior.
|
||||
|
||||
Args:
|
||||
inputs: The input data to call() method.
|
||||
@ -315,12 +316,10 @@ class Conv(Layer):
|
||||
`True` or `False` to indicate whether to recreate the conv_op.
|
||||
"""
|
||||
call_input_shape = inputs.get_shape()
|
||||
for axis in range(1, len(call_input_shape)):
|
||||
if (call_input_shape[axis] is not None
|
||||
and self._build_conv_op_input_shape[axis] is not None
|
||||
and call_input_shape[axis] != self._build_conv_op_input_shape[axis]):
|
||||
return True
|
||||
return False
|
||||
# If the most specific compatible shape between _build_input_shape and
|
||||
# call_input_shape is not _build_input_shape then we must re-build.
|
||||
return self._build_conv_op_input_shape.most_specific_compatible_shape(
|
||||
call_input_shape) != self._build_conv_op_input_shape
|
||||
|
||||
|
||||
@keras_export('keras.layers.Conv1D', 'keras.layers.Convolution1D')
|
||||
|
@ -23,6 +23,8 @@ import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -108,6 +110,26 @@ class Conv1DTest(keras_parameterized.TestCase):
|
||||
_ = layer(inpt2).shape
|
||||
self.assertEqual(outp1_shape, layer(inpt1).shape)
|
||||
|
||||
def test_conv1d_recreate_conv_unknown_dims(self):
|
||||
with self.cached_session(use_gpu=True):
|
||||
layer = keras.layers.Conv1D(filters=1,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
dilation_rate=2,
|
||||
padding='causal')
|
||||
|
||||
inpt1 = np.random.normal(size=[1, 9, 1]).astype(np.float32)
|
||||
inpt2 = np.random.normal(size=[1, 2, 1]).astype(np.float32)
|
||||
outp1_shape = layer(inpt1).shape
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec([1, None, 1])])
|
||||
def fn(inpt):
|
||||
return layer(inpt)
|
||||
|
||||
fn(inpt2)
|
||||
self.assertEqual(outp1_shape, layer(inpt1).shape)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class Conv2DTest(keras_parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user