K-FAC: Support for tf.AUTO_REUSE when re-using registrations. Multi-tower support for FullFB, NaiveDiagonalFB. Removal of LayerCollection.generic_registrations.

PiperOrigin-RevId: 174092003
This commit is contained in:
A. Unique TensorFlower 2017-10-31 14:26:27 -07:00 committed by TensorFlower Gardener
parent 0a7be5a2f5
commit 453dd5848f
4 changed files with 182 additions and 98 deletions

View File

@ -46,7 +46,8 @@ class FullFBTest(test.TestCase):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params, 32)
block = fb.FullFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
self.assertAllEqual(params, block.tensors_to_compute_grads())
@ -54,7 +55,8 @@ class FullFBTest(test.TestCase):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params, 32)
block = fb.FullFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
self.assertAllEqual(params, block.tensors_to_compute_grads())
@ -62,7 +64,8 @@ class FullFBTest(test.TestCase):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params, 32)
block = fb.FullFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors(grads, 0.5)
@ -71,7 +74,8 @@ class FullFBTest(test.TestCase):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params, 32)
block = fb.FullFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors((grads,), 0.5)
@ -88,7 +92,8 @@ class FullFBTest(test.TestCase):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
params = array_ops.constant([[1.], [2.]])
block = fb.FullFB(lc.LayerCollection(), params, 32)
block = fb.FullFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
grads = params**2
block.instantiate_factors((grads,), 0.5)
@ -105,7 +110,8 @@ class FullFBTest(test.TestCase):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.FullFB(lc.LayerCollection(), params, 32)
block = fb.FullFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
damping = 0.5
block.instantiate_factors((grads,), damping)
@ -131,7 +137,8 @@ class NaiveDiagonalFBTest(test.TestCase):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
self.assertAllEqual(params, block.tensors_to_compute_grads())
@ -139,7 +146,8 @@ class NaiveDiagonalFBTest(test.TestCase):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
self.assertAllEqual(params, block.tensors_to_compute_grads())
@ -147,7 +155,8 @@ class NaiveDiagonalFBTest(test.TestCase):
with ops.Graph().as_default():
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors(grads, 0.5)
@ -156,7 +165,8 @@ class NaiveDiagonalFBTest(test.TestCase):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
block.instantiate_factors((grads,), 0.5)
@ -173,7 +183,8 @@ class NaiveDiagonalFBTest(test.TestCase):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
params = array_ops.constant([[1.], [2.]])
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
grads = params**2
block.instantiate_factors((grads,), 0.5)
@ -189,7 +200,8 @@ class NaiveDiagonalFBTest(test.TestCase):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
block.register_additional_minibatch(32)
grads = (params[0]**2, math_ops.sqrt(params[1]))
damping = 0.5
block.instantiate_factors((grads,), damping)

View File

