diff --git a/tensorflow/python/keras/constraints.py b/tensorflow/python/keras/constraints.py index 334d072d5a2..043ceb8dd6d 100644 --- a/tensorflow/python/keras/constraints.py +++ b/tensorflow/python/keras/constraints.py @@ -21,9 +21,12 @@ from __future__ import print_function import six +from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import keras_export @@ -168,12 +171,91 @@ class MinMaxNorm(Constraint): } +@keras_export('keras.constraints.RadialConstraint', + 'keras.constraints.radial_constraint') +class RadialConstraint(Constraint): + """Constrains `Conv2D` kernel weights to be the same for each radius. + + For example, the desired output for the following 4-by-4 kernel:: + + ``` + kernel = [[v_00, v_01, v_02, v_03], + [v_10, v_11, v_12, v_13], + [v_20, v_21, v_22, v_23], + [v_30, v_31, v_32, v_33]] + ``` + + is this:: + + ``` + kernel = [[v_11, v_11, v_11, v_11], + [v_11, v_33, v_33, v_11], + [v_11, v_33, v_33, v_11], + [v_11, v_11, v_11, v_11]] + ``` + + This constraint can be applied to any `Conv2D` layer version, including + `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or + `"channels_first"` data format. The method assumes the weight tensor is of + shape `(rows, cols, input_depth, output_depth)`. + """ + + def __call__(self, w): + w_shape = w.shape + if w_shape.rank is None or w_shape.rank != 4: + raise ValueError( + 'The weight tensor must be of rank 4, but is of shape: %s' % w_shape) + + height, width, channels, kernels = w_shape + w = K.reshape(w, (height, width, channels * kernels)) + # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch + # is supported. + w = K.map_fn( + self._kernel_constraint, + K.stack(array_ops.unstack(w, axis=-1), axis=0)) + return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1), + (height, width, channels, kernels)) + + def _kernel_constraint(self, kernel): + """Radially constraints a kernel with shape (height, width, channels).""" + padding = K.constant([[1, 1], [1, 1]], dtype='int32') + + kernel_shape = K.shape(kernel)[0] + start = K.cast(kernel_shape / 2, 'int32') + + kernel_new = K.switch( + K.cast(math_ops.floormod(kernel_shape, 2), 'bool'), + lambda: kernel[start - 1:start, start - 1:start], + lambda: kernel[start - 1:start, start - 1:start] + K.zeros( # pylint: disable=g-long-lambda + (2, 2), dtype=kernel.dtype)) + index = K.switch( + K.cast(math_ops.floormod(kernel_shape, 2), 'bool'), + lambda: K.constant(0, dtype='int32'), + lambda: K.constant(1, dtype='int32')) + while_condition = lambda index, *args: K.less(index, start) + + def body_fn(i, array): + return i + 1, array_ops.pad( + array, + padding, + constant_values=kernel[start + i, start + i]) + + _, kernel_new = control_flow_ops.while_loop( + while_condition, + body_fn, + [index, kernel_new], + shape_invariants=[index.get_shape(), + tensor_shape.TensorShape([None, None])]) + return kernel_new + + # Aliases. max_norm = MaxNorm non_neg = NonNeg unit_norm = UnitNorm min_max_norm = MinMaxNorm +radial_constraint = RadialConstraint # Legacy aliases. maxnorm = max_norm diff --git a/tensorflow/python/keras/constraints_test.py b/tensorflow/python/keras/constraints_test.py index 92bc4852cff..741be34530a 100644 --- a/tensorflow/python/keras/constraints_test.py +++ b/tensorflow/python/keras/constraints_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + import numpy as np from tensorflow.python import keras @@ -36,6 +38,12 @@ def get_example_array(): return example_array +def get_example_kernel(width): + np.random.seed(3537) + example_array = np.random.rand(width, width, 2, 2) + return example_array + + @test_util.run_all_in_graph_and_eager_modes class KerasConstraintsTest(test.TestCase): @@ -93,6 +101,16 @@ class KerasConstraintsTest(test.TestCase): assert not l2[l2 < m] assert not l2[l2 > m * 2 + 1e-5] + def test_conv2d_radial_constraint(self): + for width in (3, 4, 5, 6): + array = get_example_kernel(width) + norm_instance = keras.constraints.radial_constraint() + normed = norm_instance(keras.backend.variable(array)) + value = keras.backend.eval(normed) + assert np.all(value.shape == array.shape) + assert np.all(value[0:, 0, 0, 0] == value[-1:, 0, 0, 0]) + assert len(set(value[..., 0, 0].flatten())) == math.ceil(float(width) / 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-radial-constraint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-radial-constraint.pbtxt new file mode 100644 index 00000000000..826cfb92932 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-radial-constraint.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.keras.constraints.RadialConstraint" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt index 655685956f0..29444ef3405 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt @@ -16,6 +16,10 @@ tf_module { name: "NonNeg" mtype: "" } + member { + name: "RadialConstraint" + mtype: "" + } member { name: "UnitNorm" mtype: "" @@ -32,6 +36,10 @@ tf_module { name: "non_neg" mtype: "" } + member { + name: "radial_constraint" + mtype: "" + } member { name: "unit_norm" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.radial_constraint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.radial_constraint.pbtxt new file mode 100644 index 00000000000..3040111b324 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.radial_constraint.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.keras.constraints.radial_constraint" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-radial-constraint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-radial-constraint.pbtxt new file mode 100644 index 00000000000..826cfb92932 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-radial-constraint.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.keras.constraints.RadialConstraint" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt index 655685956f0..29444ef3405 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt @@ -16,6 +16,10 @@ tf_module { name: "NonNeg" mtype: "" } + member { + name: "RadialConstraint" + mtype: "" + } member { name: "UnitNorm" mtype: "" @@ -32,6 +36,10 @@ tf_module { name: "non_neg" mtype: "" } + member { + name: "radial_constraint" + mtype: "" + } member { name: "unit_norm" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.radial_constraint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.radial_constraint.pbtxt new file mode 100644 index 00000000000..3040111b324 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.radial_constraint.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.keras.constraints.radial_constraint" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +}