diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 85ac08a1eb7..dbf40fccc82 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -652,10 +652,10 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.constant(params) inputs = random_ops.random_normal((2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs, - [1, 1, 1], 'SAME') + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME') + block.register_additional_minibatch(inputs, outputs) - self.assertAllEqual(outputs, block.tensors_to_compute_grads()) + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) def testConvKFCBasicFBInitParamsParamsTuple(self): self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)]) @@ -669,10 +669,11 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs, - (1, 1, 1, 1), 'SAME') + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), + 'SAME') + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -694,11 +695,12 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs, - (1, 1, 1, 1), 'SAME') + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), + 'SAME') + block.register_additional_minibatch(inputs, outputs) self.assertFalse(block._has_bias) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -716,11 +718,12 @@ class ConvKFCBasicFBTest(test.TestCase): params = [random_ops.random_normal((2, 2, 2, 2))] inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs, - (1, 1, 1, 1), 'SAME') + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), + 'SAME') + block.register_additional_minibatch(inputs, outputs) self.assertTrue(block._has_bias) grads = outputs**2 - block.instantiate_factors((grads,), 0.5) + block.instantiate_factors(([grads],), 0.5) # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -738,11 +741,12 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.zeros((2, 2, 2, 2)) inputs = array_ops.zeros((2, 2, 2, 2)) outputs = array_ops.zeros((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, inputs, outputs, - (1, 1, 1, 1), 'SAME') + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), + 'SAME') + block.register_additional_minibatch(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors((grads,), damping) + block.instantiate_factors(([grads],), damping) sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 7ef755c35ed..efffaaef8d5 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -454,6 +454,14 @@ class KroneckerProductFB(FisherBlock): @property def _renorm_coeff(self): + """Kronecker factor multiplier coefficient. + + If this FisherBlock is represented as 'FB = c * kron(left, right)', then + this is 'c'. + + Returns: + 0-D Tensor. + """ return 1.0 def multiply_inverse(self, vector): @@ -560,17 +568,34 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): @property def num_registered_minibatches(self): - return 1 # Multiple minibatches not supported. + return len(self._inputs) class ConvKFCBasicFB(KroneckerProductFB): """FisherBlock for 2D convolutional layers using the basic KFC approx. - See https://arxiv.org/abs/1602.01407 for details. + Estimates the Fisher Information matrix's blog for a convolutional + layer. + + Consider a convoluational layer in this model with (unshared) filter matrix + 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', + this FisherBlock estimates, + + F(w) = #locations * kronecker(E[flat(a) flat(a)^T], + E[flat(ds) flat(ds)^T]) + + where + + ds = (d / ds) log p(y | x, w) + #locations = number of (x, y) locations where 'w' is applied. + + where the expectation is taken over all examples and locations and flat() + concatenates an array's leading dimensions. + + See equation 23 in https://arxiv.org/abs/1602.01407 for details. """ - def __init__(self, layer_collection, params, inputs, outputs, strides, - padding): + def __init__(self, layer_collection, params, strides, padding): """Creates a ConvKFCBasicFB block. Args: @@ -580,38 +605,43 @@ class ConvKFCBasicFB(KroneckerProductFB): kernel alone, a Tensor of shape [kernel_height, kernel_width, in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. - inputs: A Tensor of shape [batch_size, height, width, in_channels]. - Input activations to this layer. - outputs: A Tensor of shape [batch_size, height, width, out_channels]. - Output pre-activations from this layer. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (1-D of Tensor length 4). """ - self._inputs = inputs - self._outputs = outputs - self._strides = strides + self._inputs = [] + self._outputs = [] + self._strides = tuple(strides) if isinstance(strides, list) else strides self._padding = padding self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) - input_shape = tuple(inputs.shape.as_list()) - self._num_locations = ( - input_shape[1] * input_shape[2] // (strides[1] * strides[2])) - super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): + # TODO(b/68033310): Validate which of, + # (1) summing on a single device (as below), or + # (2) on each device in isolation and aggregating + # is faster. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = _num_conv_locations(inputs.shape.as_list(), + self._strides) + self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, - (self._inputs, self._filter_shape, self._strides, self._padding, + (inputs, self._filter_shape, self._strides, self._padding, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) if NORMALIZE_DAMPING_POWER: damping /= self._num_locations**NORMALIZE_DAMPING_POWER + self._damping = damping + self._register_damped_input_and_output_inverses(damping) @property @@ -621,9 +651,21 @@ class ConvKFCBasicFB(KroneckerProductFB): def tensors_to_compute_grads(self): return self._outputs + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to + the convolution. + outputs: Tensor of shape [batch_size, height, width, output_size]. Layer + preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + @property def num_registered_minibatches(self): - return 1 # Multiple minibatches not supported. + return len(self._inputs) def _concat_along_batch_dim(tensor_list): @@ -651,3 +693,8 @@ def _concat_along_batch_dim(tensor_list): else: # [tensor1, tensor2] --> tensor return array_ops.concat(tensor_list, axis=0) + + +def _num_conv_locations(input_shape, strides): + """Returns the number of locations a Conv kernel is applied to.""" + return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index b8b524406c3..4e36813369e 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -609,9 +609,28 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): class ConvInputKroneckerFactor(InverseProvidingFactor): - """Kronecker factor for the input side of a convolutional layer.""" + r"""Kronecker factor for the input side of a convolutional layer. + + Estimates E[ a a^T ] where a is the inputs to a convolutional layer given + example x. Expectation is taken over all examples and locations. + + Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See + Section 3.1 Estimating the factors. + """ def __init__(self, inputs, filter_shape, strides, padding, has_bias=False): + """Initializes ConvInputKroneckerFactor. + + Args: + inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs + to layer. + filter_shape: 1-D Tensor of length 4. Contains [kernel_height, + kernel_width, in_channels, out_channels]. + strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride, + width_stride, in_channel_stride]. + padding: str. Padding method for layer. "SAME" or "VALID". + has_bias: bool. If True, append 1 to in_channel. + """ self._filter_shape = filter_shape self._strides = strides self._padding = padding @@ -659,9 +678,23 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): class ConvOutputKroneckerFactor(InverseProvidingFactor): - """Kronecker factor for the output side of a convolutional layer.""" + r"""Kronecker factor for the output side of a convolutional layer. + + Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer + given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over + all examples and locations. + + Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See + Section 3.1 Estimating the factors. + """ def __init__(self, outputs_grads): + """Initializes ConvOutputKroneckerFactor. + + Args: + outputs_grads: list of Tensors. Each Tensor is of shape + [batch_size, height, width, out_channels]. + """ self._out_channels = outputs_grads[0].shape.as_list()[3] self._outputs_grads = outputs_grads super(ConvOutputKroneckerFactor, self).__init__() diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 2b9958a46a6..77ddd19e59a 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -315,9 +315,9 @@ class LayerCollection(object): approx=APPROX_KRONECKER_NAME): if approx == APPROX_KRONECKER_NAME: - self.register_block(params, - fb.ConvKFCBasicFB(self, params, inputs, outputs, - strides, padding)) + block = fb.ConvKFCBasicFB(self, params, strides, padding) + block.register_additional_minibatch(inputs, outputs) + self.register_block(params, block) elif approx == APPROX_DIAGONAL_NAME: block = fb.ConvDiagonalFB(self, params, strides, padding) block.register_additional_minibatch(inputs, outputs)