From 670dddf4ad81c67fc76b370bf7b9d77263824358 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2017 06:53:02 -0700 Subject: [PATCH] Multi-minibatch support for tf.contrib.kfac.fisher_blocks.FullyConnectedKFACBasicFB. PiperOrigin-RevId: 173109677 --- .../python/kernel_tests/fisher_blocks_test.py | 41 ++++++++-------- .../contrib/kfac/python/ops/fisher_blocks.py | 48 ++++++++++++++----- .../contrib/kfac/python/ops/fisher_factors.py | 8 ++++ .../kfac/python/ops/layer_collection.py | 6 +-- 4 files changed, 69 insertions(+), 34 deletions(-) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 80855da2e92..85ac08a1eb7 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -356,50 +356,51 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): random_seed.set_random_seed(200) inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), inputs, - outputs) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) + block.register_additional_minibatch(inputs, outputs) - self.assertAllEqual(outputs, block.tensors_to_compute_grads()) + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) def testInstantiateFactorsHasBias(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=True) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) def testInstantiateFactorsNoBias(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=False) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=False) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) - vector = (np.arange(2, 6).reshape(2, 2).astype(np.float32), np.arange( - 1, 3).reshape(2, 1).astype(np.float32)) + vector = ( + np.arange(2, 6).reshape(2, 2).astype(np.float32), # + np.arange(1, 3).reshape(2, 1).astype(np.float32)) output = block.multiply_inverse((array_ops.constant(vector[0]), array_ops.constant(vector[1]))) @@ -413,10 +414,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): random_seed.set_random_seed(200) inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=False) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -436,11 +437,11 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.zeros([32, input_dim]) outputs = array_ops.zeros([32, output_dim]) params = array_ops.zeros([input_dim, output_dim]) - block = fb.FullyConnectedKFACBasicFB( - lc.LayerCollection(), inputs, outputs, has_bias=False) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors((grads,), damping) + block.instantiate_factors(([grads],), damping) sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 5e822b5fe32..754c2cc853b 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -367,7 +367,7 @@ class ConvDiagonalFB(FisherBlock): (self._strides[1] * self._strides[2])) if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations**NORMALIZE_DAMPING_POWER + damping /= self._num_locations ** NORMALIZE_DAMPING_POWER self._damping = damping self._factor = self._layer_collection.make_or_get_factor( @@ -478,34 +478,60 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): K-FAC paper (https://arxiv.org/abs/1503.05671) """ - def __init__(self, layer_collection, inputs, outputs, has_bias=False): + def __init__(self, layer_collection, has_bias=False): """Creates a FullyConnectedKFACBasicFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. - inputs: The Tensor of input activations to this layer. - outputs: The Tensor of output pre-activations from this layer. has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = inputs - self._outputs = outputs + self._inputs = [] + self._outputs = [] self._has_bias = has_bias super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, - ((self._inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + # TODO(b/68033310): Validate which of, + # (1) summing on a single device (as below), or + # (2) on each device in isolation and aggregating + # is faster. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + ((inputs,), self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + (grads_list,)) self._register_damped_input_and_output_inverses(damping) def tensors_to_compute_grads(self): return self._outputs + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to the + matrix-multiply. + outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + class ConvKFCBasicFB(KroneckerProductFB): """FisherBlock for 2D convolutional layers using the basic KFC approx. diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 86a1782fcf0..b8b524406c3 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -573,6 +573,14 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """ def __init__(self, tensors, has_bias=False): + """Instantiate FullyConnectedKroneckerFactor. + + Args: + tensors: List of Tensors of shape [batch_size, n]. Represents either a + layer's inputs or its output's gradients. + has_bias: bool. If True, assume this factor is for the layer's inputs and + append '1' to each row. + """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 10ef5543516..ceb1131f286 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -255,9 +255,9 @@ class LayerCollection(object): approx=APPROX_KRONECKER_NAME): has_bias = isinstance(params, (tuple, list)) if approx == APPROX_KRONECKER_NAME: - self.register_block(params, - fb.FullyConnectedKFACBasicFB(self, inputs, outputs, - has_bias)) + block = fb.FullyConnectedKFACBasicFB(self, has_bias) + block.register_additional_minibatch(inputs, outputs) + self.register_block(params, block) elif approx == APPROX_DIAGONAL_NAME: block = fb.FullyConnectedDiagonalFB(self, has_bias) block.register_additional_minibatch(inputs, outputs)