@ -30,6 +30,21 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
class MockFisherBlock(object):
"""A fake FisherBlock."""
num_registered_minibatches = 2
def __init__(self, name='MockFisherBlock'):
self.name = name
def __eq__(self, other):
return isinstance(other, MockFisherBlock) and other.name == self.name
def __hash__(self):
return hash(self.name)
class LayerParametersDictTest(test.TestCase):
def testSetItem(self):
@ -172,10 +187,12 @@ class LayerCollectionTest(test.TestCase):
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {x: '1', z: '2'}
lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
lc.register_block((x, y), 'foo')
self.assertEqual(set(['2', 'foo']), set(lc.get_blocks()))
lc.register_block((x, y), MockFisherBlock('foo'))
self.assertEqual(
set([MockFisherBlock('2'), MockFisherBlock('foo')]),
set(lc.get_blocks()))
def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
@ -438,11 +455,6 @@ class LayerCollectionTest(test.TestCase):
def testGetUseCountMap(self):
"""Ensure get_use_count_map() sums 'num_registered_minibatches'."""
class MockFisherBlock(object):
num_registered_minibatches = 2
lc = layer_collection.LayerCollection()
lc.fisher_blocks = {
'a': MockFisherBlock(),

View File

@ -133,16 +133,15 @@ class FullFB(FisherBlock):
to any type of parameter in principle, but has very high variance.
"""
def __init__(self, layer_collection, params, batch_size):
def __init__(self, layer_collection, params):
"""Creates a FullFB block.
Args:
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
params: The parameters of this layer (Tensor or tuple of Tensors).
batch_size: The batch size, used in the covariance estimator.
"""
self._batch_size = batch_size
self._batch_sizes = []
self._params = params
super(FullFB, self).__init__(layer_collection)
@ -172,9 +171,21 @@ class FullFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._params
def register_additional_minibatch(self, batch_size):
"""Register an additional minibatch.
Args:
batch_size: The batch size, used in the covariance estimator.
"""
self._batch_sizes.append(batch_size)
@property
def num_registered_minibatches(self):
return 1 # Multiple minibatches not supported.
return len(self._batch_sizes)
@property
def _batch_size(self):
return math_ops.reduce_sum(self._batch_sizes)
class NaiveDiagonalFB(FisherBlock):
@ -186,17 +197,16 @@ class NaiveDiagonalFB(FisherBlock):
to any type of parameter in principle, but has very high variance.
"""
def __init__(self, layer_collection, params, batch_size):
def __init__(self, layer_collection, params):
"""Creates a NaiveDiagonalFB block.
Args:
layer_collection: The collection of all layers in the K-FAC approximate
Fisher information matrix to which this FisherBlock belongs.
params: The parameters of this layer (Tensor or tuple of Tensors).
batch_size: The batch size, used in the covariance estimator.
"""
self._params = params
self._batch_size = batch_size
self._batch_sizes = []
super(NaiveDiagonalFB, self).__init__(layer_collection)
@ -221,9 +231,21 @@ class NaiveDiagonalFB(FisherBlock):
def tensors_to_compute_grads(self):
return self._params
def register_additional_minibatch(self, batch_size):
"""Register an additional minibatch.
Args:
batch_size: The batch size, used in the covariance estimator.
"""
self._batch_sizes.append(batch_size)
@property
def num_registered_minibatches(self):
return 1 # Multiple minibatches not supported.
return len(self._batch_sizes)
@property
def _batch_size(self):
return math_ops.reduce_sum(self._batch_sizes)
class FullyConnectedDiagonalFB(FisherBlock):

View File

@ -103,10 +103,6 @@ class LayerCollection(object):
fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer
parameters (Tensors or tuples of Tensors) to FisherBlock instances.
fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
generic_registrations: a list of variables registered via a generic layer
registration. Generic registrations handle any and all of the ways a
variable is used in the graph, which means we don't need to check
their registration when verifying the correctness of the graph.
losses: a list of LossFunction objects. The loss to be optimized is their
sum.
"""
@ -114,7 +110,6 @@ class LayerCollection(object):
def __init__(self, graph=None, name="LayerCollection"):
self.fisher_blocks = LayerParametersDict()
self.fisher_factors = OrderedDict()
self._generic_registrations = set()
self._graph = graph or ops.get_default_graph()
self._loss_dict = {} # {str: LossFunction}
self._subgraph = None
@ -127,7 +122,7 @@ class LayerCollection(object):
"""LossFunctions registered with this LayerCollection."""
return list(self._loss_dict.values())
def register_block(self, layer_key, fisher_block):
def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
"""Validates and registers the layer_key associated with the fisher_block.
Validation consists of checking whether the key was already registered or
@ -153,20 +148,43 @@ class LayerCollection(object):
layer_key: The key to check for in existing registrations and to register
if valid.
fisher_block: The associated fisher block.
reuse: Method to use for inserting new FisherBlocks. One of True, False,
or VARIABLE_SCOPE.
Raises:
ValueError: If the layer_key was already registered, or if a subset of the
layer_key has already been registered as part of a different tuple.
Returns:
FisherBlock registered under 'layer_key'. May or may not be the same as
'fisher_block'.
"""
if reuse is VARIABLE_SCOPE:
reuse = variable_scope.get_variable_scope().reuse
if reuse is True or (reuse is variable_scope.AUTO_REUSE and
layer_key in self.fisher_blocks):
result = self.fisher_blocks[layer_key]
if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck
raise ValueError(
"Attempted to register FisherBlock of type %s when existing "
"FisherBlock has type %s." % (type(fisher_block), type(result)))
return result
if reuse is False and layer_key in self.fisher_blocks:
raise ValueError("FisherBlock for %s is already in LayerCollection." %
(layer_key,))
# Insert fisher_block into self.fisher_blocks.
if layer_key in self.fisher_blocks:
raise ValueError("Duplicate registration: {}".format(layer_key))
if isinstance(layer_key, (tuple, list)):
self._register_block_with_sequence_key(layer_key, fisher_block)
return self._register_block_with_sequence_key(layer_key, fisher_block)
else:
self._register_block_with_nonsequence_key(layer_key, fisher_block)
return self._register_block_with_nonsequence_key(layer_key, fisher_block)
def _register_block_with_sequence_key(self, layer_key, fisher_block):
"""Validates and registers the layer_key if it's a sequence."""
# Find all keys that are either supersets or subsets of 'layer_key'.
inclusions = {
fisher_elt
for layer_elt in layer_key for fisher_elt in self.fisher_blocks
@ -175,24 +193,60 @@ class LayerCollection(object):
if not inclusions:
self.fisher_blocks[layer_key] = fisher_block
return
return fisher_block
result_key = None
for key in inclusions:
fisher_block_key = key if isinstance(key, (tuple, list)) else (key,)
if set(layer_key).issubset(fisher_block_key):
logging.warning("Graph Registration Warning: tried to register "
"a subset ({}) of an already registered tuple "
"({}), skipping".format(layer_key, fisher_block_key))
return
if not set(fisher_block_key).issubset(layer_key):
in_existing_only = set(fisher_block_key) - set(layer_key)
in_new_only = set(layer_key) - set(fisher_block_key)
if in_existing_only and in_new_only:
# Existing and new key have an intersection but neither is a subset of
# the other. This is an error.
raise ValueError(
"Inconsistent registration, expected new key to be a subset or "
"superset of the existing key: existing is {}, new is {}".format(
key, layer_key))
else:
elif in_existing_only and not in_new_only:
# Existing key is strict superset of new key. Return existing
# FisherBlock.
logging.warning("Graph Registration Warning: tried to register "
"a subset ({}) of an already registered tuple "
"({}), skipping".format(layer_key, fisher_block_key))
assert result_key is None
result_key = key
elif in_new_only and not in_existing_only:
# Existing key is a strict subset of new key. Replace existing
# FisherBlock with new one.
#
# TODO(b/68715045): This is dangerous. If there are existing
# registrations for a minibatch from elsewhere in the graph, they won't
# be re-registered with this new FisherBlock. The type of FisherBlock
# could also change here.
logging.warning(
"Replacing existing FisherBlock for key {} with new FisherBlock "
"for key {}. {} registered minibatches from the existing "
"FisherBlock will not be migrated.".format(
key, layer_key,
self.fisher_blocks[key].num_registered_minibatches))
self.fisher_blocks.pop(key)
self.fisher_blocks[layer_key] = fisher_block
assert result_key is None
result_key = layer_key
elif not in_new_only and not in_existing_only:
# Existing and new are identical. Reuse the old FisherBlock.
#
# TODO(b/68715045): This is dangerous. If the new FisherBlock has
# existing registered minibatches, they will not be migrated to the
# existing FisherBlock.
assert result_key is None
result_key = key
else:
raise ValueError("Unexpected layer key conflict: {} vs. {}".format(
layer_key, key))
self.fisher_blocks[layer_key] = fisher_block
return self.fisher_blocks[result_key]
def _register_block_with_nonsequence_key(self, layer_key, fisher_block):
"""Validates and registers the layer_key if it's not a sequence."""
@ -209,6 +263,8 @@ class LayerCollection(object):
"variable ({}) but a containing tuple was already "
"registered ({}), skipping".format(layer_key, inclusions))
return fisher_block
def _equal_or_subset(self, elt1, elt2):
"""Checks if the elements are equal or one is contained in the other."""
return (elt1 == elt2 or (isinstance(elt1,
@ -230,10 +286,6 @@ class LayerCollection(object):
def get_factors(self):
return self.fisher_factors.values()
@property
def generic_registrations(self):
return self._generic_registrations
@property
def graph(self):
return self._graph
@ -291,24 +343,7 @@ class LayerCollection(object):
block_type = approx_to_block_types[approx]
has_bias = isinstance(params, (tuple, list))
if reuse == VARIABLE_SCOPE:
reuse = variable_scope.get_variable_scope().reuse
if reuse:
block = self.fisher_blocks.get(params, None)
if block is None:
raise KeyError(
"Reuse requested but no FisherBlock found for params {}.".format(
params))
if not isinstance(block, block_type):
raise ValueError(
"Requested block of type {} but block of type {} already exists "
"for params {}.".format(block_type, type(block), params))
else:
block = block_type(self, has_bias)
self.register_block(params, block)
block = self.register_block(params, block_type(self, has_bias), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
def register_conv2d(self,
@ -351,42 +386,45 @@ class LayerCollection(object):
raise ValueError("Bad value {} for approx.".format(approx))
block_type = approx_to_block_types[approx]
if reuse == VARIABLE_SCOPE:
reuse = variable_scope.get_variable_scope().reuse
if reuse:
block = self.fisher_blocks.get(params, None)
if block is None:
raise KeyError(
"Reuse requested but no FisherBlock found for params {}.".format(
params))
if not isinstance(block, block_type):
raise ValueError(
"Requested block of type {} but block of type {} already exists "
"for params {}.".format(block_type, type(block), params))
else:
block = block_type(self, params, strides, padding)
self.register_block(params, block)
block = self.register_block(
params, block_type(self, params, strides, padding), reuse=reuse)
block.register_additional_minibatch(inputs, outputs)
def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
params = params if isinstance(params, (tuple, list)) else (params,)
self._generic_registrations |= set(params)
def register_generic(self,
params,
batch_size,
approx=APPROX_DIAGONAL_NAME,
reuse=VARIABLE_SCOPE):
"""Registers a generic layer.
# Generic registrations do not need special registration rules because we do
# not care about multiple generic registrations. Add them to the
# fisher_block dictionary manually rather than going through the logic in
# self.register_block.
if approx == APPROX_FULL_NAME:
self.fisher_blocks[params] = fb.FullFB(self, params, batch_size)
elif approx == APPROX_DIAGONAL_NAME:
self.fisher_blocks[params] = fb.NaiveDiagonalFB(self, params, batch_size)
else:
Args:
params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
this layer. Weight matrix should have shape [kernel_height,
kernel_width, in_channels, out_channels]. Bias should have shape
[out_channels].
batch_size: 0-D Tensor. Size of the minibatch.
approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME.
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
create a new FisherBlock. If VARIABLE_SCOPE, use
tf.get_variable_scope().reuse.
Raises:
ValueError: For improper value to 'approx'.
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
approx_to_block_types = {
APPROX_FULL_NAME: fb.FullFB,
APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
}
if approx not in approx_to_block_types:
raise ValueError("Bad value {} for approx.".format(approx))
block_type = approx_to_block_types[approx]
block = self.register_block(params, block_type(self, params), reuse=reuse)
block.register_additional_minibatch(batch_size)
def register_categorical_predictive_distribution(self,
logits,
seed=None,