Merge pull request #40691 from xingyu-long:fix_zero_padding_3d
PiperOrigin-RevId: 318843722 Change-Id: I5f30799f78f948a12ac0707bb4ba05bf0619dea7
This commit is contained in:
commit
177f89fb3c
@ -2977,30 +2977,30 @@ class ZeroPadding3D(Layer):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if self.data_format == 'channels_first':
|
||||
if input_shape[2] is not None:
|
||||
dim1 = input_shape[2] + 2 * self.padding[0][0]
|
||||
dim1 = input_shape[2] + self.padding[0][0] + self.padding[0][1]
|
||||
else:
|
||||
dim1 = None
|
||||
if input_shape[3] is not None:
|
||||
dim2 = input_shape[3] + 2 * self.padding[1][0]
|
||||
dim2 = input_shape[3] + self.padding[1][0] + self.padding[1][1]
|
||||
else:
|
||||
dim2 = None
|
||||
if input_shape[4] is not None:
|
||||
dim3 = input_shape[4] + 2 * self.padding[2][0]
|
||||
dim3 = input_shape[4] + self.padding[2][0] + self.padding[2][1]
|
||||
else:
|
||||
dim3 = None
|
||||
return tensor_shape.TensorShape(
|
||||
[input_shape[0], input_shape[1], dim1, dim2, dim3])
|
||||
elif self.data_format == 'channels_last':
|
||||
if input_shape[1] is not None:
|
||||
dim1 = input_shape[1] + 2 * self.padding[0][1]
|
||||
dim1 = input_shape[1] + self.padding[0][0] + self.padding[0][1]
|
||||
else:
|
||||
dim1 = None
|
||||
if input_shape[2] is not None:
|
||||
dim2 = input_shape[2] + 2 * self.padding[1][1]
|
||||
dim2 = input_shape[2] + self.padding[1][0] + self.padding[1][1]
|
||||
else:
|
||||
dim2 = None
|
||||
if input_shape[3] is not None:
|
||||
dim3 = input_shape[3] + 2 * self.padding[2][1]
|
||||
dim3 = input_shape[3] + self.padding[2][0] + self.padding[2][1]
|
||||
else:
|
||||
dim3 = None
|
||||
return tensor_shape.TensorShape(
|
||||
|
@ -723,36 +723,94 @@ class ZeroPaddingTest(keras_parameterized.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.layers.ZeroPadding2D(padding=None)
|
||||
|
||||
def test_zero_padding_3d(self):
|
||||
@parameterized.named_parameters(('channels_first', 'channels_first'),
|
||||
('channels_last', 'channels_last'))
|
||||
def test_zero_padding_3d(self, data_format):
|
||||
num_samples = 2
|
||||
stack_size = 2
|
||||
input_len_dim1 = 4
|
||||
input_len_dim2 = 5
|
||||
input_len_dim3 = 3
|
||||
|
||||
inputs = np.ones((num_samples, input_len_dim1, input_len_dim2,
|
||||
input_len_dim3, stack_size))
|
||||
if data_format == 'channels_first':
|
||||
inputs = np.ones((num_samples, stack_size, input_len_dim1, input_len_dim2,
|
||||
input_len_dim3))
|
||||
elif data_format == 'channels_last':
|
||||
inputs = np.ones((num_samples, input_len_dim1, input_len_dim2,
|
||||
input_len_dim3, stack_size))
|
||||
|
||||
with self.cached_session(use_gpu=True):
|
||||
# basic test
|
||||
testing_utils.layer_test(
|
||||
keras.layers.ZeroPadding3D,
|
||||
kwargs={'padding': (2, 2, 2)},
|
||||
kwargs={
|
||||
'padding': (2, 2, 2),
|
||||
'data_format': data_format
|
||||
},
|
||||
input_shape=inputs.shape)
|
||||
testing_utils.layer_test(
|
||||
keras.layers.ZeroPadding3D,
|
||||
kwargs={
|
||||
'padding': ((1, 2), (3, 4), (0, 2)),
|
||||
'data_format': data_format
|
||||
},
|
||||
input_shape=inputs.shape)
|
||||
|
||||
with self.cached_session(use_gpu=True):
|
||||
# correctness test
|
||||
layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2))
|
||||
layer = keras.layers.ZeroPadding3D(
|
||||
padding=(2, 2, 2), data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
for offset in [0, 1, -1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, offset, :, :, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, :, offset, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, 2:-2, 2:-2, 2:-2, :], 1.)
|
||||
if data_format == 'channels_last':
|
||||
for offset in [0, 1, -1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, offset, :, :, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, :, offset, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, 2:-2, 2:-2, 2:-2, :], 1.)
|
||||
elif data_format == 'channels_first':
|
||||
for offset in [0, 1, -1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, :, offset, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, :, :, offset], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, 2:-2, 2:-2, 2:-2], 1.)
|
||||
|
||||
layer = keras.layers.ZeroPadding3D(
|
||||
padding=((1, 2), (3, 4), (0, 2)), data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
if context.executing_eagerly():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
if data_format == 'channels_last':
|
||||
for offset in [0]:
|
||||
np.testing.assert_allclose(np_output[:, offset, :, :, :], 0.)
|
||||
for offset in [-1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, offset, :, :, :], 0.)
|
||||
for offset in [0, 1, 2]:
|
||||
np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.)
|
||||
for offset in [-1, -2, -3, -4]:
|
||||
np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.)
|
||||
for offset in [-1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, :, :, offset, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, 1:-2, 3:-4, 0:-2, :], 1.)
|
||||
elif data_format == 'channels_first':
|
||||
for offset in [0]:
|
||||
np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.)
|
||||
for offset in [-1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.)
|
||||
for offset in [0, 1, 2]:
|
||||
np.testing.assert_allclose(np_output[:, :, :, offset, :], 0.)
|
||||
for offset in [-1, -2, -3, -4]:
|
||||
np.testing.assert_allclose(np_output[:, :, :, offset, :], 0.)
|
||||
for offset in [-1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, :, :, :, offset], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, 1:-2, 3:-4, 0:-2], 1.)
|
||||
|
||||
# test incorrect use
|
||||
with self.assertRaises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user