Introduce conv3d_transpose in tf.layers (#8461)

* Add Conv3DTranspose class implementation.

- Overrides call method of base class, still lacks
  documentation and tests.

* Add functional interface for conv3d transpose layer.

* Add tests for conv 3d transpose layer.

* Declare aliases and add docstrings for Conv3DTranspose.

* Shift Conv3DTranspose's get_deconv_dim to layers/utils.py

* Make one-line docstrings fit in one line.

* Add RELEASE.md entry and expose layer through tf.layers

* Restore edited conv3d_transpose docstring.
This commit is contained in:
Karan Desai 2017-04-26 21:42:23 +05:30 committed by Martin Wicke
parent 7bd133cd21
commit b3d1b4e220
7 changed files with 461 additions and 22 deletions

View File

@ -1,6 +1,7 @@
# Changes since the last release
## Major Features and Improvements
* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution.
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
* Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described

View File

@ -970,7 +970,7 @@ def separable_conv2d(inputs,
class Conv2DTranspose(Conv2D):
"""Transposed convolution layer (sometimes called Deconvolution).
"""Transposed 2D convolution layer (sometimes called 2D Deconvolution).
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
@ -1083,19 +1083,9 @@ class Conv2DTranspose(Conv2D):
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
if isinstance(dim_size, ops.Tensor):
dim_size = math_ops.multiply(dim_size, stride_size)
elif dim_size is not None:
dim_size *= stride_size
if padding == 'valid' and dim_size is not None:
dim_size += max(kernel_size - stride_size, 0)
return dim_size
# Infer the dynamic output shape:
out_height = get_deconv_dim(height, stride_h, kernel_h, self.padding)
out_width = get_deconv_dim(width, stride_w, kernel_w, self.padding)
out_height = utils.get_deconv_dim(height, stride_h, kernel_h, self.padding)
out_width = utils.get_deconv_dim(width, stride_w, kernel_w, self.padding)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
@ -1116,9 +1106,9 @@ class Conv2DTranspose(Conv2D):
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
out_shape[h_axis] = get_deconv_dim(
out_shape[h_axis] = utils.get_deconv_dim(
out_shape[h_axis], stride_h, kernel_h, self.padding)
out_shape[w_axis] = get_deconv_dim(
out_shape[w_axis] = utils.get_deconv_dim(
out_shape[w_axis], stride_w, kernel_w, self.padding)
outputs.set_shape(out_shape)
@ -1149,7 +1139,7 @@ def conv2d_transpose(inputs,
trainable=True,
name=None,
reuse=None):
"""Transposed convolution layer (sometimes called Deconvolution).
"""Functional interface for transposed 2D convolution layer.
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
@ -1174,12 +1164,12 @@ def conv2d_transpose(inputs,
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, height, width)`.
activation: Activation function. Set it to None to maintain a
activation: Activation function. Set it to `None` to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If None, no bias will
be applied.
bias_initializer: An initializer for the bias vector. If `None`, then no
bias will be applied.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Regularizer function for the output.
@ -1212,6 +1202,246 @@ def conv2d_transpose(inputs,
return layer.apply(inputs)
class Conv3DTranspose(Conv3D):
"""Transposed 3D convolution layer (sometimes called 3D Deconvolution).
Arguments:
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
kernel_size: An integer or tuple/list of 3 integers, specifying the
depth, height and width of the 3D convolution window.
Can be a single integer to specify the same value for all spatial
dimensions.
strides: An integer or tuple/list of 3 integers, specifying the strides
of the convolution along the depth, height and width.
Can be a single integer to specify the same value for all spatial
dimensions.
padding: One of `"valid"` or `"same"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, depth, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, depth, height, width)`.
activation: Activation function. Set it to `None` to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If `None`, then no
bias will be applied.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: A string, the name of the layer.
"""
def __init__(self, filters,
kernel_size,
strides=(1, 1, 1),
padding='valid',
data_format='channels_last',
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
name=None,
**kwargs):
super(Conv3DTranspose, self).__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
trainable=trainable,
name=name, **kwargs)
def build(self, input_shape):
if len(input_shape) != 5:
raise ValueError('Inputs should have rank 5, ' +
'received input shape:', str(input_shape))
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined, found None: ' + str(input_shape))
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (self.filters, input_dim)
self.kernel = vs.get_variable('kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
trainable=True,
dtype=self.dtype)
if self.use_bias:
self.bias = vs.get_variable('bias',
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
trainable=True,
dtype=self.dtype)
else:
self.bias = None
def call(self, inputs):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
else:
c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
depth = inputs_shape[d_axis]
height = inputs_shape[h_axis]
width = inputs_shape[w_axis]
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
# Infer the dynamic output shape:
out_depth = utils.get_deconv_dim(depth, stride_d, kernel_d, self.padding)
out_height = utils.get_deconv_dim(height, stride_h, kernel_h, self.padding)
out_width = utils.get_deconv_dim(width, stride_w, kernel_w, self.padding)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_depth, out_height,
out_width)
strides = (1, 1, stride_d, stride_h, stride_w)
else:
output_shape = (batch_size, out_depth, out_height, out_width,
self.filters)
strides = (1, stride_d, stride_h, stride_w, 1)
output_shape_tensor = array_ops.stack(output_shape)
outputs = nn.conv3d_transpose(
inputs,
self.kernel,
output_shape_tensor,
strides,
data_format=utils.convert_data_format(self.data_format, ndim=5),
padding=self.padding.upper())
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
out_shape[d_axis] = utils.get_deconv_dim(
out_shape[d_axis], stride_d, kernel_d, self.padding)
out_shape[h_axis] = utils.get_deconv_dim(
out_shape[h_axis], stride_h, kernel_h, self.padding)
out_shape[w_axis] = utils.get_deconv_dim(
out_shape[w_axis], stride_w, kernel_w, self.padding)
outputs.set_shape(out_shape)
if self.bias:
outputs_shape = outputs.shape.as_list()
if self.data_format == 'channels_first':
outputs_4d = array_ops.reshape(outputs,
[outputs_shape[0], outputs_shape[1],
outputs_shape[2] * outputs_shape[3],
outputs_shape[4]])
else:
outputs_4d = array_ops.reshape(outputs,
[outputs_shape[0],
outputs_shape[1] * outputs_shape[2],
outputs_shape[3], outputs_shape[4]])
outputs_4d = nn.bias_add(
outputs_4d,
self.bias,
data_format=utils.convert_data_format(self.data_format, ndim=4))
outputs = array_ops.reshape(outputs_4d, outputs_shape)
if self.activation is not None:
return self.activation(outputs)
return outputs
def conv3d_transpose(inputs,
filters,
kernel_size,
strides=(1, 1, 1),
padding='valid',
data_format='channels_last',
activation=None,
use_bias=True,
kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(),
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
name=None,
reuse=None):
"""Functional interface for transposed 3D convolution layer.
Arguments:
inputs: Input tensor.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
kernel_size: A tuple or list of 3 positive integers specifying the spatial
dimensions of of the filters. Can be a single integer to specify the same
value for all spatial dimensions.
strides: A tuple or list of 3 positive integers specifying the strides
of the convolution. Can be a single integer to specify the same value for
all spatial dimensions.
padding: one of `"valid"` or `"same"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, height, width)`.
activation: Activation function. Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
kernel_initializer: An initializer for the convolution kernel.
bias_initializer: An initializer for the bias vector. If None, no bias will
be applied.
kernel_regularizer: Optional regularizer for the convolution kernel.
bias_regularizer: Optional regularizer for the bias vector.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: A string, the name of the layer.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns:
Output tensor.
"""
layer = Conv3DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
trainable=trainable,
name=name,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
# Aliases
Convolution1D = Conv1D
@ -1219,9 +1449,11 @@ Convolution2D = Conv2D
Convolution3D = Conv3D
SeparableConvolution2D = SeparableConv2D
Convolution2DTranspose = Deconvolution2D = Deconv2D = Conv2DTranspose
Convolution3DTranspose = Deconvolution3D = Deconv3D = Conv3DTranspose
convolution1d = conv1d
convolution2d = conv2d
convolution3d = conv3d
separable_convolution2d = separable_conv2d
convolution2d_transpose = deconvolution2d = deconv2d = conv2d_transpose
convolution3d_transpose = deconvolution3d = deconv3d = conv3d_transpose

View File

@ -651,5 +651,175 @@ class Conv2DTransposeTest(test.TestCase):
self.assertEqual(len(variables.trainable_variables()), 4)
class Conv3DTransposeTest(test.TestCase):
def testInvalidDataFormat(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
with self.assertRaisesRegexp(ValueError, 'data_format'):
conv_layers.conv3d_transpose(volumes, 4, 3, data_format='invalid')
def testInvalidStrides(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
with self.assertRaisesRegexp(ValueError, 'strides'):
conv_layers.conv3d_transpose(volumes, 4, 3, strides=(1, 2))
with self.assertRaisesRegexp(ValueError, 'strides'):
conv_layers.conv3d_transpose(volumes, 4, 3, strides=None)
def testInvalidKernelSize(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
with self.assertRaisesRegexp(ValueError, 'kernel_size'):
conv_layers.conv3d_transpose(volumes, 4, (1, 2))
with self.assertRaisesRegexp(ValueError, 'kernel_size'):
conv_layers.conv3d_transpose(volumes, 4, None)
def testCreateConv3DTranspose(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], activation=nn_ops.relu)
output = layer.apply(volumes)
self.assertEqual(output.op.name, 'conv3d_transpose/Relu')
self.assertListEqual(output.get_shape().as_list(),
[5, depth + 2, height + 2, width + 2, 4])
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
self.assertListEqual(layer.bias.get_shape().as_list(), [4])
def testCreateConv3DTransposeIntegerKernelSize(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
layer = conv_layers.Conv3DTranspose(4, 3)
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth + 2, height + 2, width + 2, 4])
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
self.assertListEqual(layer.bias.get_shape().as_list(), [4])
def testCreateConv3DTransposeChannelsFirst(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, 32, depth, height, width))
layer = conv_layers.Conv3DTranspose(
4, [3, 3, 3], data_format='channels_first')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, 4, depth + 2, height + 2, width + 2])
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
self.assertListEqual(layer.bias.get_shape().as_list(), [4])
def testConv3DTransposePaddingSame(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 64), seed=1)
layer = conv_layers.Conv3DTranspose(
32, volumes.get_shape()[1:4], padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(), [5, depth, height,
width, 32])
def testCreateConv3DTransposeWithStrides(self):
depth, height, width = 4, 6, 8
# Test strides tuple.
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
layer = conv_layers.Conv3DTranspose(
4, [3, 3, 3], strides=(2, 2, 2), padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth * 2, height * 2, width * 2, 4])
# Test strides integer.
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2,
padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth * 2, height * 2, width * 2, 4])
# Test unequal strides.
layer = conv_layers.Conv3DTranspose(
4, [3, 3, 3], strides=(2, 1, 1), padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth * 2, height, width, 4])
def testConv3DTransposeKernelRegularizer(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], kernel_regularizer=reg)
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(layer.losses, loss_keys)
def testConv3DTransposeBiasRegularizer(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], bias_regularizer=reg)
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
self.assertListEqual(layer.losses, loss_keys)
def testConv3DTransposeNoBias(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32))
layer = conv_layers.Conv3DTranspose(
4, [3, 3, 3], activation=nn_ops.relu, use_bias=False)
output = layer.apply(volumes)
self.assertEqual(output.op.name, 'conv3d_transpose/Relu')
self.assertListEqual(output.get_shape().as_list(),
[5, depth + 2, height + 2, width + 2, 4])
self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
self.assertEqual(layer.bias, None)
def testFunctionalConv3DTransposeReuse(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv3DTransposeReuseFromScope(self):
with variable_scope.variable_scope('scope'):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32),
seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2)
with variable_scope.variable_scope('scope', reuse=True):
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv3DTransposeInitializerFromScope(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
'scope', initializer=init_ops.ones_initializer()):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32),
seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
weights = variables.trainable_variables()
# Check the names of weights in order.
self.assertTrue('kernel' in weights[0].name)
self.assertTrue('bias' in weights[1].name)
sess.run(variables.global_variables_initializer())
weights = sess.run(weights)
# Check that the kernel weights got initialized to ones (from scope)
self.assertAllClose(weights[0], np.ones((3, 3, 3, 4, 32)))
# Check that the bias still got initialized to zeros.
self.assertAllClose(weights[1], np.zeros((4)))
def testFunctionalConv3DTransposeNoReuse(self):
depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3])
self.assertEqual(len(variables.trainable_variables()), 2)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3])
self.assertEqual(len(variables.trainable_variables()), 4)
if __name__ == '__main__':
test.main()

View File

@ -23,6 +23,7 @@
@@conv3d
@@separable_conv2d
@@conv2d_transpose
@@conv3d_transpose
@@average_pooling1d
@@max_pooling1d
@@average_pooling2d
@ -50,6 +51,7 @@ from tensorflow.python.layers.convolutional import conv2d
from tensorflow.python.layers.convolutional import conv3d
from tensorflow.python.layers.convolutional import separable_conv2d
from tensorflow.python.layers.convolutional import conv2d_transpose
from tensorflow.python.layers.convolutional import conv3d_transpose
# Pooling layers.
from tensorflow.python.layers.pooling import average_pooling1d

View File

@ -26,6 +26,7 @@ import numpy as np
from tensorflow.python.ops import variables
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@ -164,3 +165,28 @@ def constant_value(pred):
else:
raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.')
return pred_value
def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
"""Return output dimension of a deconv layer, based on input dimension.
Arguments:
dim_size: An int representing size of dimension, can be height, width
or depth.
stride_size: An int representing the stride of deconvolution filters
along the same dimension.
kernel_size: An int representing size of deconv kernel (filter) along
the same dimension.
padding: one of `"valid"` or `"same"` (case-insensitive).
Returns:
An int representing the size of output dimension of the layer.
"""
if isinstance(dim_size, ops.Tensor):
dim_size = math_ops.multiply(dim_size, stride_size)
elif dim_size is not None:
dim_size *= stride_size
if padding == 'valid' and dim_size is not None:
dim_size += max(kernel_size - stride_size, 0)
return dim_size

View File

@ -62,6 +62,13 @@ class ConvUtilsTest(test.TestCase):
with self.assertRaises(ValueError):
utils.normalize_padding('invalid')
def testGetDeconvDim(self):
self.assertEqual(utils.get_deconv_dim(30, 1, 3, 'valid'), 32)
self.assertEqual(utils.get_deconv_dim(28, 1, 5, 'valid'), 32)
self.assertEqual(utils.get_deconv_dim(28, 2, 5, 'valid'), 59)
self.assertEqual(utils.get_deconv_dim(32, 1, 3, 'same'), 32)
self.assertEqual(utils.get_deconv_dim(32, 1, 5, 'same'), 32)
self.assertEqual(utils.get_deconv_dim(32, 2, 5, 'same'), 64)
if __name__ == '__main__':
test.main()

View File

@ -1272,7 +1272,7 @@ def conv3d_transpose(value,
output_shape,
strides,
padding="SAME",
data_format=None,
data_format="NDHWC",
name=None):
"""The transpose of `conv3d`.
@ -1308,9 +1308,10 @@ def conv3d_transpose(value,
[value, filter, output_shape]) as name:
value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter")
if not value.get_shape()[4].is_compatible_with(filter.get_shape()[4]):
axis = 1 if data_format == "NCDHW" else 4
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]):
raise ValueError("input channels does not match filter's input channels, "
"{} != {}".format(value.get_shape()[4], filter.get_shape(
"{} != {}".format(value.get_shape()[axis], filter.get_shape(
)[4]))
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")