K-FAC: _check_registration() supports multiple towers.

PiperOrigin-RevId: 173115870
This commit is contained in:
A. Unique TensorFlower 2017-10-23 08:00:39 -07:00 committed by TensorFlower Gardener
parent 670dddf4ad
commit 434695921d
4 changed files with 54 additions and 6 deletions
tensorflow/contrib/kfac/python

View File

@ -313,10 +313,20 @@ class LayerCollectionTest(test.TestCase):
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
def testGetUseCountMap(self):
"""Ensure get_use_count_map() sums 'num_registered_minibatches'."""
class MockFisherBlock(object):
num_registered_minibatches = 2
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {'a': 1, ('a', 'c'): 2, ('b', 'c'): 2}
lc.fisher_blocks = {
'a': MockFisherBlock(),
('a', 'c'): MockFisherBlock(),
('b', 'c'): MockFisherBlock()
}
use_count_map = lc.get_use_count_map()
self.assertDictEqual({'a': 2, 'b': 1, 'c': 2}, use_count_map)
self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
if __name__ == '__main__':

View File

@ -113,7 +113,9 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"@six_archive//:six",
],
)

View File

@ -114,6 +114,14 @@ class FisherBlock(object):
"""
pass
@abc.abstractproperty
def num_registered_minibatches(self):
"""Number of minibatches registered for this FisherBlock.
Typically equal to the number of towers in a multi-tower setup.
"""
pass
class FullFB(FisherBlock):
"""FisherBlock using a full matrix estimate (no approximations).
@ -164,6 +172,10 @@ class FullFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._params
@property
def num_registered_minibatches(self):
return 1 # Multiple minibatches not supported.
class NaiveDiagonalFB(FisherBlock):
"""FisherBlock using a diagonal matrix approximation.
@ -209,6 +221,10 @@ class NaiveDiagonalFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._params
@property
def num_registered_minibatches(self):
return 1 # Multiple minibatches not supported.
class FullyConnectedDiagonalFB(FisherBlock):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
@ -305,6 +321,12 @@ class FullyConnectedDiagonalFB(FisherBlock):
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
result = len(self._inputs)
assert result == len(self._outputs)
return result
class ConvDiagonalFB(FisherBlock):
"""FisherBlock for convolutional layers using a diagonal approx.
@ -400,6 +422,10 @@ class ConvDiagonalFB(FisherBlock):
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
return len(self._inputs)
class KroneckerProductFB(FisherBlock):
"""A base class for FisherBlocks with separate input and output factors.
@ -532,6 +558,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
self._inputs.append(inputs)
self._outputs.append(outputs)
@property
def num_registered_minibatches(self):
return 1 # Multiple minibatches not supported.
class ConvKFCBasicFB(KroneckerProductFB):
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
@ -591,6 +621,10 @@ class ConvKFCBasicFB(KroneckerProductFB):
def tensors_to_compute_grads(self):
return self._outputs
@property
def num_registered_minibatches(self):
return 1 # Multiple minibatches not supported.
def _concat_along_batch_dim(tensor_list):
"""Concatenate tensors along batch (first) dimension.

View File

@ -27,6 +27,8 @@ from __future__ import print_function
from collections import defaultdict
from collections import OrderedDict
import six
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import loss_functions as lf
from tensorflow.contrib.kfac.python.ops import utils
@ -82,8 +84,8 @@ class LayerParametersDict(OrderedDict):
return key
# TODO(duckworthd): add capability for LayerCollection to be "finalized"
# and do this when it gets used by FisherEstimator / KfacOptimizer
# TODO(b/68034464): add capability for LayerCollection to be "finalized"
# and do this when it gets used by FisherEstimator / KfacOptimizer.
class LayerCollection(object):
@ -211,10 +213,10 @@ class LayerCollection(object):
def get_use_count_map(self):
"""Returns a dict of variables to their number of registrations."""
vars_to_uses = defaultdict(int)
for key in self.fisher_blocks.keys():
for key, block in six.iteritems(self.fisher_blocks):
key = key if isinstance(key, (tuple, list)) else (key,)
for k in key:
vars_to_uses[k] += 1
vars_to_uses[k] += block.num_registered_minibatches
return vars_to_uses
def get_blocks(self):