diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index b444e871701..1da811dc0a3 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -313,10 +313,20 @@ class LayerCollectionTest(test.TestCase): self.assertTrue(all([var.name.startswith(scope) for var in variables])) def testGetUseCountMap(self): + """Ensure get_use_count_map() sums 'num_registered_minibatches'.""" + + class MockFisherBlock(object): + + num_registered_minibatches = 2 + lc = layer_collection.LayerCollection() - lc.fisher_blocks = {'a': 1, ('a', 'c'): 2, ('b', 'c'): 2} + lc.fisher_blocks = { + 'a': MockFisherBlock(), + ('a', 'c'): MockFisherBlock(), + ('b', 'c'): MockFisherBlock() + } use_count_map = lc.get_use_count_map() - self.assertDictEqual({'a': 2, 'b': 1, 'c': 2}, use_count_map) + self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) if __name__ == '__main__': diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index 8b82f6e3147..5d5046c9ec6 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -113,7 +113,9 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 754c2cc853b..7ef755c35ed 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -114,6 +114,14 @@ class FisherBlock(object): """ pass + @abc.abstractproperty + def num_registered_minibatches(self): + """Number of minibatches registered for this FisherBlock. + + Typically equal to the number of towers in a multi-tower setup. + """ + pass + class FullFB(FisherBlock): """FisherBlock using a full matrix estimate (no approximations). @@ -164,6 +172,10 @@ class FullFB(FisherBlock): def tensors_to_compute_grads(self): return self._params + @property + def num_registered_minibatches(self): + return 1 # Multiple minibatches not supported. + class NaiveDiagonalFB(FisherBlock): """FisherBlock using a diagonal matrix approximation. @@ -209,6 +221,10 @@ class NaiveDiagonalFB(FisherBlock): def tensors_to_compute_grads(self): return self._params + @property + def num_registered_minibatches(self): + return 1 # Multiple minibatches not supported. + class FullyConnectedDiagonalFB(FisherBlock): """FisherBlock for fully-connected (dense) layers using a diagonal approx. @@ -305,6 +321,12 @@ class FullyConnectedDiagonalFB(FisherBlock): self._inputs.append(inputs) self._outputs.append(outputs) + @property + def num_registered_minibatches(self): + result = len(self._inputs) + assert result == len(self._outputs) + return result + class ConvDiagonalFB(FisherBlock): """FisherBlock for convolutional layers using a diagonal approx. @@ -400,6 +422,10 @@ class ConvDiagonalFB(FisherBlock): self._inputs.append(inputs) self._outputs.append(outputs) + @property + def num_registered_minibatches(self): + return len(self._inputs) + class KroneckerProductFB(FisherBlock): """A base class for FisherBlocks with separate input and output factors. @@ -532,6 +558,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): self._inputs.append(inputs) self._outputs.append(outputs) + @property + def num_registered_minibatches(self): + return 1 # Multiple minibatches not supported. + class ConvKFCBasicFB(KroneckerProductFB): """FisherBlock for 2D convolutional layers using the basic KFC approx. @@ -591,6 +621,10 @@ class ConvKFCBasicFB(KroneckerProductFB): def tensors_to_compute_grads(self): return self._outputs + @property + def num_registered_minibatches(self): + return 1 # Multiple minibatches not supported. + def _concat_along_batch_dim(tensor_list): """Concatenate tensors along batch (first) dimension. diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index ceb1131f286..49279954dc8 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -27,6 +27,8 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +import six + from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb from tensorflow.contrib.kfac.python.ops import loss_functions as lf from tensorflow.contrib.kfac.python.ops import utils @@ -82,8 +84,8 @@ class LayerParametersDict(OrderedDict): return key -# TODO(duckworthd): add capability for LayerCollection to be "finalized" -# and do this when it gets used by FisherEstimator / KfacOptimizer +# TODO(b/68034464): add capability for LayerCollection to be "finalized" +# and do this when it gets used by FisherEstimator / KfacOptimizer. class LayerCollection(object): @@ -211,10 +213,10 @@ class LayerCollection(object): def get_use_count_map(self): """Returns a dict of variables to their number of registrations.""" vars_to_uses = defaultdict(int) - for key in self.fisher_blocks.keys(): + for key, block in six.iteritems(self.fisher_blocks): key = key if isinstance(key, (tuple, list)) else (key,) for k in key: - vars_to_uses[k] += 1 + vars_to_uses[k] += block.num_registered_minibatches return vars_to_uses def get_blocks(self):