Fix shape validation error with tf.nn.conv3d_transpose (#18465)
* Fix shape validation error with tf.nn.conv3d_transpose This fix tries to address the issue raised in 18460. In `tf.nn.conv3d_transpose` when list or np array is passed, the validate of the output shape with filter shape uses `output_shape[4]` (channel). This will not work with `data_format='NCDHW'`. This fix fixes the issue by replace `output_shape[4]` with `output_shape[axis]`. This fix also adds a test case. Before this fix, the test case will fail. This fix fixes 18460. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for output and filter shape check in conv3d_transpose Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix pylint issue Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Also fix the error message
This commit is contained in:
parent
9477835866
commit
f185600509
@ -119,6 +119,18 @@ class Conv3DTransposeTest(test.TestCase):
|
||||
target = 3.0
|
||||
self.assertAllClose(target, value[n, d, h, w, k])
|
||||
|
||||
def testConv3DTransposeShapeMismatch(self):
|
||||
# Test case for GitHub issue 18460
|
||||
x_shape = [2, 2, 3, 4, 3]
|
||||
f_shape = [3, 3, 3, 2, 2]
|
||||
y_shape = [2, 2, 6, 8, 6]
|
||||
strides = [1, 1, 2, 2, 2]
|
||||
np.random.seed(1)
|
||||
x_value = np.random.random_sample(x_shape).astype(np.float64)
|
||||
f_value = np.random.random_sample(f_shape).astype(np.float64)
|
||||
nn_ops.conv3d_transpose(
|
||||
x_value, f_value, y_shape, strides, data_format='NCDHW')
|
||||
|
||||
def testConv3DTransposeValid(self):
|
||||
with self.test_session():
|
||||
strides = [1, 2, 2, 2, 1]
|
||||
|
@ -1458,10 +1458,10 @@ def conv3d_transpose(
|
||||
|
||||
if isinstance(output_shape, (list, np.ndarray)):
|
||||
# output_shape's shape should be == [5] if reached this point.
|
||||
if not filter.get_shape()[3].is_compatible_with(output_shape[4]):
|
||||
if not filter.get_shape()[3].is_compatible_with(output_shape[axis]):
|
||||
raise ValueError(
|
||||
"output_shape does not match filter's output channels, "
|
||||
"{} != {}".format(output_shape[4],
|
||||
"{} != {}".format(output_shape[axis],
|
||||
filter.get_shape()[3]))
|
||||
|
||||
if padding != "VALID" and padding != "SAME":
|
||||
|
Loading…
x
Reference in New Issue
Block a user