Introduce conv3d_transpose in tf.layers (#8461)

* Add Conv3DTranspose class implementation.

- Overrides call method of base class, still lacks
  documentation and tests.

* Add functional interface for conv3d transpose layer.

* Add tests for conv 3d transpose layer.

* Declare aliases and add docstrings for Conv3DTranspose.

* Shift Conv3DTranspose's get_deconv_dim to layers/utils.py

* Make one-line docstrings fit in one line.

* Add RELEASE.md entry and expose layer through tf.layers

* Restore edited conv3d_transpose docstring.
This commit is contained in:
Karan Desai 2017-04-26 21:42:23 +05:30 committed by Martin Wicke
parent 7bd133cd21
commit b3d1b4e220
7 changed files with 461 additions and 22 deletions

View File

@ -1,6 +1,7 @@
# Changes since the last release # Changes since the last release
## Major Features and Improvements ## Major Features and Improvements
* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution.
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times. * Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
* Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo). * Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described * `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described

View File

@ -970,7 +970,7 @@ def separable_conv2d(inputs,
class Conv2DTranspose(Conv2D): class Conv2DTranspose(Conv2D):
"""Transposed convolution layer (sometimes called Deconvolution). """Transposed 2D convolution layer (sometimes called 2D Deconvolution).
The need for transposed convolutions generally arises The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction from the desire to use a transformation going in the opposite direction
@ -1083,19 +1083,9 @@ class Conv2DTranspose(Conv2D):
kernel_h, kernel_w = self.kernel_size kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides stride_h, stride_w = self.strides
def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
if isinstance(dim_size, ops.Tensor):
dim_size = math_ops.multiply(dim_size, stride_size)
elif dim_size is not None:
dim_size *= stride_size
if padding == 'valid' and dim_size is not None:
dim_size += max(kernel_size - stride_size, 0)
return dim_size
# Infer the dynamic output shape: # Infer the dynamic output shape:
out_height = get_deconv_dim(height, stride_h, kernel_h, self.padding) out_height = utils.get_deconv_dim(height, stride_h, kernel_h, self.padding)
out_width = get_deconv_dim(width, stride_w, kernel_w, self.padding) out_width = utils.get_deconv_dim(width, stride_w, kernel_w, self.padding)
if self.data_format == 'channels_first': if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width) output_shape = (batch_size, self.filters, out_height, out_width)
@ -1116,9 +1106,9 @@ class Conv2DTranspose(Conv2D):
# Infer the static output shape: # Infer the static output shape:
out_shape = inputs.get_shape().as_list() out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters out_shape[c_axis] = self.filters
out_shape[h_axis] = get_deconv_dim( out_shape[h_axis] = utils.get_deconv_dim(
out_shape[h_axis], stride_h, kernel_h, self.padding) out_shape[h_axis], stride_h, kernel_h, self.padding)
out_shape[w_axis] = get_deconv_dim( out_shape[w_axis] = utils.get_deconv_dim(
out_shape[w_axis], stride_w, kernel_w, self.padding) out_shape[w_axis], stride_w, kernel_w, self.padding)
outputs.set_shape(out_shape) outputs.set_shape(out_shape)
@ -1149,7 +1139,7 @@ def conv2d_transpose(inputs,
trainable=True, trainable=True,
name=None, name=None,
reuse=None): reuse=None):
"""Transposed convolution layer (sometimes called Deconvolution). """Functional interface for transposed 2D convolution layer.
The need for transposed convolutions generally arises The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction from the desire to use a transformation going in the opposite direction
@ -1174,12 +1164,12 @@ def conv2d_transpose(inputs,
`channels_last` corresponds to inputs with shape `channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first` corresponds to `(batch, height, width, channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, height, width)`. inputs with shape `(batch, channels, height, width)`.
activation: Activation function. Set it to None to maintain a activation: Activation function. Set it to `None` to maintain a
linear activation. linear activation.
use_bias: Boolean, whether the layer uses a bias. use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel. kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If None, no bias will bias_initializer: An initializer for the bias vector. If `None`, then no
be applied. bias will be applied.
kernel_regularizer: Optional regularizer for the convolution kernel. kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector. bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Regularizer function for the output. activity_regularizer: Regularizer function for the output.
@ -1212,6 +1202,246 @@ def conv2d_transpose(inputs,
return layer.apply(inputs) return layer.apply(inputs)
class Conv3DTranspose(Conv3D):
"""Transposed 3D convolution layer (sometimes called 3D Deconvolution).
Arguments:
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
kernel_size: An integer or tuple/list of 3 integers, specifying the
depth, height and width of the 3D convolution window.
Can be a single integer to specify the same value for all spatial
dimensions.
strides: An integer or tuple/list of 3 integers, specifying the strides
of the convolution along the depth, height and width.
Can be a single integer to specify the same value for all spatial
dimensions.
padding: One of `"valid"` or `"same"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, depth, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, depth, height, width)`.
activation: Activation function. Set it to `None` to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If `None`, then no
bias will be applied.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: A string, the name of the layer.
"""
def __init__(self, filters,
kernel_size,
strides=(1, 1, 1),
padding='valid',
data_format='channels_last',
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
name=None,
**kwargs):
super(Conv3DTranspose, self).__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
trainable=trainable,
name=name, **kwargs)
def build(self, input_shape):
if len(input_shape) != 5:
raise ValueError('Inputs should have rank 5, ' +
'received input shape:', str(input_shape))
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined, found None: ' + str(input_shape))
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (self.filters, input_dim)
self.kernel = vs.get_variable('kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
trainable=True,
dtype=self.dtype)
if self.use_bias:
self.bias = vs.get_variable('bias',
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
trainable=True,
dtype=self.dtype)
else:
self.bias = None
def call(self, inputs):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
else:
c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
depth = inputs_shape[d_axis]
height = inputs_shape[h_axis]
width = inputs_shape[w_axis]
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
# Infer the dynamic output shape:
out_depth = utils.get_deconv_dim(depth, stride_d, kernel_d, self.padding)
out_height = utils.get_deconv_dim(height, stride_h, kernel_h, self.padding)
out_width = utils.get_deconv_dim(width, stride_w, kernel_w, self.padding)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_depth, out_height,
out_width)
strides = (1, 1, stride_d, stride_h, stride_w)
else:
output_shape = (batch_size, out_depth, out_height, out_width,
self.filters)
strides = (1, stride_d, stride_h, stride_w, 1)
output_shape_tensor = array_ops.stack(output_shape)
outputs = nn.conv3d_transpose(
inputs,
self.kernel,
output_shape_tensor,
strides,
data_format=utils.convert_data_format(self.data_format, ndim=5),
padding=self.padding.upper())
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
out_shape[d_axis] = utils.get_deconv_dim(
out_shape[d_axis], stride_d, kernel_d, self.padding)
out_shape[h_axis] = utils.get_deconv_dim(
out_shape[h_axis], stride_h, kernel_h, self.padding)
out_shape[w_axis] = utils.get_deconv_dim(
out_shape[w_axis], stride_w, kernel_w, self.padding)
outputs.set_shape(out_shape)
if self.bias:
outputs_shape = outputs.shape.as_list()
if self.data_format == 'channels_first':
outputs_4d = array_ops.reshape(outputs,
[outputs_shape[0], outputs_shape[1],
outputs_shape[2] * outputs_shape[3],
outputs_shape[4]])
else:
outputs_4d = array_ops.reshape(outputs,
[outputs_shape[0],
outputs_shape[1] * outputs_shape[2],
outputs_shape[3], outputs_shape[4]])
outputs_4d = nn.bias_add(
outputs_4d,
self.bias,
data_format=utils.convert_data_format(self.data_format, ndim=4))
outputs = array_ops.reshape(outputs_4d, outputs_shape)
if self.activation is not None:
return self.activation(outputs)
return outputs
def conv3d_transpose(inputs,
filters,
kernel_size,
strides=(1, 1, 1),
padding='valid',
data_format='channels_last',
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
name=None,
reuse=None):
"""Functional interface for transposed 3D convolution layer.
Arguments:
inputs: Input tensor.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
kernel_size: A tuple or list of 3 positive integers specifying the spatial
dimensions of of the filters. Can be a single integer to specify the same
value for all spatial dimensions.
strides: A tuple or list of 3 positive integers specifying the strides
of the convolution. Can be a single integer to specify the same value for
all spatial dimensions.
padding: one of `"valid"` or `"same"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, height, width)`.
activation: Activation function. Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If None, no bias will
be applied.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: A string, the name of the layer.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns:
Output tensor.
"""
layer = Conv3DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
trainable=trainable,
name=name,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
# Aliases # Aliases
Convolution1D = Conv1D Convolution1D = Conv1D
@ -1219,9 +1449,11 @@ Convolution2D = Conv2D
Convolution3D = Conv3D Convolution3D = Conv3D
SeparableConvolution2D = SeparableConv2D SeparableConvolution2D = SeparableConv2D
Convolution2DTranspose = Deconvolution2D = Deconv2D = Conv2DTranspose Convolution2DTranspose = Deconvolution2D = Deconv2D = Conv2DTranspose
Convolution3DTranspose = Deconvolution3D = Deconv3D = Conv3DTranspose
convolution1d = conv1d convolution1d = conv1d
convolution2d = conv2d convolution2d = conv2d
convolution3d = conv3d convolution3d = conv3d
separable_convolution2d = separable_conv2d separable_convolution2d = separable_conv2d
convolution2d_transpose = deconvolution2d = deconv2d = conv2d_transpose convolution2d_transpose = deconvolution2d = deconv2d = conv2d_transpose
convolution3d_transpose = deconvolution3d = deconv3d = conv3d_transpose

View File

@ -651,5 +651,175 @@ class Conv2DTransposeTest(test.TestCase):
self.assertEqual(len(variables.trainable_variables()), 4) self.assertEqual(len(variables.trainable_variables()), 4)
class Conv3DTransposeTest(test.TestCase):
def testInvalidDataFormat(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
with self.assertRaisesRegexp(ValueError, 'data_format'):
conv_layers.conv3d_transpose(volumes, 4, 3, data_format='invalid')
def testInvalidStrides(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
with self.assertRaisesRegexp(ValueError, 'strides'):
conv_layers.conv3d_transpose(volumes, 4, 3, strides=(1, 2))
with self.assertRaisesRegexp(ValueError, 'strides'):
conv_layers.conv3d_transpose(volumes, 4, 3, strides=None)
def testInvalidKernelSize(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
with self.assertRaisesRegexp(ValueError, 'kernel_size'):
conv_layers.conv3d_transpose(volumes, 4, (1, 2))
with self.assertRaisesRegexp(ValueError, 'kernel_size'):
conv_layers.conv3d_transpose(volumes, 4, None)
def testCreateConv3DTranspose(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], activation=nn_ops.relu)
output = layer.apply(volumes)
self.assertEqual(output.op.name, 'conv3d_transpose/Relu')
self.assertListEqual(output.get_shape().as_list(),
[5, depth + 2, height + 2, width + 2, 4])
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
self.assertListEqual(layer.bias.get_shape().as_list(), [4])
def testCreateConv3DTransposeIntegerKernelSize(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
layer = conv_layers.Conv3DTranspose(4, 3)
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth + 2, height + 2, width + 2, 4])
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
self.assertListEqual(layer.bias.get_shape().as_list(), [4])
def testCreateConv3DTransposeChannelsFirst(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, 32, depth, height, width))
layer = conv_layers.Conv3DTranspose(
4, [3, 3, 3], data_format='channels_first')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, 4, depth + 2, height + 2, width + 2])
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
self.assertListEqual(layer.bias.get_shape().as_list(), [4])
def testConv3DTransposePaddingSame(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 64), seed=1)
layer = conv_layers.Conv3DTranspose(
32, volumes.get_shape()[1:4], padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(), [5, depth, height,
width, 32])
def testCreateConv3DTransposeWithStrides(self):
depth, height, width = 4, 6, 8
# Test strides tuple.
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
layer = conv_layers.Conv3DTranspose(
4, [3, 3, 3], strides=(2, 2, 2), padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth * 2, height * 2, width * 2, 4])
# Test strides integer.
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2,
padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth * 2, height * 2, width * 2, 4])
# Test unequal strides.
layer = conv_layers.Conv3DTranspose(
4, [3, 3, 3], strides=(2, 1, 1), padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth * 2, height, width, 4])
def testConv3DTransposeKernelRegularizer(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], kernel_regularizer=reg)
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(layer.losses, loss_keys)
def testConv3DTransposeBiasRegularizer(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], bias_regularizer=reg)
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(layer.losses, loss_keys)
def testConv3DTransposeNoBias(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
layer = conv_layers.Conv3DTranspose(
4, [3, 3, 3], activation=nn_ops.relu, use_bias=False)
output = layer.apply(volumes)
self.assertEqual(output.op.name, 'conv3d_transpose/Relu')
self.assertListEqual(output.get_shape().as_list(),
[5, depth + 2, height + 2, width + 2, 4])
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
self.assertEqual(layer.bias, None)
def testFunctionalConv3DTransposeReuse(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv3DTransposeReuseFromScope(self):
with variable_scope.variable_scope('scope'):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32),
seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2)
with variable_scope.variable_scope('scope', reuse=True):
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv3DTransposeInitializerFromScope(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
'scope', initializer=init_ops.ones_initializer()):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32),
seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
weights = variables.trainable_variables()
# Check the names of weights in order.
self.assertTrue('kernel' in weights[0].name)
self.assertTrue('bias' in weights[1].name)
sess.run(variables.global_variables_initializer())
weights = sess.run(weights)
# Check that the kernel weights got initialized to ones (from scope)
self.assertAllClose(weights[0], np.ones((3, 3, 3, 4, 32)))
# Check that the bias still got initialized to zeros.
self.assertAllClose(weights[1], np.zeros((4)))
def testFunctionalConv3DTransposeNoReuse(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3])
self.assertEqual(len(variables.trainable_variables()), 2)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3])
self.assertEqual(len(variables.trainable_variables()), 4)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -23,6 +23,7 @@
@@conv3d @@conv3d
@@separable_conv2d @@separable_conv2d
@@conv2d_transpose @@conv2d_transpose
@@conv3d_transpose
@@average_pooling1d @@average_pooling1d
@@max_pooling1d @@max_pooling1d
@@average_pooling2d @@average_pooling2d
@ -50,6 +51,7 @@ from tensorflow.python.layers.convolutional import conv2d
from tensorflow.python.layers.convolutional import conv3d from tensorflow.python.layers.convolutional import conv3d
from tensorflow.python.layers.convolutional import separable_conv2d from tensorflow.python.layers.convolutional import separable_conv2d
from tensorflow.python.layers.convolutional import conv2d_transpose from tensorflow.python.layers.convolutional import conv2d_transpose
from tensorflow.python.layers.convolutional import conv3d_transpose
# Pooling layers. # Pooling layers.
from tensorflow.python.layers.pooling import average_pooling1d from tensorflow.python.layers.pooling import average_pooling1d

View File

@ -26,6 +26,7 @@ import numpy as np
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
@ -164,3 +165,28 @@ def constant_value(pred):
else: else:
raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.')
return pred_value return pred_value
def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
"""Return output dimension of a deconv layer, based on input dimension.
Arguments:
dim_size: An int representing size of dimension, can be height, width
or depth.
stride_size: An int representing the stride of deconvolution filters
along the same dimension.
kernel_size: An int representing size of deconv kernel (filter) along
the same dimension.
padding: one of `"valid"` or `"same"` (case-insensitive).
Returns:
An int representing the size of output dimension of the layer.
"""
if isinstance(dim_size, ops.Tensor):
dim_size = math_ops.multiply(dim_size, stride_size)
elif dim_size is not None:
dim_size *= stride_size
if padding == 'valid' and dim_size is not None:
dim_size += max(kernel_size - stride_size, 0)
return dim_size

View File

@ -62,6 +62,13 @@ class ConvUtilsTest(test.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
utils.normalize_padding('invalid') utils.normalize_padding('invalid')
def testGetDeconvDim(self):
self.assertEqual(utils.get_deconv_dim(30, 1, 3, 'valid'), 32)
self.assertEqual(utils.get_deconv_dim(28, 1, 5, 'valid'), 32)
self.assertEqual(utils.get_deconv_dim(28, 2, 5, 'valid'), 59)
self.assertEqual(utils.get_deconv_dim(32, 1, 3, 'same'), 32)
self.assertEqual(utils.get_deconv_dim(32, 1, 5, 'same'), 32)
self.assertEqual(utils.get_deconv_dim(32, 2, 5, 'same'), 64)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -1272,7 +1272,7 @@ def conv3d_transpose(value,
output_shape, output_shape,
strides, strides,
padding="SAME", padding="SAME",
data_format=None, data_format="NDHWC",
name=None): name=None):
"""The transpose of `conv3d`. """The transpose of `conv3d`.
@ -1308,9 +1308,10 @@ def conv3d_transpose(value,
[value, filter, output_shape]) as name: [value, filter, output_shape]) as name:
value = ops.convert_to_tensor(value, name="value") value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter") filter = ops.convert_to_tensor(filter, name="filter")
if not value.get_shape()[4].is_compatible_with(filter.get_shape()[4]): axis = 1 if data_format == "NCDHW" else 4
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]):
raise ValueError("input channels does not match filter's input channels, " raise ValueError("input channels does not match filter's input channels, "
"{} != {}".format(value.get_shape()[4], filter.get_shape( "{} != {}".format(value.get_shape()[axis], filter.get_shape(
)[4])) )[4]))
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")