K-FAC: _check_registration() supports multiple towers.
PiperOrigin-RevId: 173115870
This commit is contained in:
parent
670dddf4ad
commit
434695921d
tensorflow/contrib/kfac/python
@ -313,10 +313,20 @@ class LayerCollectionTest(test.TestCase):
|
||||
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
|
||||
|
||||
def testGetUseCountMap(self):
|
||||
"""Ensure get_use_count_map() sums 'num_registered_minibatches'."""
|
||||
|
||||
class MockFisherBlock(object):
|
||||
|
||||
num_registered_minibatches = 2
|
||||
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {'a': 1, ('a', 'c'): 2, ('b', 'c'): 2}
|
||||
lc.fisher_blocks = {
|
||||
'a': MockFisherBlock(),
|
||||
('a', 'c'): MockFisherBlock(),
|
||||
('b', 'c'): MockFisherBlock()
|
||||
}
|
||||
use_count_map = lc.get_use_count_map()
|
||||
self.assertDictEqual({'a': 2, 'b': 1, 'c': 2}, use_count_map)
|
||||
self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -113,7 +113,9 @@ py_library(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -114,6 +114,14 @@ class FisherBlock(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def num_registered_minibatches(self):
|
||||
"""Number of minibatches registered for this FisherBlock.
|
||||
|
||||
Typically equal to the number of towers in a multi-tower setup.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class FullFB(FisherBlock):
|
||||
"""FisherBlock using a full matrix estimate (no approximations).
|
||||
@ -164,6 +172,10 @@ class FullFB(FisherBlock):
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return 1 # Multiple minibatches not supported.
|
||||
|
||||
|
||||
class NaiveDiagonalFB(FisherBlock):
|
||||
"""FisherBlock using a diagonal matrix approximation.
|
||||
@ -209,6 +221,10 @@ class NaiveDiagonalFB(FisherBlock):
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return 1 # Multiple minibatches not supported.
|
||||
|
||||
|
||||
class FullyConnectedDiagonalFB(FisherBlock):
|
||||
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
|
||||
@ -305,6 +321,12 @@ class FullyConnectedDiagonalFB(FisherBlock):
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
result = len(self._inputs)
|
||||
assert result == len(self._outputs)
|
||||
return result
|
||||
|
||||
|
||||
class ConvDiagonalFB(FisherBlock):
|
||||
"""FisherBlock for convolutional layers using a diagonal approx.
|
||||
@ -400,6 +422,10 @@ class ConvDiagonalFB(FisherBlock):
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return len(self._inputs)
|
||||
|
||||
|
||||
class KroneckerProductFB(FisherBlock):
|
||||
"""A base class for FisherBlocks with separate input and output factors.
|
||||
@ -532,6 +558,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
||||
self._inputs.append(inputs)
|
||||
self._outputs.append(outputs)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return 1 # Multiple minibatches not supported.
|
||||
|
||||
|
||||
class ConvKFCBasicFB(KroneckerProductFB):
|
||||
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
||||
@ -591,6 +621,10 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._outputs
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return 1 # Multiple minibatches not supported.
|
||||
|
||||
|
||||
def _concat_along_batch_dim(tensor_list):
|
||||
"""Concatenate tensors along batch (first) dimension.
|
||||
|
@ -27,6 +27,8 @@ from __future__ import print_function
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
|
||||
from tensorflow.contrib.kfac.python.ops import loss_functions as lf
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
@ -82,8 +84,8 @@ class LayerParametersDict(OrderedDict):
|
||||
return key
|
||||
|
||||
|
||||
# TODO(duckworthd): add capability for LayerCollection to be "finalized"
|
||||
# and do this when it gets used by FisherEstimator / KfacOptimizer
|
||||
# TODO(b/68034464): add capability for LayerCollection to be "finalized"
|
||||
# and do this when it gets used by FisherEstimator / KfacOptimizer.
|
||||
|
||||
|
||||
class LayerCollection(object):
|
||||
@ -211,10 +213,10 @@ class LayerCollection(object):
|
||||
def get_use_count_map(self):
|
||||
"""Returns a dict of variables to their number of registrations."""
|
||||
vars_to_uses = defaultdict(int)
|
||||
for key in self.fisher_blocks.keys():
|
||||
for key, block in six.iteritems(self.fisher_blocks):
|
||||
key = key if isinstance(key, (tuple, list)) else (key,)
|
||||
for k in key:
|
||||
vars_to_uses[k] += 1
|
||||
vars_to_uses[k] += block.num_registered_minibatches
|
||||
return vars_to_uses
|
||||
|
||||
def get_blocks(self):
|
||||
|
Loading…
Reference in New Issue
Block a user