K-FAC: _check_registration() supports multiple towers.
PiperOrigin-RevId: 173115870
This commit is contained in:
parent
670dddf4ad
commit
434695921d
@ -313,10 +313,20 @@ class LayerCollectionTest(test.TestCase):
|
|||||||
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
|
self.assertTrue(all([var.name.startswith(scope) for var in variables]))
|
||||||
|
|
||||||
def testGetUseCountMap(self):
|
def testGetUseCountMap(self):
|
||||||
|
"""Ensure get_use_count_map() sums 'num_registered_minibatches'."""
|
||||||
|
|
||||||
|
class MockFisherBlock(object):
|
||||||
|
|
||||||
|
num_registered_minibatches = 2
|
||||||
|
|
||||||
lc = layer_collection.LayerCollection()
|
lc = layer_collection.LayerCollection()
|
||||||
lc.fisher_blocks = {'a': 1, ('a', 'c'): 2, ('b', 'c'): 2}
|
lc.fisher_blocks = {
|
||||||
|
'a': MockFisherBlock(),
|
||||||
|
('a', 'c'): MockFisherBlock(),
|
||||||
|
('b', 'c'): MockFisherBlock()
|
||||||
|
}
|
||||||
use_count_map = lc.get_use_count_map()
|
use_count_map = lc.get_use_count_map()
|
||||||
self.assertDictEqual({'a': 2, 'b': 1, 'c': 2}, use_count_map)
|
self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -113,7 +113,9 @@ py_library(
|
|||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -114,6 +114,14 @@ class FisherBlock(object):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractproperty
|
||||||
|
def num_registered_minibatches(self):
|
||||||
|
"""Number of minibatches registered for this FisherBlock.
|
||||||
|
|
||||||
|
Typically equal to the number of towers in a multi-tower setup.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FullFB(FisherBlock):
|
class FullFB(FisherBlock):
|
||||||
"""FisherBlock using a full matrix estimate (no approximations).
|
"""FisherBlock using a full matrix estimate (no approximations).
|
||||||
@ -164,6 +172,10 @@ class FullFB(FisherBlock):
|
|||||||
def tensors_to_compute_grads(self):
|
def tensors_to_compute_grads(self):
|
||||||
return self._params
|
return self._params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_registered_minibatches(self):
|
||||||
|
return 1 # Multiple minibatches not supported.
|
||||||
|
|
||||||
|
|
||||||
class NaiveDiagonalFB(FisherBlock):
|
class NaiveDiagonalFB(FisherBlock):
|
||||||
"""FisherBlock using a diagonal matrix approximation.
|
"""FisherBlock using a diagonal matrix approximation.
|
||||||
@ -209,6 +221,10 @@ class NaiveDiagonalFB(FisherBlock):
|
|||||||
def tensors_to_compute_grads(self):
|
def tensors_to_compute_grads(self):
|
||||||
return self._params
|
return self._params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_registered_minibatches(self):
|
||||||
|
return 1 # Multiple minibatches not supported.
|
||||||
|
|
||||||
|
|
||||||
class FullyConnectedDiagonalFB(FisherBlock):
|
class FullyConnectedDiagonalFB(FisherBlock):
|
||||||
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
|
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
|
||||||
@ -305,6 +321,12 @@ class FullyConnectedDiagonalFB(FisherBlock):
|
|||||||
self._inputs.append(inputs)
|
self._inputs.append(inputs)
|
||||||
self._outputs.append(outputs)
|
self._outputs.append(outputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_registered_minibatches(self):
|
||||||
|
result = len(self._inputs)
|
||||||
|
assert result == len(self._outputs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ConvDiagonalFB(FisherBlock):
|
class ConvDiagonalFB(FisherBlock):
|
||||||
"""FisherBlock for convolutional layers using a diagonal approx.
|
"""FisherBlock for convolutional layers using a diagonal approx.
|
||||||
@ -400,6 +422,10 @@ class ConvDiagonalFB(FisherBlock):
|
|||||||
self._inputs.append(inputs)
|
self._inputs.append(inputs)
|
||||||
self._outputs.append(outputs)
|
self._outputs.append(outputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_registered_minibatches(self):
|
||||||
|
return len(self._inputs)
|
||||||
|
|
||||||
|
|
||||||
class KroneckerProductFB(FisherBlock):
|
class KroneckerProductFB(FisherBlock):
|
||||||
"""A base class for FisherBlocks with separate input and output factors.
|
"""A base class for FisherBlocks with separate input and output factors.
|
||||||
@ -532,6 +558,10 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB):
|
|||||||
self._inputs.append(inputs)
|
self._inputs.append(inputs)
|
||||||
self._outputs.append(outputs)
|
self._outputs.append(outputs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_registered_minibatches(self):
|
||||||
|
return 1 # Multiple minibatches not supported.
|
||||||
|
|
||||||
|
|
||||||
class ConvKFCBasicFB(KroneckerProductFB):
|
class ConvKFCBasicFB(KroneckerProductFB):
|
||||||
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
"""FisherBlock for 2D convolutional layers using the basic KFC approx.
|
||||||
@ -591,6 +621,10 @@ class ConvKFCBasicFB(KroneckerProductFB):
|
|||||||
def tensors_to_compute_grads(self):
|
def tensors_to_compute_grads(self):
|
||||||
return self._outputs
|
return self._outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_registered_minibatches(self):
|
||||||
|
return 1 # Multiple minibatches not supported.
|
||||||
|
|
||||||
|
|
||||||
def _concat_along_batch_dim(tensor_list):
|
def _concat_along_batch_dim(tensor_list):
|
||||||
"""Concatenate tensors along batch (first) dimension.
|
"""Concatenate tensors along batch (first) dimension.
|
||||||
|
@ -27,6 +27,8 @@ from __future__ import print_function
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
|
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
|
||||||
from tensorflow.contrib.kfac.python.ops import loss_functions as lf
|
from tensorflow.contrib.kfac.python.ops import loss_functions as lf
|
||||||
from tensorflow.contrib.kfac.python.ops import utils
|
from tensorflow.contrib.kfac.python.ops import utils
|
||||||
@ -82,8 +84,8 @@ class LayerParametersDict(OrderedDict):
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
# TODO(duckworthd): add capability for LayerCollection to be "finalized"
|
# TODO(b/68034464): add capability for LayerCollection to be "finalized"
|
||||||
# and do this when it gets used by FisherEstimator / KfacOptimizer
|
# and do this when it gets used by FisherEstimator / KfacOptimizer.
|
||||||
|
|
||||||
|
|
||||||
class LayerCollection(object):
|
class LayerCollection(object):
|
||||||
@ -211,10 +213,10 @@ class LayerCollection(object):
|
|||||||
def get_use_count_map(self):
|
def get_use_count_map(self):
|
||||||
"""Returns a dict of variables to their number of registrations."""
|
"""Returns a dict of variables to their number of registrations."""
|
||||||
vars_to_uses = defaultdict(int)
|
vars_to_uses = defaultdict(int)
|
||||||
for key in self.fisher_blocks.keys():
|
for key, block in six.iteritems(self.fisher_blocks):
|
||||||
key = key if isinstance(key, (tuple, list)) else (key,)
|
key = key if isinstance(key, (tuple, list)) else (key,)
|
||||||
for k in key:
|
for k in key:
|
||||||
vars_to_uses[k] += 1
|
vars_to_uses[k] += block.num_registered_minibatches
|
||||||
return vars_to_uses
|
return vars_to_uses
|
||||||
|
|
||||||
def get_blocks(self):
|
def get_blocks(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user