From 4f7503a876e20e6d58c9aec3f44214b98bcfdbbb Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 23 Oct 2017 08:55:00 -0700
Subject: [PATCH] K-FAC: Support for registering multiple minibatches with
 register_fully_connected()

PiperOrigin-RevId: 173121735
---
 .../kernel_tests/layer_collection_test.py     | 67 +++++++++++++++++++
 .../kfac/python/ops/layer_collection.py       | 66 +++++++++++++++---
 .../kfac/python/ops/layer_collection_lib.py   |  1 +
 3 files changed, 123 insertions(+), 11 deletions(-)

diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index 1da811dc0a3..432937d8032 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -282,6 +282,73 @@ class LayerCollectionTest(test.TestCase):
       single_loss = sess.run(lc.total_loss())
       self.assertAlmostEqual(7.6983433, single_loss)
 
+  def testRegisterFullyConnectedReuse(self):
+    """Ensure the 'reuse' keyword argument function as intended."""
+    with ops.Graph().as_default():
+      inputs = [
+          array_ops.ones([2, 10]),  #
+          array_ops.zeros([5, 10])
+      ]
+      outputs = [
+          array_ops.zeros([2, 5]),  #
+          array_ops.ones([5, 5])
+      ]
+      params = (
+          variable_scope.get_variable('w', [10, 5]),  #
+          variable_scope.get_variable('b', [5]))
+
+      # Fails on second if reuse=False.
+      lc = layer_collection.LayerCollection()
+      lc.register_fully_connected(params, inputs[0], outputs[0])
+      with self.assertRaises(ValueError):
+        lc.register_fully_connected(params, inputs[1], outputs[1], reuse=False)
+
+      # Succeeds on second if reuse=True.
+      lc = layer_collection.LayerCollection()
+      lc.register_fully_connected(params, inputs[0], outputs[0])
+      lc.register_fully_connected(params, inputs[1], outputs[1], reuse=True)
+
+      # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
+      lc = layer_collection.LayerCollection()
+      lc.register_fully_connected(params, inputs[0], outputs[0])
+      with self.assertRaises(ValueError):
+        lc.register_fully_connected(
+            params,
+            inputs[1],
+            outputs[1],
+            reuse=layer_collection.VARIABLE_SCOPE)
+
+      # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
+      lc = layer_collection.LayerCollection()
+      lc.register_fully_connected(params, inputs[0], outputs[0])
+      with variable_scope.variable_scope(
+          variable_scope.get_variable_scope(), reuse=True):
+        lc.register_fully_connected(
+            params,
+            inputs[1],
+            outputs[1],
+            reuse=layer_collection.VARIABLE_SCOPE)
+
+      # Fails if block type changes.
+      lc = layer_collection.LayerCollection()
+      lc.register_fully_connected(
+          params,
+          inputs[0],
+          outputs[0],
+          approx=layer_collection.APPROX_KRONECKER_NAME)
+      with self.assertRaises(ValueError):
+        lc.register_fully_connected(
+            params,
+            inputs[1],
+            outputs[1],
+            approx=layer_collection.APPROX_DIAGONAL_NAME,
+            reuse=True)
+
+      # Fails if reuse requested but no FisherBlock exists.
+      lc = layer_collection.LayerCollection()
+      with self.assertRaises(KeyError):
+        lc.register_fully_connected(params, inputs[0], outputs[0], reuse=True)
+
   def testMakeOrGetFactor(self):
     with ops.Graph().as_default():
       random_seed.set_random_seed(200)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 49279954dc8..cd711d05610 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -39,10 +39,15 @@ from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
 
 
+# Names for various approximations that can be requested for Fisher blocks.
 APPROX_KRONECKER_NAME = "kron"
 APPROX_DIAGONAL_NAME = "diagonal"
 APPROX_FULL_NAME = "full"
 
+# Possible value for 'reuse' keyword argument. Sets 'reuse' to
+# tf.get_variable_scope().reuse.
+VARIABLE_SCOPE = "VARIABLE_SCOPE"
+
 # TODO(jamesmartens): need to add find_canonical_output back into this somewhere
 
 
@@ -254,19 +259,58 @@ class LayerCollection(object):
                                params,
                                inputs,
                                outputs,
-                               approx=APPROX_KRONECKER_NAME):
-    has_bias = isinstance(params, (tuple, list))
-    if approx == APPROX_KRONECKER_NAME:
-      block = fb.FullyConnectedKFACBasicFB(self, has_bias)
-      block.register_additional_minibatch(inputs, outputs)
-      self.register_block(params, block)
-    elif approx == APPROX_DIAGONAL_NAME:
-      block = fb.FullyConnectedDiagonalFB(self, has_bias)
-      block.register_additional_minibatch(inputs, outputs)
-      self.register_block(params, block)
-    else:
+                               approx=APPROX_KRONECKER_NAME,
+                               reuse=VARIABLE_SCOPE):
+    """Registers a fully connnected layer.
+
+    Args:
+      params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+        this layer. Weight matrix should have shape [input_size, output_size].
+        Bias should have shape [output_size].
+      inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
+      outputs: Tensor of shape [batch_size, output_size]. Preactivations
+        produced by layer.
+      approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME.
+      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
+        create a new FisherBlock.  If VARIABLE_SCOPE, use
+        tf.get_variable_scope().reuse.
+
+    Raises:
+      ValueError: For improper value to 'approx'.
+      KeyError: If reuse == True but no FisherBlock found for 'params'.
+      ValueError: If reuse == True and FisherBlock found but of the wrong type.
+    """
+    approx_to_block_types = {
+        APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
+        APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
+    }
+
+    if approx not in approx_to_block_types:
       raise ValueError("Bad value {} for approx.".format(approx))
 
+    block_type = approx_to_block_types[approx]
+    has_bias = isinstance(params, (tuple, list))
+
+    if reuse == VARIABLE_SCOPE:
+      reuse = variable_scope.get_variable_scope().reuse
+
+    if reuse:
+      block = self.fisher_blocks.get(params, None)
+      if block is None:
+        raise KeyError(
+            "Reuse requested but no FisherBlock found for params {}.".format(
+                params))
+      if not isinstance(block, block_type):
+        raise ValueError(
+            "Requested block of type {} but block of type {} already exists "
+            "for params {}.".format(block_type, type(block), params))
+
+    else:
+      block = block_type(self, has_bias)
+      self.register_block(params, block)
+
+    block.register_additional_minibatch(inputs, outputs)
+
   def register_conv2d(self, params, strides, padding, inputs, outputs,
                       approx=APPROX_KRONECKER_NAME):
 
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
index 63a9b173bc8..d6bf61a2102 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
@@ -35,6 +35,7 @@ _allowed_symbols = [
     "APPROX_KRONECKER_NAME",
     "APPROX_DIAGONAL_NAME",
     "APPROX_FULL_NAME",
+    "VARIABLE_SCOPE",
 ]
 
 remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)