From 026bc91eae3078b0fa0463615bec0fa171ceb8a3 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 19 Jun 2019 21:27:36 -0700
Subject: [PATCH] Adding radial constraint for Conv2D kernels to make them
 rotation invariant.

PiperOrigin-RevId: 254129632
---
 tensorflow/python/keras/constraints.py        | 82 +++++++++++++++++++
 tensorflow/python/keras/constraints_test.py   | 18 ++++
 ...keras.constraints.-radial-constraint.pbtxt | 13 +++
 .../v1/tensorflow.keras.constraints.pbtxt     |  8 ++
 ....keras.constraints.radial_constraint.pbtxt | 13 +++
 ...keras.constraints.-radial-constraint.pbtxt | 13 +++
 .../v2/tensorflow.keras.constraints.pbtxt     |  8 ++
 ....keras.constraints.radial_constraint.pbtxt | 13 +++
 8 files changed, 168 insertions(+)
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-radial-constraint.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.radial_constraint.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-radial-constraint.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.radial_constraint.pbtxt

diff --git a/tensorflow/python/keras/constraints.py b/tensorflow/python/keras/constraints.py
index 334d072d5a2..043ceb8dd6d 100644
--- a/tensorflow/python/keras/constraints.py
+++ b/tensorflow/python/keras/constraints.py
@@ -21,9 +21,12 @@ from __future__ import print_function
 
 import six
 
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
 from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.util.tf_export import keras_export
 
@@ -168,12 +171,91 @@ class MinMaxNorm(Constraint):
     }
 
 
+@keras_export('keras.constraints.RadialConstraint',
+              'keras.constraints.radial_constraint')
+class RadialConstraint(Constraint):
+  """Constrains `Conv2D` kernel weights to be the same for each radius.
+
+  For example, the desired output for the following 4-by-4 kernel::
+
+  ```
+      kernel = [[v_00, v_01, v_02, v_03],
+                [v_10, v_11, v_12, v_13],
+                [v_20, v_21, v_22, v_23],
+                [v_30, v_31, v_32, v_33]]
+  ```
+
+  is this::
+
+  ```
+      kernel = [[v_11, v_11, v_11, v_11],
+                [v_11, v_33, v_33, v_11],
+                [v_11, v_33, v_33, v_11],
+                [v_11, v_11, v_11, v_11]]
+  ```
+
+  This constraint can be applied to any `Conv2D` layer version, including
+  `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
+  `"channels_first"` data format. The method assumes the weight tensor is of
+  shape `(rows, cols, input_depth, output_depth)`.
+  """
+
+  def __call__(self, w):
+    w_shape = w.shape
+    if w_shape.rank is None or w_shape.rank != 4:
+      raise ValueError(
+          'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
+
+    height, width, channels, kernels = w_shape
+    w = K.reshape(w, (height, width, channels * kernels))
+    # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch
+    # is supported.
+    w = K.map_fn(
+        self._kernel_constraint,
+        K.stack(array_ops.unstack(w, axis=-1), axis=0))
+    return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1),
+                     (height, width, channels, kernels))
+
+  def _kernel_constraint(self, kernel):
+    """Radially constraints a kernel with shape (height, width, channels)."""
+    padding = K.constant([[1, 1], [1, 1]], dtype='int32')
+
+    kernel_shape = K.shape(kernel)[0]
+    start = K.cast(kernel_shape / 2, 'int32')
+
+    kernel_new = K.switch(
+        K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
+        lambda: kernel[start - 1:start, start - 1:start],
+        lambda: kernel[start - 1:start, start - 1:start] + K.zeros(  # pylint: disable=g-long-lambda
+            (2, 2), dtype=kernel.dtype))
+    index = K.switch(
+        K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
+        lambda: K.constant(0, dtype='int32'),
+        lambda: K.constant(1, dtype='int32'))
+    while_condition = lambda index, *args: K.less(index, start)
+
+    def body_fn(i, array):
+      return i + 1, array_ops.pad(
+          array,
+          padding,
+          constant_values=kernel[start + i, start + i])
+
+    _, kernel_new = control_flow_ops.while_loop(
+        while_condition,
+        body_fn,
+        [index, kernel_new],
+        shape_invariants=[index.get_shape(),
+                          tensor_shape.TensorShape([None, None])])
+    return kernel_new
+
+
 # Aliases.
 
 max_norm = MaxNorm
 non_neg = NonNeg
 unit_norm = UnitNorm
 min_max_norm = MinMaxNorm
+radial_constraint = RadialConstraint
 
 # Legacy aliases.
 maxnorm = max_norm
diff --git a/tensorflow/python/keras/constraints_test.py b/tensorflow/python/keras/constraints_test.py
index 92bc4852cff..741be34530a 100644
--- a/tensorflow/python/keras/constraints_test.py
+++ b/tensorflow/python/keras/constraints_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import math
+
 import numpy as np
 
 from tensorflow.python import keras
@@ -36,6 +38,12 @@ def get_example_array():
   return example_array
 
 
+def get_example_kernel(width):
+  np.random.seed(3537)
+  example_array = np.random.rand(width, width, 2, 2)
+  return example_array
+
+
 @test_util.run_all_in_graph_and_eager_modes
 class KerasConstraintsTest(test.TestCase):
 
@@ -93,6 +101,16 @@ class KerasConstraintsTest(test.TestCase):
       assert not l2[l2 < m]
       assert not l2[l2 > m * 2 + 1e-5]
 
+  def test_conv2d_radial_constraint(self):
+    for width in (3, 4, 5, 6):
+      array = get_example_kernel(width)
+      norm_instance = keras.constraints.radial_constraint()
+      normed = norm_instance(keras.backend.variable(array))
+      value = keras.backend.eval(normed)
+      assert np.all(value.shape == array.shape)
+      assert np.all(value[0:, 0, 0, 0] == value[-1:, 0, 0, 0])
+      assert len(set(value[..., 0, 0].flatten())) == math.ceil(float(width) / 2)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-radial-constraint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-radial-constraint.pbtxt
new file mode 100644
index 00000000000..826cfb92932
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-radial-constraint.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.keras.constraints.RadialConstraint"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
+  is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt
index 655685956f0..29444ef3405 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt
@@ -16,6 +16,10 @@ tf_module {
     name: "NonNeg"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "RadialConstraint"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "UnitNorm"
     mtype: "<type \'type\'>"
@@ -32,6 +36,10 @@ tf_module {
     name: "non_neg"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "radial_constraint"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "unit_norm"
     mtype: "<type \'type\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.radial_constraint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.radial_constraint.pbtxt
new file mode 100644
index 00000000000..3040111b324
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.radial_constraint.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.keras.constraints.radial_constraint"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
+  is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-radial-constraint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-radial-constraint.pbtxt
new file mode 100644
index 00000000000..826cfb92932
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-radial-constraint.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.keras.constraints.RadialConstraint"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
+  is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt
index 655685956f0..29444ef3405 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt
@@ -16,6 +16,10 @@ tf_module {
     name: "NonNeg"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "RadialConstraint"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "UnitNorm"
     mtype: "<type \'type\'>"
@@ -32,6 +36,10 @@ tf_module {
     name: "non_neg"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "radial_constraint"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "unit_norm"
     mtype: "<type \'type\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.radial_constraint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.radial_constraint.pbtxt
new file mode 100644
index 00000000000..3040111b324
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.radial_constraint.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.keras.constraints.radial_constraint"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.constraints.RadialConstraint\'>"
+  is_instance: "<class \'tensorflow.python.keras.constraints.Constraint\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}