Fix shape validation error with tf.nn.conv3d_transpose (#18465)

* Fix shape validation error with tf.nn.conv3d_transpose

This fix tries to address the issue raised in 18460.
In `tf.nn.conv3d_transpose` when list or np array is passed,
the validate of the output shape with filter shape uses
`output_shape[4]` (channel). This will not work with
`data_format='NCDHW'`.

This fix fixes the issue by replace `output_shape[4]` with `output_shape[axis]`.

This fix also adds a test case. Before this fix, the test case will fail.

This fix fixes 18460.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test case for output and filter shape check in conv3d_transpose

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix pylint issue

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Also fix the error message
This commit is contained in:
Yong Tang 2018-04-17 16:44:21 -07:00 committed by Jonathan Hseu
parent 9477835866
commit f185600509
2 changed files with 14 additions and 2 deletions

View File

@ -119,6 +119,18 @@ class Conv3DTransposeTest(test.TestCase):
target = 3.0
self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeShapeMismatch(self):
# Test case for GitHub issue 18460
x_shape = [2, 2, 3, 4, 3]
f_shape = [3, 3, 3, 2, 2]
y_shape = [2, 2, 6, 8, 6]
strides = [1, 1, 2, 2, 2]
np.random.seed(1)
x_value = np.random.random_sample(x_shape).astype(np.float64)
f_value = np.random.random_sample(f_shape).astype(np.float64)
nn_ops.conv3d_transpose(
x_value, f_value, y_shape, strides, data_format='NCDHW')
def testConv3DTransposeValid(self):
with self.test_session():
strides = [1, 2, 2, 2, 1]

View File

@ -1458,10 +1458,10 @@ def conv3d_transpose(
if isinstance(output_shape, (list, np.ndarray)):
# output_shape's shape should be == [5] if reached this point.
if not filter.get_shape()[3].is_compatible_with(output_shape[4]):
if not filter.get_shape()[3].is_compatible_with(output_shape[axis]):
raise ValueError(
"output_shape does not match filter's output channels, "
"{} != {}".format(output_shape[4],
"{} != {}".format(output_shape[axis],
filter.get_shape()[3]))
if padding != "VALID" and padding != "SAME":