K-FAC: Support for tf.AUTO_REUSE when re-using registrations. Multi-tower support for FullFB, NaiveDiagonalFB. Removal of LayerCollection.generic_registrations.
PiperOrigin-RevId: 174092003
This commit is contained in:
		
							parent
							
								
									0a7be5a2f5
								
							
						
					
					
						commit
						453dd5848f
					
				@ -46,7 +46,8 @@ class FullFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
 | 
			
		||||
      self.assertAllEqual(params, block.tensors_to_compute_grads())
 | 
			
		||||
 | 
			
		||||
@ -54,7 +55,8 @@ class FullFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
 | 
			
		||||
      self.assertAllEqual(params, block.tensors_to_compute_grads())
 | 
			
		||||
 | 
			
		||||
@ -62,7 +64,8 @@ class FullFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
 | 
			
		||||
      grads = (params[0]**2, math_ops.sqrt(params[1]))
 | 
			
		||||
      block.instantiate_factors(grads, 0.5)
 | 
			
		||||
@ -71,7 +74,8 @@ class FullFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default(), self.test_session() as sess:
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
      grads = (params[0]**2, math_ops.sqrt(params[1]))
 | 
			
		||||
      block.instantiate_factors((grads,), 0.5)
 | 
			
		||||
 | 
			
		||||
@ -88,7 +92,8 @@ class FullFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default(), self.test_session() as sess:
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = array_ops.constant([[1.], [2.]])
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
      grads = params**2
 | 
			
		||||
      block.instantiate_factors((grads,), 0.5)
 | 
			
		||||
 | 
			
		||||
@ -105,7 +110,8 @@ class FullFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default(), self.test_session() as sess:
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.FullFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
      grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
 | 
			
		||||
      damping = 0.5
 | 
			
		||||
      block.instantiate_factors((grads,), damping)
 | 
			
		||||
@ -131,7 +137,8 @@ class NaiveDiagonalFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
 | 
			
		||||
      self.assertAllEqual(params, block.tensors_to_compute_grads())
 | 
			
		||||
 | 
			
		||||
@ -139,7 +146,8 @@ class NaiveDiagonalFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
 | 
			
		||||
      self.assertAllEqual(params, block.tensors_to_compute_grads())
 | 
			
		||||
 | 
			
		||||
@ -147,7 +155,8 @@ class NaiveDiagonalFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
 | 
			
		||||
      grads = (params[0]**2, math_ops.sqrt(params[1]))
 | 
			
		||||
      block.instantiate_factors(grads, 0.5)
 | 
			
		||||
@ -156,7 +165,8 @@ class NaiveDiagonalFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default(), self.test_session() as sess:
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
      grads = (params[0]**2, math_ops.sqrt(params[1]))
 | 
			
		||||
      block.instantiate_factors((grads,), 0.5)
 | 
			
		||||
 | 
			
		||||
@ -173,7 +183,8 @@ class NaiveDiagonalFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default(), self.test_session() as sess:
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = array_ops.constant([[1.], [2.]])
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
      grads = params**2
 | 
			
		||||
      block.instantiate_factors((grads,), 0.5)
 | 
			
		||||
 | 
			
		||||
@ -189,7 +200,8 @@ class NaiveDiagonalFBTest(test.TestCase):
 | 
			
		||||
    with ops.Graph().as_default(), self.test_session() as sess:
 | 
			
		||||
      random_seed.set_random_seed(200)
 | 
			
		||||
      params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32)
 | 
			
		||||
      block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
 | 
			
		||||
      block.register_additional_minibatch(32)
 | 
			
		||||
      grads = (params[0]**2, math_ops.sqrt(params[1]))
 | 
			
		||||
      damping = 0.5
 | 
			
		||||
      block.instantiate_factors((grads,), damping)
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,21 @@ from tensorflow.python.ops import variable_scope
 | 
			
		||||
