Adding radial constraint for Conv2D kernels to make them rotation invariant.
PiperOrigin-RevId: 254129632
This commit is contained in:
parent
c1ca3700ed
commit
026bc91eae
@ -21,9 +21,12 @@ from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -168,12 +171,91 @@ class MinMaxNorm(Constraint):
|
||||
}
|
||||
|
||||
|
||||
@keras_export('keras.constraints.RadialConstraint',
|
||||
'keras.constraints.radial_constraint')
|
||||
class RadialConstraint(Constraint):
|
||||
"""Constrains `Conv2D` kernel weights to be the same for each radius.
|
||||
|
||||
For example, the desired output for the following 4-by-4 kernel::
|
||||
|
||||
```
|
||||
kernel = [[v_00, v_01, v_02, v_03],
|
||||
[v_10, v_11, v_12, v_13],
|
||||
[v_20, v_21, v_22, v_23],
|
||||
[v_30, v_31, v_32, v_33]]
|
||||
```
|
||||
|
||||
is this::
|
||||
|
||||
```
|
||||
kernel = [[v_11, v_11, v_11, v_11],
|
||||
[v_11, v_33, v_33, v_11],
|
||||
[v_11, v_33, v_33, v_11],
|
||||
[v_11, v_11, v_11, v_11]]
|
||||
```
|
||||
|
||||
This constraint can be applied to any `Conv2D` layer version, including
|
||||
`Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
|
||||
`"channels_first"` data format. The method assumes the weight tensor is of
|
||||
shape `(rows, cols, input_depth, output_depth)`.
|
||||
"""
|
||||
|
||||
def __call__(self, w):
|
||||
w_shape = w.shape
|
||||
if w_shape.rank is None or w_shape.rank != 4:
|
||||
raise ValueError(
|
||||
'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
|
||||
|
||||
height, width, channels, kernels = w_shape
|
||||
w = K.reshape(w, (height, width, channels * kernels))
|
||||
# TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch
|
||||
# is supported.
|
||||
w = K.map_fn(
|
||||
self._kernel_constraint,
|
||||
K.stack(array_ops.unstack(w, axis=-1), axis=0))
|
||||
return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1),
|
||||
(height, width, channels, kernels))
|
||||
|
||||
def _kernel_constraint(self, kernel):
|
||||
"""Radially constraints a kernel with shape (height, width, channels)."""
|
||||
padding = K.constant([[1, 1], [1, 1]], dtype='int32')
|
||||
|
||||
kernel_shape = K.shape(kernel)[0]
|
||||
start = K.cast(kernel_shape / 2, 'int32')
|
||||
|
||||
kernel_new = K.switch(
|
||||
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
|
||||
lambda: kernel[start - 1:start, start - 1:start],
|
||||
lambda: kernel[start - 1:start, start - 1:start] + K.zeros( # pylint: disable=g-long-lambda
|
||||
(2, 2), dtype=kernel.dtype))
|
||||
index = K.switch(
|
||||
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
|
||||
lambda: K.constant(0, dtype='int32'),
|
||||
lambda: K.constant(1, dtype='int32'))
|
||||
while_condition = lambda index, *args: K.less(index, start)
|
||||
|
||||
def body_fn(i, array):
|
||||
return i + 1, array_ops.pad(
|
||||
array,
|
||||
padding,
|
||||
constant_values=kernel[start + i, start + i])
|
||||
|
||||
_, kernel_new = control_flow_ops.while_loop(
|
||||
while_condition,
|
||||
body_fn,
|
||||
[index, kernel_new],
|
||||
shape_invariants=[index.get_shape(),
|
||||
tensor_shape.TensorShape([None, None])])
|
||||
return kernel_new
|
||||
|
||||
|
||||
# Aliases.
|
||||
|
||||
max_norm = MaxNorm
|
||||
non_neg = NonNeg
|
||||
unit_norm = UnitNorm
|
||||
min_max_norm = MinMaxNorm
|
||||
radial_constraint = RadialConstraint
|
||||
|
||||
# Legacy aliases.
|
||||
maxnorm = max_norm
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
@ -36,6 +38,12 @@ def get_example_array():
|
||||
return example_array
|
||||
|
||||
|
||||
def get_example_kernel(width):
|
||||
np.random.seed(3537)
|
||||
example_array = np.random.rand(width, width, 2, 2)
|
||||
return example_array
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class KerasConstraintsTest(test.TestCase):
|
||||
|
||||
@ -93,6 +101,16 @@ class KerasConstraintsTest(test.TestCase):
|
||||
assert not l2[l2 < m]
|
||||
assert not l2[l2 > m * 2 + 1e-5]
|
||||
|
||||
def test_conv2d_radial_constraint(self):
|
||||
for width in (3, 4, 5, 6):
|
||||
array = get_example_kernel(width)
|
||||
norm_instance = keras.constraints.radial_constraint()
|
||||
normed = norm_instance(keras.backend.variable(array))
|
||||
value = keras.backend.eval(normed)
|
||||
assert np.all(value.shape == array.shape)
|
||||
assert np.all(value[0:, 0, 0, 0] == value[-1:, 0, 0, 0])
|
||||
assert len(set(value[..., 0, 0].flatten())) == math.ceil(float(width) / 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -0,0 +1,13 @@
|
||||
path: "tensorflow.keras.constraints.RadialConstraint"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -16,6 +16,10 @@ tf_module {
|
||||
name: "NonNeg"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RadialConstraint"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "UnitNorm"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -32,6 +36,10 @@ tf_module {
|
||||
name: "non_neg"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "radial_constraint"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "unit_norm"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,13 @@
|
||||
path: "tensorflow.keras.constraints.radial_constraint"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
path: "tensorflow.keras.constraints.RadialConstraint"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -16,6 +16,10 @@ tf_module {
|
||||
name: "NonNeg"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RadialConstraint"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "UnitNorm"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -32,6 +36,10 @@ tf_module {
|
||||
name: "non_neg"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "radial_constraint"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "unit_norm"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,13 @@
|
||||
path: "tensorflow.keras.constraints.radial_constraint"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user