K-FAC: Support for tf.AUTO_REUSE when re-using registrations. Multi-tower support for FullFB, NaiveDiagonalFB. Removal of LayerCollection.generic_registrations.
PiperOrigin-RevId: 174092003
This commit is contained in:
parent
0a7be5a2f5
commit
453dd5848f
@ -46,7 +46,8 @@ class FullFBTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.FullFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.FullFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
|
||||
self.assertAllEqual(params, block.tensors_to_compute_grads())
|
||||
|
||||
@ -54,7 +55,8 @@ class FullFBTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.FullFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.FullFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
|
||||
self.assertAllEqual(params, block.tensors_to_compute_grads())
|
||||
|
||||
@ -62,7 +64,8 @@ class FullFBTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.FullFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.FullFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
|
||||
grads = (params[0]**2, math_ops.sqrt(params[1]))
|
||||
block.instantiate_factors(grads, 0.5)
|
||||
@ -71,7 +74,8 @@ class FullFBTest(test.TestCase):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.FullFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.FullFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
grads = (params[0]**2, math_ops.sqrt(params[1]))
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
|
||||
@ -88,7 +92,8 @@ class FullFBTest(test.TestCase):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
params = array_ops.constant([[1.], [2.]])
|
||||
block = fb.FullFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.FullFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
grads = params**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
|
||||
@ -105,7 +110,8 @@ class FullFBTest(test.TestCase):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.FullFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.FullFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
|
||||
damping = 0.5
|
||||
block.instantiate_factors((grads,), damping)
|
||||
@ -131,7 +137,8 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
|
||||
self.assertAllEqual(params, block.tensors_to_compute_grads())
|
||||
|
||||
@ -139,7 +146,8 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
|
||||
self.assertAllEqual(params, block.tensors_to_compute_grads())
|
||||
|
||||
@ -147,7 +155,8 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
|
||||
grads = (params[0]**2, math_ops.sqrt(params[1]))
|
||||
block.instantiate_factors(grads, 0.5)
|
||||
@ -156,7 +165,8 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
grads = (params[0]**2, math_ops.sqrt(params[1]))
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
|
||||
@ -173,7 +183,8 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
params = array_ops.constant([[1.], [2.]])
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
grads = params**2
|
||||
block.instantiate_factors((grads,), 0.5)
|
||||
|
||||
@ -189,7 +200,8 @@ class NaiveDiagonalFBTest(test.TestCase):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
random_seed.set_random_seed(200)
|
||||
params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
|
||||
block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
|
||||
block.register_additional_minibatch(32)
|
||||
grads = (params[0]**2, math_ops.sqrt(params[1]))
|
||||
damping = 0.5
|
||||
block.instantiate_factors((grads,), damping)
|
||||
|
@ -30,6 +30,21 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MockFisherBlock(object):
|
||||
"""A fake FisherBlock."""
|
||||
|
||||
num_registered_minibatches = 2
|
||||
|
||||
def __init__(self, name='MockFisherBlock'):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MockFisherBlock) and other.name == self.name
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
|
||||
class LayerParametersDictTest(test.TestCase):
|
||||
|
||||
def testSetItem(self):
|
||||
@ -172,10 +187,12 @@ class LayerCollectionTest(test.TestCase):
|
||||
y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
|
||||
z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {x: '1', z: '2'}
|
||||
lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
|
||||
|
||||
lc.register_block((x, y), 'foo')
|
||||
self.assertEqual(set(['2', 'foo']), set(lc.get_blocks()))
|
||||
lc.register_block((x, y), MockFisherBlock('foo'))
|
||||
self.assertEqual(
|
||||
set([MockFisherBlock('2'), MockFisherBlock('foo')]),
|
||||
set(lc.get_blocks()))
|
||||
|
||||
def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
|
||||
x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
|
||||
@ -438,11 +455,6 @@ class LayerCollectionTest(test.TestCase):
|
||||
|
||||
def testGetUseCountMap(self):
|
||||
"""Ensure get_use_count_map() sums 'num_registered_minibatches'."""
|
||||
|
||||
class MockFisherBlock(object):
|
||||
|
||||
num_registered_minibatches = 2
|
||||
|
||||
lc = layer_collection.LayerCollection()
|
||||
lc.fisher_blocks = {
|
||||
'a': MockFisherBlock(),
|
||||
|
@ -133,16 +133,15 @@ class FullFB(FisherBlock):
|
||||
to any type of parameter in principle, but has very high variance.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_collection, params, batch_size):
|
||||
def __init__(self, layer_collection, params):
|
||||
"""Creates a FullFB block.
|
||||
|
||||
Args:
|
||||
layer_collection: The collection of all layers in the K-FAC approximate
|
||||
Fisher information matrix to which this FisherBlock belongs.
|
||||
params: The parameters of this layer (Tensor or tuple of Tensors).
|
||||
batch_size: The batch size, used in the covariance estimator.
|
||||
"""
|
||||
self._batch_size = batch_size
|
||||
self._batch_sizes = []
|
||||
self._params = params
|
||||
|
||||
super(FullFB, self).__init__(layer_collection)
|
||||
@ -172,9 +171,21 @@ class FullFB(FisherBlock):
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._params
|
||||
|
||||
def register_additional_minibatch(self, batch_size):
|
||||
"""Register an additional minibatch.
|
||||
|
||||
Args:
|
||||
batch_size: The batch size, used in the covariance estimator.
|
||||
"""
|
||||
self._batch_sizes.append(batch_size)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return 1 # Multiple minibatches not supported.
|
||||
return len(self._batch_sizes)
|
||||
|
||||
@property
|
||||
def _batch_size(self):
|
||||
return math_ops.reduce_sum(self._batch_sizes)
|
||||
|
||||
|
||||
class NaiveDiagonalFB(FisherBlock):
|
||||
@ -186,17 +197,16 @@ class NaiveDiagonalFB(FisherBlock):
|
||||
to any type of parameter in principle, but has very high variance.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_collection, params, batch_size):
|
||||
def __init__(self, layer_collection, params):
|
||||
"""Creates a NaiveDiagonalFB block.
|
||||
|
||||
Args:
|
||||
layer_collection: The collection of all layers in the K-FAC approximate
|
||||
Fisher information matrix to which this FisherBlock belongs.
|
||||
params: The parameters of this layer (Tensor or tuple of Tensors).
|
||||
batch_size: The batch size, used in the covariance estimator.
|
||||
"""
|
||||
self._params = params
|
||||
self._batch_size = batch_size
|
||||
self._batch_sizes = []
|
||||
|
||||
super(NaiveDiagonalFB, self).__init__(layer_collection)
|
||||
|
||||
@ -221,9 +231,21 @@ class NaiveDiagonalFB(FisherBlock):
|
||||
def tensors_to_compute_grads(self):
|
||||
return self._params
|
||||
|
||||
def register_additional_minibatch(self, batch_size):
|
||||
"""Register an additional minibatch.
|
||||
|
||||
Args:
|
||||
batch_size: The batch size, used in the covariance estimator.
|
||||
"""
|
||||
self._batch_sizes.append(batch_size)
|
||||
|
||||
@property
|
||||
def num_registered_minibatches(self):
|
||||
return 1 # Multiple minibatches not supported.
|
||||
return len(self._batch_sizes)
|
||||
|
||||
@property
|
||||
def _batch_size(self):
|
||||
return math_ops.reduce_sum(self._batch_sizes)
|
||||
|
||||
|
||||
class FullyConnectedDiagonalFB(FisherBlock):
|
||||
|
@ -103,10 +103,6 @@ class LayerCollection(object):
|
||||
fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer
|
||||
parameters (Tensors or tuples of Tensors) to FisherBlock instances.
|
||||
fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
|
||||
generic_registrations: a list of variables registered via a generic layer
|
||||
registration. Generic registrations handle any and all of the ways a
|
||||
variable is used in the graph, which means we don't need to check
|
||||
their registration when verifying the correctness of the graph.
|
||||
losses: a list of LossFunction objects. The loss to be optimized is their
|
||||
sum.
|
||||
"""
|
||||
@ -114,7 +110,6 @@ class LayerCollection(object):
|
||||
def __init__(self, graph=None, name="LayerCollection"):
|
||||
self.fisher_blocks = LayerParametersDict()
|
||||
self.fisher_factors = OrderedDict()
|
||||
self._generic_registrations = set()
|
||||
self._graph = graph or ops.get_default_graph()
|
||||
self._loss_dict = {} # {str: LossFunction}
|
||||
self._subgraph = None
|
||||
@ -127,7 +122,7 @@ class LayerCollection(object):
|
||||
"""LossFunctions registered with this LayerCollection."""
|
||||
return list(self._loss_dict.values())
|
||||
|
||||
def register_block(self, layer_key, fisher_block):
|
||||
def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
|
||||
"""Validates and registers the layer_key associated with the fisher_block.
|
||||
|
||||
Validation consists of checking whether the key was already registered or
|
||||
@ -153,20 +148,43 @@ class LayerCollection(object):
|
||||
layer_key: The key to check for in existing registrations and to register
|
||||
if valid.
|
||||
fisher_block: The associated fisher block.
|
||||
reuse: Method to use for inserting new FisherBlocks. One of True, False,
|
||||
or VARIABLE_SCOPE.
|
||||
|
||||
Raises:
|
||||
ValueError: If the layer_key was already registered, or if a subset of the
|
||||
layer_key has already been registered as part of a different tuple.
|
||||
|
||||
Returns:
|
||||
FisherBlock registered under 'layer_key'. May or may not be the same as
|
||||
'fisher_block'.
|
||||
"""
|
||||
if reuse is VARIABLE_SCOPE:
|
||||
reuse = variable_scope.get_variable_scope().reuse
|
||||
|
||||
if reuse is True or (reuse is variable_scope.AUTO_REUSE and
|
||||
layer_key in self.fisher_blocks):
|
||||
result = self.fisher_blocks[layer_key]
|
||||
if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck
|
||||
raise ValueError(
|
||||
"Attempted to register FisherBlock of type %s when existing "
|
||||
"FisherBlock has type %s." % (type(fisher_block), type(result)))
|
||||
return result
|
||||
if reuse is False and layer_key in self.fisher_blocks:
|
||||
raise ValueError("FisherBlock for %s is already in LayerCollection." %
|
||||
(layer_key,))
|
||||
|
||||
# Insert fisher_block into self.fisher_blocks.
|
||||
if layer_key in self.fisher_blocks:
|
||||
raise ValueError("Duplicate registration: {}".format(layer_key))
|
||||
if isinstance(layer_key, (tuple, list)):
|
||||
self._register_block_with_sequence_key(layer_key, fisher_block)
|
||||
return self._register_block_with_sequence_key(layer_key, fisher_block)
|
||||
else:
|
||||
self._register_block_with_nonsequence_key(layer_key, fisher_block)
|
||||
return self._register_block_with_nonsequence_key(layer_key, fisher_block)
|
||||
|
||||
def _register_block_with_sequence_key(self, layer_key, fisher_block):
|
||||
"""Validates and registers the layer_key if it's a sequence."""
|
||||
# Find all keys that are either supersets or subsets of 'layer_key'.
|
||||
inclusions = {
|
||||
fisher_elt
|
||||
for layer_elt in layer_key for fisher_elt in self.fisher_blocks
|
||||
@ -175,24 +193,60 @@ class LayerCollection(object):
|
||||
|
||||
if not inclusions:
|
||||
self.fisher_blocks[layer_key] = fisher_block
|
||||
return
|
||||
return fisher_block
|
||||
|
||||
result_key = None
|
||||
for key in inclusions:
|
||||
fisher_block_key = key if isinstance(key, (tuple, list)) else (key,)
|
||||
if set(layer_key).issubset(fisher_block_key):
|
||||
logging.warning("Graph Registration Warning: tried to register "
|
||||
"a subset ({}) of an already registered tuple "
|
||||
"({}), skipping".format(layer_key, fisher_block_key))
|
||||
return
|
||||
if not set(fisher_block_key).issubset(layer_key):
|
||||
in_existing_only = set(fisher_block_key) - set(layer_key)
|
||||
in_new_only = set(layer_key) - set(fisher_block_key)
|
||||
|
||||
if in_existing_only and in_new_only:
|
||||
# Existing and new key have an intersection but neither is a subset of
|
||||
# the other. This is an error.
|
||||
raise ValueError(
|
||||
"Inconsistent registration, expected new key to be a subset or "
|
||||
"superset of the existing key: existing is {}, new is {}".format(
|
||||
key, layer_key))
|
||||
else:
|
||||
elif in_existing_only and not in_new_only:
|
||||
# Existing key is strict superset of new key. Return existing
|
||||
# FisherBlock.
|
||||
logging.warning("Graph Registration Warning: tried to register "
|
||||
"a subset ({}) of an already registered tuple "
|
||||
"({}), skipping".format(layer_key, fisher_block_key))
|
||||
assert result_key is None
|
||||
result_key = key
|
||||
elif in_new_only and not in_existing_only:
|
||||
# Existing key is a strict subset of new key. Replace existing
|
||||
# FisherBlock with new one.
|
||||
#
|
||||
# TODO(b/68715045): This is dangerous. If there are existing
|
||||
# registrations for a minibatch from elsewhere in the graph, they won't
|
||||
# be re-registered with this new FisherBlock. The type of FisherBlock
|
||||
# could also change here.
|
||||
logging.warning(
|
||||
"Replacing existing FisherBlock for key {} with new FisherBlock "
|
||||
"for key {}. {} registered minibatches from the existing "
|
||||
"FisherBlock will not be migrated.".format(
|
||||
key, layer_key,
|
||||
self.fisher_blocks[key].num_registered_minibatches))
|
||||
self.fisher_blocks.pop(key)
|
||||
self.fisher_blocks[layer_key] = fisher_block
|
||||
assert result_key is None
|
||||
result_key = layer_key
|
||||
elif not in_new_only and not in_existing_only:
|
||||
# Existing and new are identical. Reuse the old FisherBlock.
|
||||
#
|
||||
# TODO(b/68715045): This is dangerous. If the new FisherBlock has
|
||||
# existing registered minibatches, they will not be migrated to the
|
||||
# existing FisherBlock.
|
||||
assert result_key is None
|
||||
result_key = key
|
||||
else:
|
||||
raise ValueError("Unexpected layer key conflict: {} vs. {}".format(
|
||||
layer_key, key))
|
||||
|
||||
self.fisher_blocks[layer_key] = fisher_block
|
||||
return self.fisher_blocks[result_key]
|
||||
|
||||
def _register_block_with_nonsequence_key(self, layer_key, fisher_block):
|
||||
"""Validates and registers the layer_key if it's not a sequence."""
|
||||
@ -209,6 +263,8 @@ class LayerCollection(object):
|
||||
"variable ({}) but a containing tuple was already "
|
||||
"registered ({}), skipping".format(layer_key, inclusions))
|
||||
|
||||
return fisher_block
|
||||
|
||||
def _equal_or_subset(self, elt1, elt2):
|
||||
"""Checks if the elements are equal or one is contained in the other."""
|
||||
return (elt1 == elt2 or (isinstance(elt1,
|
||||
@ -230,10 +286,6 @@ class LayerCollection(object):
|
||||
def get_factors(self):
|
||||
return self.fisher_factors.values()
|
||||
|
||||
@property
|
||||
def generic_registrations(self):
|
||||
return self._generic_registrations
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
return self._graph
|
||||
@ -291,24 +343,7 @@ class LayerCollection(object):
|
||||
block_type = approx_to_block_types[approx]
|
||||
has_bias = isinstance(params, (tuple, list))
|
||||
|
||||
if reuse == VARIABLE_SCOPE:
|
||||
reuse = variable_scope.get_variable_scope().reuse
|
||||
|
||||
if reuse:
|
||||
block = self.fisher_blocks.get(params, None)
|
||||
if block is None:
|
||||
raise KeyError(
|
||||
"Reuse requested but no FisherBlock found for params {}.".format(
|
||||
params))
|
||||
if not isinstance(block, block_type):
|
||||
raise ValueError(
|
||||
"Requested block of type {} but block of type {} already exists "
|
||||
"for params {}.".format(block_type, type(block), params))
|
||||
|
||||
else:
|
||||
block = block_type(self, has_bias)
|
||||
self.register_block(params, block)
|
||||
|
||||
block = self.register_block(params, block_type(self, has_bias), reuse=reuse)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
def register_conv2d(self,
|
||||
@ -351,42 +386,45 @@ class LayerCollection(object):
|
||||
raise ValueError("Bad value {} for approx.".format(approx))
|
||||
|
||||
block_type = approx_to_block_types[approx]
|
||||
|
||||
if reuse == VARIABLE_SCOPE:
|
||||
reuse = variable_scope.get_variable_scope().reuse
|
||||
|
||||
if reuse:
|
||||
block = self.fisher_blocks.get(params, None)
|
||||
if block is None:
|
||||
raise KeyError(
|
||||
"Reuse requested but no FisherBlock found for params {}.".format(
|
||||
params))
|
||||
if not isinstance(block, block_type):
|
||||
raise ValueError(
|
||||
"Requested block of type {} but block of type {} already exists "
|
||||
"for params {}.".format(block_type, type(block), params))
|
||||
|
||||
else:
|
||||
block = block_type(self, params, strides, padding)
|
||||
self.register_block(params, block)
|
||||
|
||||
block = self.register_block(
|
||||
params, block_type(self, params, strides, padding), reuse=reuse)
|
||||
block.register_additional_minibatch(inputs, outputs)
|
||||
|
||||
def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
|
||||
params = params if isinstance(params, (tuple, list)) else (params,)
|
||||
self._generic_registrations |= set(params)
|
||||
def register_generic(self,
|
||||
params,
|
||||
batch_size,
|
||||
approx=APPROX_DIAGONAL_NAME,
|
||||
reuse=VARIABLE_SCOPE):
|
||||
"""Registers a generic layer.
|
||||
|
||||
# Generic registrations do not need special registration rules because we do
|
||||
# not care about multiple generic registrations. Add them to the
|
||||
# fisher_block dictionary manually rather than going through the logic in
|
||||
# self.register_block.
|
||||
if approx == APPROX_FULL_NAME:
|
||||
self.fisher_blocks[params] = fb.FullFB(self, params, batch_size)
|
||||
elif approx == APPROX_DIAGONAL_NAME:
|
||||
self.fisher_blocks[params] = fb.NaiveDiagonalFB(self, params, batch_size)
|
||||
else:
|
||||
Args:
|
||||
params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
|
||||
this layer. Weight matrix should have shape [kernel_height,
|
||||
kernel_width, in_channels, out_channels]. Bias should have shape
|
||||
[out_channels].
|
||||
batch_size: 0-D Tensor. Size of the minibatch.
|
||||
approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME.
|
||||
reuse: bool or str. If True, reuse an existing FisherBlock. If False,
|
||||
create a new FisherBlock. If VARIABLE_SCOPE, use
|
||||
tf.get_variable_scope().reuse.
|
||||
|
||||
Raises:
|
||||
ValueError: For improper value to 'approx'.
|
||||
KeyError: If reuse == True but no FisherBlock found for 'params'.
|
||||
ValueError: If reuse == True and FisherBlock found but of the wrong type.
|
||||
"""
|
||||
approx_to_block_types = {
|
||||
APPROX_FULL_NAME: fb.FullFB,
|
||||
APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
|
||||
}
|
||||
|
||||
if approx not in approx_to_block_types:
|
||||
raise ValueError("Bad value {} for approx.".format(approx))
|
||||
|
||||
block_type = approx_to_block_types[approx]
|
||||
block = self.register_block(params, block_type(self, params), reuse=reuse)
|
||||
block.register_additional_minibatch(batch_size)
|
||||
|
||||
def register_categorical_predictive_distribution(self,
|
||||
logits,
|
||||
seed=None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user