Adding radial constraint for Conv2D kernels to make them rotation invariant.

PiperOrigin-RevId: 254129632
This commit is contained in:
A. Unique TensorFlower 2019-06-19 21:27:36 -07:00 committed by TensorFlower Gardener
parent c1ca3700ed
commit 026bc91eae
8 changed files with 168 additions and 0 deletions

View File

@ -21,9 +21,12 @@ from __future__ import print_function
import six import six
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -168,12 +171,91 @@ class MinMaxNorm(Constraint):
} }
@keras_export('keras.constraints.RadialConstraint',
'keras.constraints.radial_constraint')
class RadialConstraint(Constraint):
"""Constrains `Conv2D` kernel weights to be the same for each radius.
For example, the desired output for the following 4-by-4 kernel::
```
kernel = [[v_00, v_01, v_02, v_03],
[v_10, v_11, v_12, v_13],
[v_20, v_21, v_22, v_23],
[v_30, v_31, v_32, v_33]]
```
is this::
```
kernel = [[v_11, v_11, v_11, v_11],
[v_11, v_33, v_33, v_11],
[v_11, v_33, v_33, v_11],
[v_11, v_11, v_11, v_11]]
```
This constraint can be applied to any `Conv2D` layer version, including
`Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
`"channels_first"` data format. The method assumes the weight tensor is of
shape `(rows, cols, input_depth, output_depth)`.
"""
def __call__(self, w):
w_shape = w.shape
if w_shape.rank is None or w_shape.rank != 4:
raise ValueError(
'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
height, width, channels, kernels = w_shape
w = K.reshape(w, (height, width, channels * kernels))
# TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch
# is supported.
w = K.map_fn(
self._kernel_constraint,
K.stack(array_ops.unstack(w, axis=-1), axis=0))
return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1),
(height, width, channels, kernels))
def _kernel_constraint(self, kernel):
"""Radially constraints a kernel with shape (height, width, channels)."""
padding = K.constant([[1, 1], [1, 1]], dtype='int32')
kernel_shape = K.shape(kernel)[0]
start = K.cast(kernel_shape / 2, 'int32')
kernel_new = K.switch(
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
lambda: kernel[start - 1:start, start - 1:start],
lambda: kernel[start - 1:start, start - 1:start] + K.zeros( # pylint: disable=g-long-lambda
(2, 2), dtype=kernel.dtype))
index = K.switch(
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
lambda: K.constant(0, dtype='int32'),
lambda: K.constant(1, dtype='int32'))
while_condition = lambda index, *args: K.less(index, start)
def body_fn(i, array):
return i + 1, array_ops.pad(
array,
padding,
constant_values=kernel[start + i, start + i])
_, kernel_new = control_flow_ops.while_loop(
while_condition,
body_fn,
[index, kernel_new],
shape_invariants=[index.get_shape(),
tensor_shape.TensorShape([None, None])])
return kernel_new
# Aliases. # Aliases.
max_norm = MaxNorm max_norm = MaxNorm
non_neg = NonNeg non_neg = NonNeg
unit_norm = UnitNorm unit_norm = UnitNorm
min_max_norm = MinMaxNorm min_max_norm = MinMaxNorm
radial_constraint = RadialConstraint
# Legacy aliases. # Legacy aliases.
maxnorm = max_norm maxnorm = max_norm

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
@ -36,6 +38,12 @@ def get_example_array():
return example_array return example_array
def get_example_kernel(width):
np.random.seed(3537)
example_array = np.random.rand(width, width, 2, 2)
return example_array
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class KerasConstraintsTest(test.TestCase): class KerasConstraintsTest(test.TestCase):
@ -93,6 +101,16 @@ class KerasConstraintsTest(test.TestCase):
assert not l2[l2 < m] assert not l2[l2 < m]
assert not l2[l2 > m * 2 + 1e-5] assert not l2[l2 > m * 2 + 1e-5]
def test_conv2d_radial_constraint(self):
for width in (3, 4, 5, 6):
array = get_example_kernel(width)
norm_instance = keras.constraints.radial_constraint()
normed = norm_instance(keras.backend.variable(array))
value = keras.backend.eval(normed)
assert np.all(value.shape == array.shape)
assert np.all(value[0:, 0, 0, 0] == value[-1:, 0, 0, 0])
assert len(set(value[..., 0, 0].flatten())) == math.ceil(float(width) / 2)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -0,0 +1,13 @@
path: "tensorflow.keras.constraints.RadialConstraint"
tf_class {
is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -16,6 +16,10 @@ tf_module {
name: "NonNeg" name: "NonNeg"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "RadialConstraint"
mtype: "<type \'type\'>"
}
member { member {
name: "UnitNorm" name: "UnitNorm"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
@ -32,6 +36,10 @@ tf_module {
name: "non_neg" name: "non_neg"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "radial_constraint"
mtype: "<type \'type\'>"
}
member { member {
name: "unit_norm" name: "unit_norm"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -0,0 +1,13 @@
path: "tensorflow.keras.constraints.radial_constraint"
tf_class {
is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,13 @@
path: "tensorflow.keras.constraints.RadialConstraint"
tf_class {
is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -16,6 +16,10 @@ tf_module {
name: "NonNeg" name: "NonNeg"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "RadialConstraint"
mtype: "<type \'type\'>"
}
member { member {
name: "UnitNorm" name: "UnitNorm"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
@ -32,6 +36,10 @@ tf_module {
name: "non_neg" name: "non_neg"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "radial_constraint"
mtype: "<type \'type\'>"
}
member { member {
name: "unit_norm" name: "unit_norm"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -0,0 +1,13 @@
path: "tensorflow.keras.constraints.radial_constraint"
tf_class {
is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}