diff --git a/RELEASE.md b/RELEASE.md index fe6d052640a..f078d336abb 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,6 +1,7 @@ # Changes since the last release ## Major Features and Improvements +* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution. * Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times. * Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo). * `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 3b8959e2106..04fec38b211 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -970,7 +970,7 @@ def separable_conv2d(inputs, class Conv2DTranspose(Conv2D): - """Transposed convolution layer (sometimes called Deconvolution). + """Transposed 2D convolution layer (sometimes called 2D Deconvolution). The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction @@ -1083,19 +1083,9 @@ class Conv2DTranspose(Conv2D): kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides - def get_deconv_dim(dim_size, stride_size, kernel_size, padding): - if isinstance(dim_size, ops.Tensor): - dim_size = math_ops.multiply(dim_size, stride_size) - elif dim_size is not None: - dim_size *= stride_size - - if padding == 'valid' and dim_size is not None: - dim_size += max(kernel_size - stride_size, 0) - return dim_size - # Infer the dynamic output shape: - out_height = get_deconv_dim(height, stride_h, kernel_h, self.padding) - out_width = get_deconv_dim(width, stride_w, kernel_w, self.padding) + out_height = utils.get_deconv_dim(height, stride_h, kernel_h, self.padding) + out_width = utils.get_deconv_dim(width, stride_w, kernel_w, self.padding) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) @@ -1116,9 +1106,9 @@ class Conv2DTranspose(Conv2D): # Infer the static output shape: out_shape = inputs.get_shape().as_list() out_shape[c_axis] = self.filters - out_shape[h_axis] = get_deconv_dim( + out_shape[h_axis] = utils.get_deconv_dim( out_shape[h_axis], stride_h, kernel_h, self.padding) - out_shape[w_axis] = get_deconv_dim( + out_shape[w_axis] = utils.get_deconv_dim( out_shape[w_axis], stride_w, kernel_w, self.padding) outputs.set_shape(out_shape) @@ -1149,7 +1139,7 @@ def conv2d_transpose(inputs, trainable=True, name=None, reuse=None): - """Transposed convolution layer (sometimes called Deconvolution). + """Functional interface for transposed 2D convolution layer. The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction @@ -1174,12 +1164,12 @@ def conv2d_transpose(inputs, `channels_last` corresponds to inputs with shape `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape `(batch, channels, height, width)`. - activation: Activation function. Set it to None to maintain a + activation: Activation function. Set it to `None` to maintain a linear activation. use_bias: Boolean, whether the layer uses a bias. kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, no bias will - be applied. + bias_initializer: An initializer for the bias vector. If `None`, then no + bias will be applied. kernel_regularizer: Optional regularizer for the convolution kernel. bias_regularizer: Optional regularizer for the bias vector. activity_regularizer: Regularizer function for the output. @@ -1212,6 +1202,246 @@ def conv2d_transpose(inputs, return layer.apply(inputs) +class Conv3DTranspose(Conv3D): + """Transposed 3D convolution layer (sometimes called 3D Deconvolution). + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for all spatial + dimensions. + strides: An integer or tuple/list of 3 integers, specifying the strides + of the convolution along the depth, height and width. + Can be a single integer to specify the same value for all spatial + dimensions. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + activation: Activation function. Set it to `None` to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If `None`, then no + bias will be applied. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=(1, 1, 1), + padding='valid', + data_format='channels_last', + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(Conv3DTranspose, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + trainable=trainable, + name=name, **kwargs) + + def build(self, input_shape): + if len(input_shape) != 5: + raise ValueError('Inputs should have rank 5, ' + + 'received input shape:', str(input_shape)) + if self.data_format == 'channels_first': + channel_axis = 1 + else: + channel_axis = -1 + if input_shape[channel_axis] is None: + raise ValueError('The channel dimension of the inputs ' + 'should be defined, found None: ' + str(input_shape)) + input_dim = input_shape[channel_axis] + kernel_shape = self.kernel_size + (self.filters, input_dim) + + self.kernel = vs.get_variable('kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True, + dtype=self.dtype) + if self.use_bias: + self.bias = vs.get_variable('bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + trainable=True, + dtype=self.dtype) + else: + self.bias = None + + def call(self, inputs): + inputs_shape = array_ops.shape(inputs) + batch_size = inputs_shape[0] + if self.data_format == 'channels_first': + c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4 + else: + c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3 + + depth = inputs_shape[d_axis] + height = inputs_shape[h_axis] + width = inputs_shape[w_axis] + + kernel_d, kernel_h, kernel_w = self.kernel_size + stride_d, stride_h, stride_w = self.strides + + # Infer the dynamic output shape: + out_depth = utils.get_deconv_dim(depth, stride_d, kernel_d, self.padding) + out_height = utils.get_deconv_dim(height, stride_h, kernel_h, self.padding) + out_width = utils.get_deconv_dim(width, stride_w, kernel_w, self.padding) + + if self.data_format == 'channels_first': + output_shape = (batch_size, self.filters, out_depth, out_height, + out_width) + strides = (1, 1, stride_d, stride_h, stride_w) + else: + output_shape = (batch_size, out_depth, out_height, out_width, + self.filters) + strides = (1, stride_d, stride_h, stride_w, 1) + + output_shape_tensor = array_ops.stack(output_shape) + outputs = nn.conv3d_transpose( + inputs, + self.kernel, + output_shape_tensor, + strides, + data_format=utils.convert_data_format(self.data_format, ndim=5), + padding=self.padding.upper()) + + # Infer the static output shape: + out_shape = inputs.get_shape().as_list() + out_shape[c_axis] = self.filters + out_shape[d_axis] = utils.get_deconv_dim( + out_shape[d_axis], stride_d, kernel_d, self.padding) + out_shape[h_axis] = utils.get_deconv_dim( + out_shape[h_axis], stride_h, kernel_h, self.padding) + out_shape[w_axis] = utils.get_deconv_dim( + out_shape[w_axis], stride_w, kernel_w, self.padding) + outputs.set_shape(out_shape) + + if self.bias: + outputs_shape = outputs.shape.as_list() + if self.data_format == 'channels_first': + outputs_4d = array_ops.reshape(outputs, + [outputs_shape[0], outputs_shape[1], + outputs_shape[2] * outputs_shape[3], + outputs_shape[4]]) + else: + outputs_4d = array_ops.reshape(outputs, + [outputs_shape[0], + outputs_shape[1] * outputs_shape[2], + outputs_shape[3], outputs_shape[4]]) + outputs_4d = nn.bias_add( + outputs_4d, + self.bias, + data_format=utils.convert_data_format(self.data_format, ndim=4)) + outputs = array_ops.reshape(outputs_4d, outputs_shape) + + if self.activation is not None: + return self.activation(outputs) + return outputs + + +def conv3d_transpose(inputs, + filters, + kernel_size, + strides=(1, 1, 1), + padding='valid', + data_format='channels_last', + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for transposed 3D convolution layer. + + Arguments: + inputs: Input tensor. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A tuple or list of 3 positive integers specifying the spatial + dimensions of of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + strides: A tuple or list of 3 positive integers specifying the strides + of the convolution. Can be a single integer to specify the same value for + all spatial dimensions. + padding: one of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, no bias will + be applied. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + """ + layer = Conv3DTranspose( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) + + # Aliases Convolution1D = Conv1D @@ -1219,9 +1449,11 @@ Convolution2D = Conv2D Convolution3D = Conv3D SeparableConvolution2D = SeparableConv2D Convolution2DTranspose = Deconvolution2D = Deconv2D = Conv2DTranspose +Convolution3DTranspose = Deconvolution3D = Deconv3D = Conv3DTranspose convolution1d = conv1d convolution2d = conv2d convolution3d = conv3d separable_convolution2d = separable_conv2d convolution2d_transpose = deconvolution2d = deconv2d = conv2d_transpose +convolution3d_transpose = deconvolution3d = deconv3d = conv3d_transpose diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py index da962b2f99e..635cc24714c 100644 --- a/tensorflow/python/layers/convolutional_test.py +++ b/tensorflow/python/layers/convolutional_test.py @@ -651,5 +651,175 @@ class Conv2DTransposeTest(test.TestCase): self.assertEqual(len(variables.trainable_variables()), 4) +class Conv3DTransposeTest(test.TestCase): + + def testInvalidDataFormat(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1) + with self.assertRaisesRegexp(ValueError, 'data_format'): + conv_layers.conv3d_transpose(volumes, 4, 3, data_format='invalid') + + def testInvalidStrides(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1) + with self.assertRaisesRegexp(ValueError, 'strides'): + conv_layers.conv3d_transpose(volumes, 4, 3, strides=(1, 2)) + + with self.assertRaisesRegexp(ValueError, 'strides'): + conv_layers.conv3d_transpose(volumes, 4, 3, strides=None) + + def testInvalidKernelSize(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1) + with self.assertRaisesRegexp(ValueError, 'kernel_size'): + conv_layers.conv3d_transpose(volumes, 4, (1, 2)) + + with self.assertRaisesRegexp(ValueError, 'kernel_size'): + conv_layers.conv3d_transpose(volumes, 4, None) + + def testCreateConv3DTranspose(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32)) + layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], activation=nn_ops.relu) + output = layer.apply(volumes) + self.assertEqual(output.op.name, 'conv3d_transpose/Relu') + self.assertListEqual(output.get_shape().as_list(), + [5, depth + 2, height + 2, width + 2, 4]) + self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32]) + self.assertListEqual(layer.bias.get_shape().as_list(), [4]) + + def testCreateConv3DTransposeIntegerKernelSize(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32)) + layer = conv_layers.Conv3DTranspose(4, 3) + output = layer.apply(volumes) + self.assertListEqual(output.get_shape().as_list(), + [5, depth + 2, height + 2, width + 2, 4]) + self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32]) + self.assertListEqual(layer.bias.get_shape().as_list(), [4]) + + def testCreateConv3DTransposeChannelsFirst(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, 32, depth, height, width)) + layer = conv_layers.Conv3DTranspose( + 4, [3, 3, 3], data_format='channels_first') + output = layer.apply(volumes) + self.assertListEqual(output.get_shape().as_list(), + [5, 4, depth + 2, height + 2, width + 2]) + self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32]) + self.assertListEqual(layer.bias.get_shape().as_list(), [4]) + + def testConv3DTransposePaddingSame(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 64), seed=1) + layer = conv_layers.Conv3DTranspose( + 32, volumes.get_shape()[1:4], padding='same') + output = layer.apply(volumes) + self.assertListEqual(output.get_shape().as_list(), [5, depth, height, + width, 32]) + + def testCreateConv3DTransposeWithStrides(self): + depth, height, width = 4, 6, 8 + # Test strides tuple. + volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1) + layer = conv_layers.Conv3DTranspose( + 4, [3, 3, 3], strides=(2, 2, 2), padding='same') + output = layer.apply(volumes) + self.assertListEqual(output.get_shape().as_list(), + [5, depth * 2, height * 2, width * 2, 4]) + + # Test strides integer. + layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2, + padding='same') + output = layer.apply(volumes) + self.assertListEqual(output.get_shape().as_list(), + [5, depth * 2, height * 2, width * 2, 4]) + + # Test unequal strides. + layer = conv_layers.Conv3DTranspose( + 4, [3, 3, 3], strides=(2, 1, 1), padding='same') + output = layer.apply(volumes) + self.assertListEqual(output.get_shape().as_list(), + [5, depth * 2, height, width, 4]) + + def testConv3DTransposeKernelRegularizer(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32)) + reg = lambda x: 0.1 * math_ops.reduce_sum(x) + layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], kernel_regularizer=reg) + layer.apply(volumes) + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertListEqual(layer.losses, loss_keys) + + def testConv3DTransposeBiasRegularizer(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32)) + reg = lambda x: 0.1 * math_ops.reduce_sum(x) + layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], bias_regularizer=reg) + layer.apply(volumes) + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertListEqual(layer.losses, loss_keys) + + def testConv3DTransposeNoBias(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32)) + layer = conv_layers.Conv3DTranspose( + 4, [3, 3, 3], activation=nn_ops.relu, use_bias=False) + output = layer.apply(volumes) + self.assertEqual(output.op.name, 'conv3d_transpose/Relu') + self.assertListEqual(output.get_shape().as_list(), + [5, depth + 2, height + 2, width + 2, 4]) + self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32]) + self.assertEqual(layer.bias, None) + + def testFunctionalConv3DTransposeReuse(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1) + conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1') + self.assertEqual(len(variables.trainable_variables()), 2) + conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1', reuse=True) + self.assertEqual(len(variables.trainable_variables()), 2) + + def testFunctionalConv3DTransposeReuseFromScope(self): + with variable_scope.variable_scope('scope'): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32), + seed=1) + conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1') + self.assertEqual(len(variables.trainable_variables()), 2) + with variable_scope.variable_scope('scope', reuse=True): + conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1') + self.assertEqual(len(variables.trainable_variables()), 2) + + def testFunctionalConv3DTransposeInitializerFromScope(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + 'scope', initializer=init_ops.ones_initializer()): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32), + seed=1) + conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1') + weights = variables.trainable_variables() + # Check the names of weights in order. + self.assertTrue('kernel' in weights[0].name) + self.assertTrue('bias' in weights[1].name) + sess.run(variables.global_variables_initializer()) + weights = sess.run(weights) + # Check that the kernel weights got initialized to ones (from scope) + self.assertAllClose(weights[0], np.ones((3, 3, 3, 4, 32))) + # Check that the bias still got initialized to zeros. + self.assertAllClose(weights[1], np.zeros((4))) + + def testFunctionalConv3DTransposeNoReuse(self): + depth, height, width = 5, 7, 9 + volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1) + conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3]) + self.assertEqual(len(variables.trainable_variables()), 2) + conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3]) + self.assertEqual(len(variables.trainable_variables()), 4) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py index 9f02757d5bc..aa46eb5d27d 100644 --- a/tensorflow/python/layers/layers.py +++ b/tensorflow/python/layers/layers.py @@ -23,6 +23,7 @@ @@conv3d @@separable_conv2d @@conv2d_transpose +@@conv3d_transpose @@average_pooling1d @@max_pooling1d @@average_pooling2d @@ -50,6 +51,7 @@ from tensorflow.python.layers.convolutional import conv2d from tensorflow.python.layers.convolutional import conv3d from tensorflow.python.layers.convolutional import separable_conv2d from tensorflow.python.layers.convolutional import conv2d_transpose +from tensorflow.python.layers.convolutional import conv3d_transpose # Pooling layers. from tensorflow.python.layers.pooling import average_pooling1d diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index 666d475690b..64b948c70f5 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.python.ops import variables from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -164,3 +165,28 @@ def constant_value(pred): else: raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') return pred_value + + +def get_deconv_dim(dim_size, stride_size, kernel_size, padding): + """Return output dimension of a deconv layer, based on input dimension. + + Arguments: + dim_size: An int representing size of dimension, can be height, width + or depth. + stride_size: An int representing the stride of deconvolution filters + along the same dimension. + kernel_size: An int representing size of deconv kernel (filter) along + the same dimension. + padding: one of `"valid"` or `"same"` (case-insensitive). + + Returns: + An int representing the size of output dimension of the layer. + """ + if isinstance(dim_size, ops.Tensor): + dim_size = math_ops.multiply(dim_size, stride_size) + elif dim_size is not None: + dim_size *= stride_size + + if padding == 'valid' and dim_size is not None: + dim_size += max(kernel_size - stride_size, 0) + return dim_size diff --git a/tensorflow/python/layers/utils_test.py b/tensorflow/python/layers/utils_test.py index 54e757c112b..7969e957d8d 100644 --- a/tensorflow/python/layers/utils_test.py +++ b/tensorflow/python/layers/utils_test.py @@ -62,6 +62,13 @@ class ConvUtilsTest(test.TestCase): with self.assertRaises(ValueError): utils.normalize_padding('invalid') + def testGetDeconvDim(self): + self.assertEqual(utils.get_deconv_dim(30, 1, 3, 'valid'), 32) + self.assertEqual(utils.get_deconv_dim(28, 1, 5, 'valid'), 32) + self.assertEqual(utils.get_deconv_dim(28, 2, 5, 'valid'), 59) + self.assertEqual(utils.get_deconv_dim(32, 1, 3, 'same'), 32) + self.assertEqual(utils.get_deconv_dim(32, 1, 5, 'same'), 32) + self.assertEqual(utils.get_deconv_dim(32, 2, 5, 'same'), 64) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index ccce9402c77..9b54b937540 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1272,7 +1272,7 @@ def conv3d_transpose(value, output_shape, strides, padding="SAME", - data_format=None, + data_format="NDHWC", name=None): """The transpose of `conv3d`. @@ -1308,9 +1308,10 @@ def conv3d_transpose(value, [value, filter, output_shape]) as name: value = ops.convert_to_tensor(value, name="value") filter = ops.convert_to_tensor(filter, name="filter") - if not value.get_shape()[4].is_compatible_with(filter.get_shape()[4]): + axis = 1 if data_format == "NCDHW" else 4 + if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]): raise ValueError("input channels does not match filter's input channels, " - "{} != {}".format(value.get_shape()[4], filter.get_shape( + "{} != {}".format(value.get_shape()[axis], filter.get_shape( )[4])) output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")