from tensorflow.python.platform import test
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockFisherBlock(object):
 | 
			
		||||
  """A fake FisherBlock."""
 | 
			
		||||
 | 
			
		||||
  num_registered_minibatches = 2
 | 
			
		||||
 | 
			
		||||
  def __init__(self, name='MockFisherBlock'):
 | 
			
		||||
    self.name = name
 | 
			
		||||
 | 
			
		||||
  def __eq__(self, other):
 | 
			
		||||
    return isinstance(other, MockFisherBlock) and other.name == self.name
 | 
			
		||||
 | 
			
		||||
  def __hash__(self):
 | 
			
		||||
    return hash(self.name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerParametersDictTest(test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def testSetItem(self):
 | 
			
		||||
@ -172,10 +187,12 @@ class LayerCollectionTest(test.TestCase):
 | 
			
		||||
    y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
 | 
			
		||||
    z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
 | 
			
		||||
    lc = layer_collection.LayerCollection()
 | 
			
		||||
    lc.fisher_blocks = {x: '1', z: '2'}
 | 
			
		||||
    lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
 | 
			
		||||
 | 
			
		||||
    lc.register_block((x, y), 'foo')
 | 
			
		||||
    self.assertEqual(set(['2', 'foo']), set(lc.get_blocks()))
 | 
			
		||||
    lc.register_block((x, y), MockFisherBlock('foo'))
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        set([MockFisherBlock('2'), MockFisherBlock('foo')]),
 | 
			
		||||
        set(lc.get_blocks()))
 | 
			
		||||
 | 
			
		||||
  def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
 | 
			
		||||
    x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
 | 
			
		||||
@ -438,11 +455,6 @@ class LayerCollectionTest(test.TestCase):
 | 
			
		||||
 | 
			
		||||
  def testGetUseCountMap(self):
 | 
			
		||||
    """Ensure get_use_count_map() sums 'num_registered_minibatches'."""
 | 
			
		||||
 | 
			
		||||
    class MockFisherBlock(object):
 | 
			
		||||
 | 
			
		||||
      num_registered_minibatches = 2
 | 
			
		||||
 | 
			
		||||
    lc = layer_collection.LayerCollection()
 | 
			
		||||
    lc.fisher_blocks = {
 | 
			
		||||
        'a': MockFisherBlock(),
 | 
			
		||||
 | 
			
		||||
@ -133,16 +133,15 @@ class FullFB(FisherBlock):
 | 
			
		||||
  to any type of parameter in principle, but has very high variance.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def __init__(self, layer_collection, params, batch_size):
 | 
			
		||||
  def __init__(self, layer_collection, params):
 | 
			
		||||
    """Creates a FullFB block.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      layer_collection: The collection of all layers in the K-FAC approximate
 | 
			
		||||
          Fisher information matrix to which this FisherBlock belongs.
 | 
			
		||||
      params: The parameters of this layer (Tensor or tuple of Tensors).
 | 
			
		||||
      batch_size: The batch size, used in the covariance estimator.
 | 
			
		||||
    """
 | 
			
		||||
    self._batch_size = batch_size
 | 
			
		||||
    self._batch_sizes = []
 | 
			
		||||
    self._params = params
 | 
			
		||||
 | 
			
		||||
    super(FullFB, self).__init__(layer_collection)
 | 
			
		||||
@ -172,9 +171,21 @@ class FullFB(FisherBlock):
 | 
			
		||||
  def tensors_to_compute_grads(self):
 | 
			
		||||
    return self._params
 | 
			
		||||
 | 
			
		||||
  def register_additional_minibatch(self, batch_size):
 | 
			
		||||
    """Register an additional minibatch.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      batch_size: The batch size, used in the covariance estimator.
 | 
			
		||||
    """
 | 
			
		||||
    self._batch_sizes.append(batch_size)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def num_registered_minibatches(self):
 | 
			
		||||
    return 1  # Multiple minibatches not supported.
 | 
			
		||||
    return len(self._batch_sizes)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def _batch_size(self):
 | 
			
		||||
    return math_ops.reduce_sum(self._batch_sizes)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NaiveDiagonalFB(FisherBlock):
 | 
			
		||||
@ -186,17 +197,16 @@ class NaiveDiagonalFB(FisherBlock):
 | 
			
		||||
  to any type of parameter in principle, but has very high variance.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def __init__(self, layer_collection, params, batch_size):
 | 
			
		||||
  def __init__(self, layer_collection, params):
 | 
			
		||||
    """Creates a NaiveDiagonalFB block.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      layer_collection: The collection of all layers in the K-FAC approximate
 | 
			
		||||
          Fisher information matrix to which this FisherBlock belongs.
 | 
			
		||||
      params: The parameters of this layer (Tensor or tuple of Tensors).
 | 
			
		||||
      batch_size: The batch size, used in the covariance estimator.
 | 
			
		||||
    """
 | 
			
		||||
    self._params = params
 | 
			
		||||
    self._batch_size = batch_size
 | 
			
		||||
    self._batch_sizes = []
 | 
			
		||||
 | 
			
		||||
    super(NaiveDiagonalFB, self).__init__(layer_collection)
 | 
			
		||||
 | 
			
		||||
@ -221,9 +231,21 @@ class NaiveDiagonalFB(FisherBlock):
 | 
			
		||||
  def tensors_to_compute_grads(self):
 | 
			
		||||
    return self._params
 | 
			
		||||
 | 
			
		||||
  def register_additional_minibatch(self, batch_size):
 | 
			
		||||
    """Register an additional minibatch.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      batch_size: The batch size, used in the covariance estimator.
 | 
			
		||||
    """
 | 
			
		||||
    self._batch_sizes.append(batch_size)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def num_registered_minibatches(self):
 | 
			
		||||
    return 1  # Multiple minibatches not supported.
 | 
			
		||||
    return len(self._batch_sizes)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def _batch_size(self):
 | 
			
		||||
    return math_ops.reduce_sum(self._batch_sizes)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FullyConnectedDiagonalFB(FisherBlock):
 | 
			
		||||
 | 
			
		||||
@ -103,10 +103,6 @@ class LayerCollection(object):
 | 
			
		||||
    fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer
 | 
			
		||||
        parameters (Tensors or tuples of Tensors) to FisherBlock instances.
 | 
			
		||||
    fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
 | 
			
		||||
    generic_registrations: a list of variables registered via a generic layer
 | 
			
		||||
        registration. Generic registrations handle any and all of the ways a
 | 
			
		||||
        variable is used in the graph, which means we don't need to check
 | 
			
		||||
        their registration when verifying the correctness of the graph.
 | 
			
		||||
    losses: a list of LossFunction objects. The loss to be optimized is their
 | 
			
		||||
        sum.
 | 
			
		||||
  """
 | 
			
		||||
@ -114,7 +110,6 @@ class LayerCollection(object):
 | 
			
		||||
  def __init__(self, graph=None, name="LayerCollection"):
 | 
			
		||||
    self.fisher_blocks = LayerParametersDict()
 | 
			
		||||
    self.fisher_factors = OrderedDict()
 | 
			
		||||
    self._generic_registrations = set()
 | 
			
		||||
    self._graph = graph or ops.get_default_graph()
 | 
			
		||||
    self._loss_dict = {}  # {str: LossFunction}
 | 
			
		||||
    self._subgraph = None
 | 
			
		||||
@ -127,7 +122,7 @@ class LayerCollection(object):
 | 
			
		||||
    """LossFunctions registered with this LayerCollection."""
 | 
			
		||||
    return list(self._loss_dict.values())
 | 
			
		||||
 | 
			
		||||
  def register_block(self, layer_key, fisher_block):
 | 
			
		||||
  def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
 | 
			
		||||
    """Validates and registers the layer_key associated with the fisher_block.
 | 
			
		||||
 | 
			
		||||
    Validation consists of checking whether the key was already registered or
 | 
			
		||||
@ -153,20 +148,43 @@ class LayerCollection(object):
 | 
			
		||||
      layer_key: The key to check for in existing registrations and to register
 | 
			
		||||
          if valid.
 | 
			
		||||
      fisher_block: The associated fisher block.
 | 
			
		||||
      reuse: Method to use for inserting new FisherBlocks. One of True, False,
 | 
			
		||||
        or VARIABLE_SCOPE.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: If the layer_key was already registered, or if a subset of the
 | 
			
		||||
          layer_key has already been registered as part of a different tuple.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      FisherBlock registered under 'layer_key'. May or may not be the same as
 | 
			
		||||
      'fisher_block'.
 | 
			
		||||
    """
 | 
			
		||||
    if reuse is VARIABLE_SCOPE:
 | 
			
		||||
      reuse = variable_scope.get_variable_scope().reuse
 | 
			
		||||
 | 
			
		||||
    if reuse is True or (reuse is variable_scope.AUTO_REUSE and
 | 
			
		||||
                         layer_key in self.fisher_blocks):
 | 
			
		||||
      result = self.fisher_blocks[layer_key]
 | 
			
		||||
      if type(result) != type(fisher_block):  # pylint: disable=unidiomatic-typecheck
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "Attempted to register FisherBlock of type %s when existing "
 | 
			
		||||
            "FisherBlock has type %s." % (type(fisher_block), type(result)))
 | 
			
		||||
      return result
 | 
			
		||||
    if reuse is False and layer_key in self.fisher_blocks:
 | 
			
		||||
      raise ValueError("FisherBlock for %s is already in LayerCollection." %
 | 
			
		||||
                       (layer_key,))
 | 
			
		||||
 | 
			
		||||
    # Insert fisher_block into self.fisher_blocks.
 | 
			
		||||
    if layer_key in self.fisher_blocks:
 | 
			
		||||
      raise ValueError("Duplicate registration: {}".format(layer_key))
 | 
			
		||||
    if isinstance(layer_key, (tuple, list)):
 | 
			
		||||
      self._register_block_with_sequence_key(layer_key, fisher_block)
 | 
			
		||||
      return self._register_block_with_sequence_key(layer_key, fisher_block)
 | 
			
		||||
    else:
 | 
			
		||||
      self._register_block_with_nonsequence_key(layer_key, fisher_block)
 | 
			
		||||
      return self._register_block_with_nonsequence_key(layer_key, fisher_block)
 | 
			
		||||
 | 
			
		||||
  def _register_block_with_sequence_key(self, layer_key, fisher_block):
 | 
			
		||||
    """Validates and registers the layer_key if it's a sequence."""
 | 
			
		||||
    # Find all keys that are either supersets or subsets of 'layer_key'.
 | 
			
		||||
    inclusions = {
 | 
			
		||||
        fisher_elt
 | 
			
		||||
        for layer_elt in layer_key for fisher_elt in self.fisher_blocks
 | 
			
		||||
@ -175,24 +193,60 @@ class LayerCollection(object):
 | 
			
		||||
 | 
			
		||||
    if not inclusions:
 | 
			
		||||
      self.fisher_blocks[layer_key] = fisher_block
 | 
			
		||||
      return
 | 
			
		||||
      return fisher_block
 | 
			
		||||
 | 
			
		||||
    result_key = None
 | 
			
		||||
    for key in inclusions:
 | 
			
		||||
      fisher_block_key = key if isinstance(key, (tuple, list)) else (key,)
 | 
			
		||||
      if set(layer_key).issubset(fisher_block_key):
 | 
			
		||||
        logging.warning("Graph Registration Warning: tried to register "
 | 
			
		||||
                        "a subset ({}) of an already registered tuple "
 | 
			
		||||
                        "({}), skipping".format(layer_key, fisher_block_key))
 | 
			
		||||
        return
 | 
			
		||||
      if not set(fisher_block_key).issubset(layer_key):
 | 
			
		||||
      in_existing_only = set(fisher_block_key) - set(layer_key)
 | 
			
		||||
      in_new_only = set(layer_key) - set(fisher_block_key)
 | 
			
		||||
 | 
			
		||||
      if in_existing_only and in_new_only:
 | 
			
		||||
        # Existing and new key have an intersection but neither is a subset of
 | 
			
		||||
        # the other. This is an error.
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "Inconsistent registration, expected new key to be a subset or "
 | 
			
		||||
            "superset of the existing key: existing is {}, new is {}".format(
 | 
			
		||||
                key, layer_key))
 | 
			
		||||
      else:
 | 
			
		||||
      elif in_existing_only and not in_new_only:
 | 
			
		||||
        # Existing key is strict superset of new key. Return existing
 | 
			
		||||
        # FisherBlock.
 | 
			
		||||
        logging.warning("Graph Registration Warning: tried to register "
 | 
			
		||||
                        "a subset ({}) of an already registered tuple "
 | 
			
		||||
                        "({}), skipping".format(layer_key, fisher_block_key))
 | 
			
		||||
        assert result_key is None
 | 
			
		||||
        result_key = key
 | 
			
		||||
      elif in_new_only and not in_existing_only:
 | 
			
		||||
        # Existing key is a strict subset of new key. Replace existing
 | 
			
		||||
        # FisherBlock with new one.
 | 
			
		||||
        #
 | 
			
		||||
        # TODO(b/68715045): This is dangerous. If there are existing
 | 
			
		||||
        # registrations for a minibatch from elsewhere in the graph, they won't
 | 
			
		||||
        # be re-registered with this new FisherBlock. The type of FisherBlock
 | 
			
		||||
        # could also change here.
 | 
			
		||||
        logging.warning(
 | 
			
		||||
            "Replacing existing FisherBlock for key {} with new FisherBlock "
 | 
			
		||||
            "for key {}. {} registered minibatches from the existing "
 | 
			
		||||
            "FisherBlock will not be migrated.".format(
 | 
			
		||||
                key, layer_key,
 | 
			
		||||
                self.fisher_blocks[key].num_registered_minibatches))
 | 
			
		||||
        self.fisher_blocks.pop(key)
 | 
			
		||||
        self.fisher_blocks[layer_key] = fisher_block
 | 
			
		||||
        assert result_key is None
 | 
			
		||||
        result_key = layer_key
 | 
			
		||||
      elif not in_new_only and not in_existing_only:
 | 
			
		||||
        # Existing and new are identical. Reuse the old FisherBlock.
 | 
			
		||||
        #
 | 
			
		||||
        # TODO(b/68715045): This is dangerous. If the new FisherBlock has
 | 
			
		||||
        # existing registered minibatches, they will not be migrated to the
 | 
			
		||||
        # existing FisherBlock.
 | 
			
		||||
        assert result_key is None
 | 
			
		||||
        result_key = key
 | 
			
		||||
      else:
 | 
			
		||||
        raise ValueError("Unexpected layer key conflict: {} vs. {}".format(
 | 
			
		||||
            layer_key, key))
 | 
			
		||||
 | 
			
		||||
    self.fisher_blocks[layer_key] = fisher_block
 | 
			
		||||
    return self.fisher_blocks[result_key]
 | 
			
		||||
 | 
			
		||||
  def _register_block_with_nonsequence_key(self, layer_key, fisher_block):
 | 
			
		||||
    """Validates and registers the layer_key if it's not a sequence."""
 | 
			
		||||
@ -209,6 +263,8 @@ class LayerCollection(object):
 | 
			
		||||
                      "variable ({}) but a containing tuple was already "
 | 
			
		||||
                      "registered ({}), skipping".format(layer_key, inclusions))
 | 
			
		||||
 | 
			
		||||
    return fisher_block
 | 
			
		||||
 | 
			
		||||
  def _equal_or_subset(self, elt1, elt2):
 | 
			
		||||
    """Checks if the elements are equal or one is contained in the other."""
 | 
			
		||||
    return (elt1 == elt2 or (isinstance(elt1,
 | 
			
		||||
@ -230,10 +286,6 @@ class LayerCollection(object):
 | 
			
		||||
  def get_factors(self):
 | 
			
		||||
    return self.fisher_factors.values()
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def generic_registrations(self):
 | 
			
		||||
    return self._generic_registrations
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def graph(self):
 | 
			
		||||
    return self._graph
 | 
			
		||||
@ -291,24 +343,7 @@ class LayerCollection(object):
 | 
			
		||||
    block_type = approx_to_block_types[approx]
 | 
			
		||||
    has_bias = isinstance(params, (tuple, list))
 | 
			
		||||
 | 
			
		||||
    if reuse == VARIABLE_SCOPE:
 | 
			
		||||
      reuse = variable_scope.get_variable_scope().reuse
 | 
			
		||||
 | 
			
		||||
    if reuse:
 | 
			
		||||
      block = self.fisher_blocks.get(params, None)
 | 
			
		||||
      if block is None:
 | 
			
		||||
        raise KeyError(
 | 
			
		||||
            "Reuse requested but no FisherBlock found for params {}.".format(
 | 
			
		||||
                params))
 | 
			
		||||
      if not isinstance(block, block_type):
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "Requested block of type {} but block of type {} already exists "
 | 
			
		||||
            "for params {}.".format(block_type, type(block), params))
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
      block = block_type(self, has_bias)
 | 
			
		||||
      self.register_block(params, block)
 | 
			
		||||
 | 
			
		||||
    block = self.register_block(params, block_type(self, has_bias), reuse=reuse)
 | 
			
		||||
    block.register_additional_minibatch(inputs, outputs)
 | 
			
		||||
 | 
			
		||||
  def register_conv2d(self,
 | 
			
		||||
@ -351,42 +386,45 @@ class LayerCollection(object):
 | 
			
		||||
      raise ValueError("Bad value {} for approx.".format(approx))
 | 
			
		||||
 | 
			
		||||
    block_type = approx_to_block_types[approx]
 | 
			
		||||
 | 
			
		||||
    if reuse == VARIABLE_SCOPE:
 | 
			
		||||
      reuse = variable_scope.get_variable_scope().reuse
 | 
			
		||||
 | 
			
		||||
    if reuse:
 | 
			
		||||
      block = self.fisher_blocks.get(params, None)
 | 
			
		||||
      if block is None:
 | 
			
		||||
        raise KeyError(
 | 
			
		||||
            "Reuse requested but no FisherBlock found for params {}.".format(
 | 
			
		||||
                params))
 | 
			
		||||
      if not isinstance(block, block_type):
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "Requested block of type {} but block of type {} already exists "
 | 
			
		||||
            "for params {}.".format(block_type, type(block), params))
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
      block = block_type(self, params, strides, padding)
 | 
			
		||||
      self.register_block(params, block)
 | 
			
		||||
 | 
			
		||||
    block = self.register_block(
 | 
			
		||||
        params, block_type(self, params, strides, padding), reuse=reuse)
 | 
			
		||||
    block.register_additional_minibatch(inputs, outputs)
 | 
			
		||||
 | 
			
		||||
  def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME):
 | 
			
		||||
    params = params if isinstance(params, (tuple, list)) else (params,)
 | 
			
		||||
    self._generic_registrations |= set(params)
 | 
			
		||||
  def register_generic(self,
 | 
			
		||||
                       params,
 | 
			
		||||
                       batch_size,
 | 
			
		||||
                       approx=APPROX_DIAGONAL_NAME,
 | 
			
		||||
                       reuse=VARIABLE_SCOPE):
 | 
			
		||||
    """Registers a generic layer.
 | 
			
		||||
 | 
			
		||||
    # Generic registrations do not need special registration rules because we do
 | 
			
		||||
    # not care about multiple generic registrations. Add them to the
 | 
			
		||||
    # fisher_block dictionary manually rather than going through the logic in
 | 
			
		||||
    # self.register_block.
 | 
			
		||||
    if approx == APPROX_FULL_NAME:
 | 
			
		||||
      self.fisher_blocks[params] = fb.FullFB(self, params, batch_size)
 | 
			
		||||
    elif approx == APPROX_DIAGONAL_NAME:
 | 
			
		||||
      self.fisher_blocks[params] = fb.NaiveDiagonalFB(self, params, batch_size)
 | 
			
		||||
    else:
 | 
			
		||||
    Args:
 | 
			
		||||
      params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
 | 
			
		||||
        this layer. Weight matrix should have shape [kernel_height,
 | 
			
		||||
        kernel_width, in_channels, out_channels].  Bias should have shape
 | 
			
		||||
        [out_channels].
 | 
			
		||||
      batch_size: 0-D Tensor. Size of the minibatch.
 | 
			
		||||
      approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME.
 | 
			
		||||
      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
 | 
			
		||||
        create a new FisherBlock.  If VARIABLE_SCOPE, use
 | 
			
		||||
        tf.get_variable_scope().reuse.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: For improper value to 'approx'.
 | 
			
		||||
      KeyError: If reuse == True but no FisherBlock found for 'params'.
 | 
			
		||||
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
 | 
			
		||||
    """
 | 
			
		||||
    approx_to_block_types = {
 | 
			
		||||
        APPROX_FULL_NAME: fb.FullFB,
 | 
			
		||||
        APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if approx not in approx_to_block_types:
 | 
			
		||||
      raise ValueError("Bad value {} for approx.".format(approx))
 | 
			
		||||
 | 
			
		||||
    block_type = approx_to_block_types[approx]
 | 
			
		||||
    block = self.register_block(params, block_type(self, params), reuse=reuse)
 | 
			
		||||
    block.register_additional_minibatch(batch_size)
 | 
			
		||||
 | 
			
		||||
  def register_categorical_predictive_distribution(self,
 | 
			
		||||
                                                   logits,
 | 
			
		||||
                                                   seed=None,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user