From 434695921de7cfd713b789533173e1e0c3fc7691 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 23 Oct 2017 08:00:39 -0700
Subject: [PATCH] K-FAC: _check_registration() supports multiple towers.

PiperOrigin-RevId: 173115870
---
 .../kernel_tests/layer_collection_test.py     | 14 ++++++--
 tensorflow/contrib/kfac/python/ops/BUILD      |  2 ++
 .../contrib/kfac/python/ops/fisher_blocks.py  | 34 +++++++++++++++++++
 .../kfac/python/ops/layer_collection.py       | 10 +++---
 4 files changed, 54 insertions(+), 6 deletions(-)

diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
index b444e871701..1da811dc0a3 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -313,10 +313,20 @@ class LayerCollectionTest(test.TestCase):
       self.assertTrue(all([var.name.startswith(scope) for var in variables]))
 
   def testGetUseCountMap(self):
+    """Ensure get_use_count_map() sums 'num_registered_minibatches'."""
+
+    class MockFisherBlock(object):
+
+      num_registered_minibatches = 2
+
     lc = layer_collection.LayerCollection()
-    lc.fisher_blocks = {'a': 1, ('a', 'c'): 2, ('b', 'c'): 2}
+    lc.fisher_blocks = {
+        'a': MockFisherBlock(),
+        ('a', 'c'): MockFisherBlock(),
+        ('b', 'c'): MockFisherBlock()
+    }
     use_count_map = lc.get_use_count_map()
-    self.assertDictEqual({'a': 2, 'b': 1, 'c': 2}, use_count_map)
+    self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index 8b82f6e3147..5d5046c9ec6 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -113,7 +113,9 @@ py_library(
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform",
+        "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
+        "@six_archive//:six",
     ],
 )
 
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 754c2cc853b..7ef755c35ed 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -114,6 +114,14 @@ class FisherBlock(object):
     """
     pass
 
+  @abc.abstractproperty
+  def num_registered_minibatches(self):
+    """Number of minibatches registered for this FisherBlock.
+
+    Typically equal to the number of towers in a multi-tower setup.
+    """
+    pass
+
 
 class FullFB(FisherBlock):
   """FisherBlock using a full matrix estimate (no approximations).
@@ -164,6 +172,10 @@ class FullFB(FisherBlock):
   def tensors_to_compute_grads(self):
     return self._params
 
+  @property
+  def num_registered_minibatches(self):
+    return 1  # Multiple minibatches not supported.
+
 
 class NaiveDiagonalFB(FisherBlock):
   """FisherBlock using a diagonal matrix approximation.
@@ -209,6 +221,10 @@ class NaiveDiagonalFB(FisherBlock):
   def tensors_to_compute_grads(self):
     return self._params
 
+  @property
+  def num_registered_minibatches(self):
+    return 1  # Multiple minibatches not supported.
+
 
 class FullyConnectedDiagonalFB(FisherBlock):
   """FisherBlock for fully-connected (dense) layers using a diagonal approx.
@@ -305,6 +321,12 @@ class FullyConnectedDiagonalFB(FisherBlock):
     self._inputs.append(inputs)
     self._outputs.append(outputs)
 
+  @property
+  def num_registered_minibatches(self):
+    result = len(self._inputs)
+    assert result == len(self._outputs)
+    return result
+
 
 class ConvDiagonalFB(FisherBlock):
   """FisherBlock for convolutional layers using a diagonal approx.
@@ -400,6 +422,10 @@ class ConvDiagonalFB(FisherBlock):
     self._inputs.append(inputs)
     self._outputs.append(outputs)
 
+  @property
+  def num_registered_minibatches(self):
+    return len(self._inputs)
+
 
 class KroneckerProductFB(FisherBlock):
   """A base class for FisherBlocks with separate input and output factors.
@@ -532,6 +558,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
     self._inputs.append(inputs)
     self._outputs.append(outputs)
 
+  @property
+  def num_registered_minibatches(self):
+    return 1  # Multiple minibatches not supported.
+
 
 class ConvKFCBasicFB(KroneckerProductFB):
   """FisherBlock for 2D convolutional layers using the basic KFC approx.
@@ -591,6 +621,10 @@ class ConvKFCBasicFB(KroneckerProductFB):
   def tensors_to_compute_grads(self):
     return self._outputs
 
+  @property
+  def num_registered_minibatches(self):
+    return 1  # Multiple minibatches not supported.
+
 
 def _concat_along_batch_dim(tensor_list):
   """Concatenate tensors along batch (first) dimension.
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index ceb1131f286..49279954dc8 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -27,6 +27,8 @@ from __future__ import print_function
 from collections import defaultdict
 from collections import OrderedDict
 
+import six
+
 from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
 from tensorflow.contrib.kfac.python.ops import loss_functions as lf
 from tensorflow.contrib.kfac.python.ops import utils
@@ -82,8 +84,8 @@ class LayerParametersDict(OrderedDict):
     return key
 
 
-# TODO(duckworthd): add capability for LayerCollection to be "finalized"
-# and do this when it gets used by FisherEstimator / KfacOptimizer
+# TODO(b/68034464): add capability for LayerCollection to be "finalized"
+# and do this when it gets used by FisherEstimator / KfacOptimizer.
 
 
 class LayerCollection(object):
@@ -211,10 +213,10 @@ class LayerCollection(object):
   def get_use_count_map(self):
     """Returns a dict of variables to their number of registrations."""
     vars_to_uses = defaultdict(int)
-    for key in self.fisher_blocks.keys():
+    for key, block in six.iteritems(self.fisher_blocks):
       key = key if isinstance(key, (tuple, list)) else (key,)
       for k in key:
-        vars_to_uses[k] += 1
+        vars_to_uses[k] += block.num_registered_minibatches
     return vars_to_uses
 
   def get_blocks(self):