Bugfix: number of input channels is not necessarily in the last dimension, after introduction of data_format param.
PiperOrigin-RevId: 164889729
This commit is contained in:
parent
8f9b1af8ae
commit
58c4a4cb1b
@ -2548,7 +2548,8 @@ def separable_convolution2d(
|
||||
dtype = inputs.dtype.base_dtype
|
||||
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
|
||||
stride_h, stride_w = utils.two_element_tuple(stride)
|
||||
num_filters_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
|
||||
num_filters_in = utils.channel_dimension(
|
||||
inputs.get_shape(), df, min_rank=4)
|
||||
weights_collections = utils.get_variable_collections(
|
||||
variables_collections, 'weights')
|
||||
|
||||
|
@ -3230,12 +3230,13 @@ class SeparableConv2dTest(test.TestCase):
|
||||
def testConvNCHW(self):
|
||||
for num_filters, correct_output_filters in [(None, 6), (8, 8)]:
|
||||
with self.test_session():
|
||||
height, width = 3, 3
|
||||
images = random_ops.random_uniform((5, 3, height, width), seed=1)
|
||||
batch, height, width = 4, 5, 6
|
||||
images = random_ops.random_uniform((batch, 3, height, width), seed=1)
|
||||
output = layers_lib.separable_conv2d(
|
||||
images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW')
|
||||
self.assertListEqual(
|
||||
output.get_shape().as_list(), [5, correct_output_filters, 1, 1])
|
||||
output.get_shape().as_list(), [batch, correct_output_filters,
|
||||
height - 2, width - 2])
|
||||
|
||||
|
||||
class ScaleGradientTests(test.TestCase):
|
||||
|
@ -33,8 +33,8 @@ __all__ = ['collect_named_outputs',
|
||||
'get_variable_collections',
|
||||
'two_element_tuple',
|
||||
'n_positive_integers',
|
||||
'last_dimension',
|
||||
'first_dimension']
|
||||
'channel_dimension',
|
||||
'last_dimension']
|
||||
|
||||
NamedOutputs = namedtuple('NamedOutputs', ['name', 'outputs'])
|
||||
|
||||
@ -220,15 +220,16 @@ def get_variable_collections(variables_collections, name):
|
||||
return variable_collections
|
||||
|
||||
|
||||
def first_dimension(shape, min_rank=1):
|
||||
"""Returns the first dimension of shape while checking it has min_rank.
|
||||
def _get_dimension(shape, dim, min_rank=1):
|
||||
"""Returns the `dim` dimension of `shape`, while checking it has `min_rank`.
|
||||
|
||||
Args:
|
||||
shape: A `TensorShape`.
|
||||
dim: Integer, which dimension to return.
|
||||
min_rank: Integer, minimum rank of shape.
|
||||
|
||||
Returns:
|
||||
The value of the first dimension.
|
||||
The value of the `dim` dimension.
|
||||
|
||||
Raises:
|
||||
ValueError: if inputs don't have at least min_rank dimensions, or if the
|
||||
@ -240,12 +241,32 @@ def first_dimension(shape, min_rank=1):
|
||||
if len(dims) < min_rank:
|
||||
raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
|
||||
len(dims)))
|
||||
value = dims[0].value
|
||||
value = dims[dim].value
|
||||
if value is None:
|
||||
raise ValueError('first dimension shape must be known but is None')
|
||||
raise ValueError(
|
||||
'dimension %d of shape must be known but is None: %s' % (dim, shape))
|
||||
return value
|
||||
|
||||
|
||||
def channel_dimension(shape, data_format, min_rank=1):
|
||||
"""Returns the channel dimension of shape, while checking it has min_rank.
|
||||
|
||||
Args:
|
||||
shape: A `TensorShape`.
|
||||
data_format: `channels_first` or `channels_last`.
|
||||
min_rank: Integer, minimum rank of shape.
|
||||
|
||||
Returns:
|
||||
The value of the first dimension.
|
||||
|
||||
Raises:
|
||||
ValueError: if inputs don't have at least min_rank dimensions, or if the
|
||||
first dimension value is not defined.
|
||||
"""
|
||||
return _get_dimension(shape, 1 if data_format == 'channels_first' else -1,
|
||||
min_rank=min_rank)
|
||||
|
||||
|
||||
def last_dimension(shape, min_rank=1):
|
||||
"""Returns the last dimension of shape while checking it has min_rank.
|
||||
|
||||
@ -260,16 +281,7 @@ def last_dimension(shape, min_rank=1):
|
||||
ValueError: if inputs don't have at least min_rank dimensions, or if the
|
||||
last dimension value is not defined.
|
||||
"""
|
||||
dims = shape.dims
|
||||
if dims is None:
|
||||
raise ValueError('dims of shape must be known but is None')
|
||||
if len(dims) < min_rank:
|
||||
raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
|
||||
len(dims)))
|
||||
value = dims[-1].value
|
||||
if value is None:
|
||||
raise ValueError('last dimension shape must be known but is None')
|
||||
return value
|
||||
return _get_dimension(shape, -1, min_rank=min_rank)
|
||||
|
||||
|
||||
def two_element_tuple(int_or_tuple):
|
||||
|
Loading…
Reference in New Issue
Block a user