- Adding support for Cholesky (inverse) factor multiplications.

- Refactored FisherFactor to use LinearOperator classes that know how to multiply themselves, compute their own trace, etc. This addresses the feature request: b/73356352
- Fixed some problems with FisherEstimator construction
- More careful casting of damping constants before they are used

PiperOrigin-RevId: 194379298
This commit is contained in:
James Martens 2018-04-26 04:37:28 -07:00 committed by TensorFlower Gardener
parent 8148895adc
commit 481f229881
11 changed files with 637 additions and 282 deletions

View File

@ -58,6 +58,7 @@ py_test(
deps = [ deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:layer_collection", "//tensorflow/contrib/kfac/python/ops:layer_collection",
"//tensorflow/contrib/kfac/python/ops:linear_operator",
"//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import layer_collection as lc
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
from tensorflow.contrib.kfac.python.ops import utils from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
@ -46,8 +47,9 @@ class UtilsTest(test.TestCase):
def testComputePiTracenorm(self): def testComputePiTracenorm(self):
with ops.Graph().as_default(), self.test_session() as sess: with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200) random_seed.set_random_seed(200)
left_factor = array_ops.diag([1., 2., 0., 1.]) diag = ops.convert_to_tensor([1., 2., 0., 1.])
right_factor = array_ops.ones([2., 2.]) left_factor = lo.LinearOperatorDiag(diag)
right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
# pi is the sqrt of the left trace norm divided by the right trace norm # pi is the sqrt of the left trace norm divided by the right trace norm
pi = fb.compute_pi_tracenorm(left_factor, right_factor) pi = fb.compute_pi_tracenorm(left_factor, right_factor)
@ -245,7 +247,6 @@ class NaiveDiagonalFBTest(test.TestCase):
full = sess.run(block.full_fisher_block()) full = sess.run(block.full_fisher_block())
explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
self.assertAllClose(output_flat, explicit) self.assertAllClose(output_flat, explicit)

View File

@ -70,18 +70,6 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def get_cov(self): def get_cov(self):
return NotImplementedError return NotImplementedError
def left_multiply(self, x, damping):
return NotImplementedError
def right_multiply(self, x, damping):
return NotImplementedError
def left_multiply_matpower(self, x, exp, damping):
return NotImplementedError
def right_multiply_matpower(self, x, exp, damping):
return NotImplementedError
def instantiate_inv_variables(self): def instantiate_inv_variables(self):
return NotImplementedError return NotImplementedError
@ -91,14 +79,35 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def _get_data_device(self): def _get_data_device(self):
raise NotImplementedError raise NotImplementedError
def register_matpower(self, exp, damping_func):
raise NotImplementedError
class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def register_cholesky(self, damping_func):
"""Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. raise NotImplementedError
def register_cholesky_inverse(self, damping_func):
raise NotImplementedError
def get_matpower(self, exp, damping_func):
raise NotImplementedError
def get_cholesky(self, damping_func):
raise NotImplementedError
def get_cholesky_inverse(self, damping_func):
raise NotImplementedError
def get_cov_as_linear_operator(self):
raise NotImplementedError
class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
"""Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
""" """
def __init__(self, shape): def __init__(self, shape):
self._shape = shape self._shape = shape
super(InverseProvidingFactorTestingDummy, self).__init__() super(DenseSquareMatrixFactorTestingDummy, self).__init__()
@property @property
def _var_scope(self): def _var_scope(self):
@ -230,13 +239,13 @@ class FisherFactorTest(test.TestCase):
self.assertEqual(0, len(factor.make_inverse_update_ops())) self.assertEqual(0, len(factor.make_inverse_update_ops()))
class InverseProvidingFactorTest(test.TestCase): class DenseSquareMatrixFactorTest(test.TestCase):
def testRegisterDampedInverse(self): def testRegisterDampedInverse(self):
with tf_ops.Graph().as_default(): with tf_ops.Graph().as_default():
random_seed.set_random_seed(200) random_seed.set_random_seed(200)
shape = [2, 2] shape = [2, 2]
factor = InverseProvidingFactorTestingDummy(shape) factor = DenseSquareMatrixFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c' factor_var_scope = 'dummy/a_b_c'
damping_funcs = [make_damping_func(0.1), damping_funcs = [make_damping_func(0.1),
@ -248,22 +257,25 @@ class InverseProvidingFactorTest(test.TestCase):
factor.instantiate_inv_variables() factor.instantiate_inv_variables()
inv = factor.get_inverse(damping_funcs[0]) inv = factor.get_inverse(damping_funcs[0]).to_dense()
self.assertEqual(inv, factor.get_inverse(damping_funcs[1])) self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2])) self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
self.assertEqual(factor.get_inverse(damping_funcs[2]), self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
factor.get_inverse(damping_funcs[3])) factor.get_inverse(damping_funcs[3]).to_dense())
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope) factor_var_scope)
self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]), factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
set(factor_vars))
self.assertEqual(set([inv,
factor.get_inverse(damping_funcs[2]).to_dense()]),
set(factor_tensors))
self.assertEqual(shape, inv.get_shape()) self.assertEqual(shape, inv.get_shape())
def testRegisterMatpower(self): def testRegisterMatpower(self):
with tf_ops.Graph().as_default(): with tf_ops.Graph().as_default():
random_seed.set_random_seed(200) random_seed.set_random_seed(200)
shape = [3, 3] shape = [3, 3]
factor = InverseProvidingFactorTestingDummy(shape) factor = DenseSquareMatrixFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c' factor_var_scope = 'dummy/a_b_c'
# TODO(b/74201126): Change to using the same func for both once # TODO(b/74201126): Change to using the same func for both once
@ -278,10 +290,13 @@ class InverseProvidingFactorTest(test.TestCase):
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope) factor_var_scope)
matpower1 = factor.get_matpower(-0.5, damping_func_1)
matpower2 = factor.get_matpower(2, damping_func_2)
self.assertEqual(set([matpower1, matpower2]), set(factor_vars)) factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
self.assertEqual(shape, matpower1.get_shape()) self.assertEqual(shape, matpower1.get_shape())
self.assertEqual(shape, matpower2.get_shape()) self.assertEqual(shape, matpower2.get_shape())
@ -297,7 +312,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess: with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200) random_seed.set_random_seed(200)
cov = np.array([[1., 2.], [3., 4.]]) cov = np.array([[1., 2.], [3., 4.]])
factor = InverseProvidingFactorTestingDummy(cov.shape) factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32) factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_funcs = [] damping_funcs = []
@ -316,7 +331,8 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(ops) sess.run(ops)
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
# The inverse op will assign the damped inverse of cov to the inv var. # The inverse op will assign the damped inverse of cov to the inv var.
new_invs.append(sess.run(factor.get_inverse(damping_funcs[i]))) new_invs.append(
sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
# We want to see that the new invs are all different from each other. # We want to see that the new invs are all different from each other.
for i in range(len(new_invs)): for i in range(len(new_invs)):
@ -328,7 +344,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess: with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200) random_seed.set_random_seed(200)
cov = np.array([[6., 2.], [2., 4.]]) cov = np.array([[6., 2.], [2., 4.]])
factor = InverseProvidingFactorTestingDummy(cov.shape) factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32) factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
damping = 0.5 damping = 0.5
@ -341,7 +357,7 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(tf_variables.global_variables_initializer()) sess.run(tf_variables.global_variables_initializer())
sess.run(ops[0]) sess.run(ops[0])
matpower = sess.run(factor.get_matpower(exp, damping_func)) matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
self.assertAllClose(matpower, matpower_np) self.assertAllClose(matpower, matpower_np)
@ -349,7 +365,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess: with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200) random_seed.set_random_seed(200)
cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
factor = InverseProvidingFactorTestingDummy(cov.shape) factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32) factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_func = make_damping_func(0) damping_func = make_damping_func(0)
@ -361,12 +377,12 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(tf_variables.global_variables_initializer()) sess.run(tf_variables.global_variables_initializer())
# The inverse op will assign the damped inverse of cov to the inv var. # The inverse op will assign the damped inverse of cov to the inv var.
old_inv = sess.run(factor.get_inverse(damping_func)) old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
self.assertAllClose( self.assertAllClose(
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
sess.run(ops) sess.run(ops)
new_inv = sess.run(factor.get_inverse(damping_func)) new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
self.assertAllClose(new_inv, np.linalg.inv(cov)) self.assertAllClose(new_inv, np.linalg.inv(cov))
@ -411,7 +427,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c') tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32) factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables() factor.instantiate_cov_variables()
self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
def testNaiveDiagonalFactorInitFloat64(self): def testNaiveDiagonalFactorInitFloat64(self):
with tf_ops.Graph().as_default(): with tf_ops.Graph().as_default():
@ -420,7 +436,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32) factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables() factor.instantiate_cov_variables()
cov = factor.get_cov_var() cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype) self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 1], cov.get_shape().as_list()) self.assertEqual([6, 1], cov.get_shape().as_list())
@ -444,7 +460,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
vocab_size = 5 vocab_size = 5
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor.instantiate_cov_variables() factor.instantiate_cov_variables()
cov = factor.get_cov_var() cov = factor.get_cov()
self.assertEqual(cov.shape.as_list(), [vocab_size]) self.assertEqual(cov.shape.as_list(), [vocab_size])
def testCovarianceUpdateOp(self): def testCovarianceUpdateOp(self):
@ -502,7 +518,7 @@ class ConvDiagonalFactorTest(test.TestCase):
self.kernel_height * self.kernel_width * self.in_channels, self.kernel_height * self.kernel_width * self.in_channels,
self.out_channels self.out_channels
], ],
factor.get_cov_var().shape.as_list()) factor.get_cov().shape.as_list())
def testMakeCovarianceUpdateOp(self): def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default(): with tf_ops.Graph().as_default():
@ -564,7 +580,7 @@ class ConvDiagonalFactorTest(test.TestCase):
self.kernel_height * self.kernel_width * self.in_channels + 1, self.kernel_height * self.kernel_width * self.in_channels + 1,
self.out_channels self.out_channels
], ],
factor.get_cov_var().shape.as_list()) factor.get_cov().shape.as_list())
# Ensure update op doesn't crash. # Ensure update op doesn't crash.
cov_update_op = factor.make_covariance_update_op(0.0) cov_update_op = factor.make_covariance_update_op(0.0)
@ -654,13 +670,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
# Ensure shape of covariance matches input size of filter. # Ensure shape of covariance matches input size of filter.
input_size = in_channels * (width**3) input_size = in_channels * (width**3)
self.assertEqual([input_size, input_size], self.assertEqual([input_size, input_size],
factor.get_cov_var().shape.as_list()) factor.get_cov().shape.as_list())
# Ensure cov_update_op doesn't crash. # Ensure cov_update_op doesn't crash.
with self.test_session() as sess: with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer()) sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0)) sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov_var()) cov = sess.run(factor.get_cov())
# Cov should be rank-8, as the filter will be applied at each corner of # Cov should be rank-8, as the filter will be applied at each corner of
# the 4-D cube. # the 4-D cube.
@ -685,13 +701,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
# Ensure shape of covariance matches input size of filter. # Ensure shape of covariance matches input size of filter.
self.assertEqual([in_channels, in_channels], self.assertEqual([in_channels, in_channels],
factor.get_cov_var().shape.as_list()) factor.get_cov().shape.as_list())
# Ensure cov_update_op doesn't crash. # Ensure cov_update_op doesn't crash.
with self.test_session() as sess: with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer()) sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0)) sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov_var()) cov = sess.run(factor.get_cov())
# Cov should be rank-9, as the filter will be applied at each location. # Cov should be rank-9, as the filter will be applied at each location.
self.assertMatrixRank(9, cov) self.assertMatrixRank(9, cov)
@ -716,7 +732,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
with self.test_session() as sess: with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer()) sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0)) sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov_var()) cov = sess.run(factor.get_cov())
# Cov should be the sum of 3 * 2 = 6 outer products. # Cov should be the sum of 3 * 2 = 6 outer products.
self.assertMatrixRank(6, cov) self.assertMatrixRank(6, cov)
@ -742,7 +758,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
with self.test_session() as sess: with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer()) sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0)) sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov_var()) cov = sess.run(factor.get_cov())
# Cov should be rank = in_channels, as only the center of the filter # Cov should be rank = in_channels, as only the center of the filter
# receives non-zero input for each input channel. # receives non-zero input for each input channel.

View File

@ -35,6 +35,7 @@ py_library(
srcs = ["fisher_factors.py"], srcs = ["fisher_factors.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":linear_operator",
":utils", ":utils",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
@ -63,6 +64,19 @@ py_library(
], ],
) )
py_library(
name = "linear_operator",
srcs = ["linear_operator.py"],
srcs_version = "PY2AND3",
deps = [
":utils",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/ops/linalg",
"@six_archive//:six",
],
)
py_library( py_library(
name = "loss_functions", name = "loss_functions",
srcs = ["loss_functions.py"], srcs = ["loss_functions.py"],

View File

@ -57,8 +57,8 @@ def make_fisher_estimator(placement_strategy=None, **kwargs):
if placement_strategy in [None, "round_robin"]: if placement_strategy in [None, "round_robin"]:
return FisherEstimatorRoundRobin(**kwargs) return FisherEstimatorRoundRobin(**kwargs)
else: else:
raise ValueError("Unimplemented vars and ops placement strategy : %s", raise ValueError("Unimplemented vars and ops "
placement_strategy) "placement strategy : {}".format(placement_strategy))
# pylint: enable=abstract-class-instantiated # pylint: enable=abstract-class-instantiated
@ -81,7 +81,9 @@ class FisherEstimator(object):
exps=(-1,), exps=(-1,),
estimation_mode="gradients", estimation_mode="gradients",
colocate_gradients_with_ops=True, colocate_gradients_with_ops=True,
name="FisherEstimator"): name="FisherEstimator",
compute_cholesky=False,
compute_cholesky_inverse=False):
"""Create a FisherEstimator object. """Create a FisherEstimator object.
Args: Args:
@ -124,6 +126,12 @@ class FisherEstimator(object):
name: A string. A name given to this estimator, which is added to the name: A string. A name given to this estimator, which is added to the
variable scope when constructing variables and ops. variable scope when constructing variables and ops.
(Default: "FisherEstimator") (Default: "FisherEstimator")
compute_cholesky: Bool. Whether or not the FisherEstimator will be
able to multiply vectors by the Cholesky factor.
(Default: False)
compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
will be able to multiply vectors by the Cholesky factor inverse.
(Default: False)
Raises: Raises:
ValueError: If no losses have been registered with layer_collection. ValueError: If no losses have been registered with layer_collection.
""" """
@ -142,6 +150,8 @@ class FisherEstimator(object):
self._made_vars = False self._made_vars = False
self._exps = exps self._exps = exps
self._compute_cholesky = compute_cholesky
self._compute_cholesky_inverse = compute_cholesky_inverse
self._name = name self._name = name
@ -300,9 +310,54 @@ class FisherEstimator(object):
A list of (transformed vector, var) pairs in the same order as A list of (transformed vector, var) pairs in the same order as
vecs_and_vars. vecs_and_vars.
""" """
assert exp in self._exps
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
return self._apply_transformation(vecs_and_vars, fcn) return self._apply_transformation(vecs_and_vars, fcn)
def multiply_cholesky(self, vecs_and_vars, transpose=False):
"""Multiplies the vecs by the corresponding Cholesky factors.
Args:
vecs_and_vars: List of (vector, variable) pairs.
transpose: Bool. If true the Cholesky factors are transposed before
multiplying the vecs. (Default: False)
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
assert self._compute_cholesky
fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
return self._apply_transformation(vecs_and_vars, fcn)
def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
"""Mults the vecs by the inverses of the corresponding Cholesky factors.
Note: if you are using Cholesky inverse multiplication to sample from
a matrix-variate Gaussian you will want to multiply by the transpose.
Let L be the Cholesky factor of F and observe that
L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
Thus we want to multiply by L^-T in order to sample from Gaussian with
covariance F^-1.
Args:
vecs_and_vars: List of (vector, variable) pairs.
transpose: Bool. If true the Cholesky factor inverses are transposed
before multiplying the vecs. (Default: False)
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
assert self._compute_cholesky_inverse
fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
return self._apply_transformation(vecs_and_vars, fcn)
def _instantiate_factors(self): def _instantiate_factors(self):
"""Instantiates FisherFactors' variables. """Instantiates FisherFactors' variables.
@ -333,9 +388,13 @@ class FisherEstimator(object):
return self._made_vars return self._made_vars
def _register_matrix_functions(self): def _register_matrix_functions(self):
for exp in self._exps: for block in self.blocks:
for block in self.blocks: for exp in self._exps:
block.register_matpower(exp) block.register_matpower(exp)
if self._compute_cholesky:
block.register_cholesky()
if self._compute_cholesky_inverse:
block.register_cholesky_inverse()
def _finalize_layer_collection(self): def _finalize_layer_collection(self):
self._layers.create_subgraph() self._layers.create_subgraph()

View File

@ -25,6 +25,7 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [ _allowed_symbols = [
'FisherEstimator', 'FisherEstimator',
'make_fisher_estimator',
] ]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -83,34 +83,22 @@ def normalize_damping(damping, num_replications):
def compute_pi_tracenorm(left_cov, right_cov): def compute_pi_tracenorm(left_cov, right_cov):
"""Computes the scalar constant pi for Tikhonov regularization/damping. r"""Computes the scalar constant pi for Tikhonov regularization/damping.
$$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
Args: Args:
left_cov: The left Kronecker factor "covariance". left_cov: A LinearOperator object. The left Kronecker factor "covariance".
right_cov: The right Kronecker factor "covariance". right_cov: A LinearOperator object. The right Kronecker factor "covariance".
Returns: Returns:
The computed scalar constant pi for these Kronecker Factors (as a Tensor). The computed scalar constant pi for these Kronecker Factors (as a Tensor).
""" """
def _trace(cov):
if len(cov.shape) == 1:
# Diagonal matrix.
return math_ops.reduce_sum(cov)
elif len(cov.shape) == 2:
# Full matrix.
return math_ops.trace(cov)
else:
raise ValueError(
"What's the trace of a Tensor of rank %d?" % len(cov.shape))
# Instead of dividing by the dim of the norm, we multiply by the dim of the # Instead of dividing by the dim of the norm, we multiply by the dim of the
# other norm. This works out the same in the ratio. # other norm. This works out the same in the ratio.
left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] left_norm = left_cov.trace() * int(right_cov.domain_dimension)
right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] right_norm = right_cov.trace() * int(left_cov.domain_dimension)
return math_ops.sqrt(left_norm / right_norm) return math_ops.sqrt(left_norm / right_norm)
@ -188,6 +176,16 @@ class FisherBlock(object):
""" """
pass pass
@abc.abstractmethod
def register_cholesky(self):
"""Registers a Cholesky factor to be computed by the block."""
pass
@abc.abstractmethod
def register_cholesky_inverse(self):
"""Registers an inverse Cholesky factor to be computed by the block."""
pass
def register_inverse(self): def register_inverse(self):
"""Registers a matrix inverse to be computed by the block.""" """Registers a matrix inverse to be computed by the block."""
self.register_matpower(-1) self.register_matpower(-1)
@ -228,6 +226,33 @@ class FisherBlock(object):
""" """
return self.multiply_matpower(vector, 1) return self.multiply_matpower(vector, 1)
@abc.abstractmethod
def multiply_cholesky(self, vector, transpose=False):
"""Multiplies the vector by the (damped) Cholesky-factor of the block.
Args:
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
transpose: Bool. If true the Cholesky factor is transposed before
multiplying the vector. (Default: False)
Returns:
The vector left-multiplied by the (damped) Cholesky-factor of the block.
"""
pass
@abc.abstractmethod
def multiply_cholesky_inverse(self, vector, transpose=False):
"""Multiplies vector by the (damped) inverse Cholesky-factor of the block.
Args:
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
transpose: Bool. If true the Cholesky factor inverse is transposed
before multiplying the vector. (Default: False)
Returns:
Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
"""
pass
@abc.abstractmethod @abc.abstractmethod
def tensors_to_compute_grads(self): def tensors_to_compute_grads(self):
"""Returns the Tensor(s) with respect to which this FisherBlock needs grads. """Returns the Tensor(s) with respect to which this FisherBlock needs grads.
@ -275,15 +300,32 @@ class FullFB(FisherBlock):
def register_matpower(self, exp): def register_matpower(self, exp):
self._factor.register_matpower(exp, self._damping_func) self._factor.register_matpower(exp, self._damping_func)
def multiply_matpower(self, vector, exp): def register_cholesky(self):
self._factor.register_cholesky(self._damping_func)
def register_cholesky_inverse(self):
self._factor.register_cholesky_inverse(self._damping_func)
def _multiply_matrix(self, matrix, vector, transpose=False):
vector_flat = utils.tensors_to_column(vector) vector_flat = utils.tensors_to_column(vector)
out_flat = self._factor.left_multiply_matpower( out_flat = matrix.matmul(vector_flat, adjoint=transpose)
vector_flat, exp, self._damping_func)
return utils.column_to_tensors(vector, out_flat) return utils.column_to_tensors(vector, out_flat)
def multiply_matpower(self, vector, exp):
matrix = self._factor.get_matpower(exp, self._damping_func)
return self._multiply_matrix(matrix, vector)
def multiply_cholesky(self, vector, transpose=False):
matrix = self._factor.get_cholesky(self._damping_func)
return self._multiply_matrix(matrix, vector, transpose=transpose)
def multiply_cholesky_inverse(self, vector, transpose=False):
matrix = self._factor.get_cholesky_inverse(self._damping_func)
return self._multiply_matrix(matrix, vector, transpose=transpose)
def full_fisher_block(self): def full_fisher_block(self):
"""Explicitly constructs the full Fisher block.""" """Explicitly constructs the full Fisher block."""
return self._factor.get_cov() return self._factor.get_cov_as_linear_operator().to_dense()
def tensors_to_compute_grads(self): def tensors_to_compute_grads(self):
return self._params return self._params
@ -305,7 +347,47 @@ class FullFB(FisherBlock):
return math_ops.reduce_sum(self._batch_sizes) return math_ops.reduce_sum(self._batch_sizes)
class NaiveDiagonalFB(FisherBlock): @six.add_metaclass(abc.ABCMeta)
class DiagonalFB(FisherBlock):
"""A base class for FisherBlocks that use diagonal approximations."""
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def register_cholesky(self):
# Not needed for this. Cholesky's are computed on demand in the
# diagonal case
pass
def register_cholesky_inverse(self):
# Not needed for this. Cholesky inverses's are computed on demand in the
# diagonal case
pass
def _multiply_matrix(self, matrix, vector):
vector_flat = utils.tensors_to_column(vector)
out_flat = matrix.matmul(vector_flat)
return utils.column_to_tensors(vector, out_flat)
def multiply_matpower(self, vector, exp):
matrix = self._factor.get_matpower(exp, self._damping_func)
return self._multiply_matrix(matrix, vector)
def multiply_cholesky(self, vector, transpose=False):
matrix = self._factor.get_cholesky(self._damping_func)
return self._multiply_matrix(matrix, vector)
def multiply_cholesky_inverse(self, vector, transpose=False):
matrix = self._factor.get_cholesky_inverse(self._damping_func)
return self._multiply_matrix(matrix, vector)
def full_fisher_block(self):
return self._factor.get_cov_as_linear_operator().to_dense()
class NaiveDiagonalFB(DiagonalFB):
"""FisherBlock using a diagonal matrix approximation. """FisherBlock using a diagonal matrix approximation.
This type of approximation is generically applicable but quite primitive. This type of approximation is generically applicable but quite primitive.
@ -333,20 +415,6 @@ class NaiveDiagonalFB(FisherBlock):
self._factor = self._layer_collection.make_or_get_factor( self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply_matpower(self, vector, exp):
vector_flat = utils.tensors_to_column(vector)
out_flat = self._factor.left_multiply_matpower(
vector_flat, exp, self._damping_func)
return utils.column_to_tensors(vector, out_flat)
def full_fisher_block(self):
return self._factor.get_cov()
def tensors_to_compute_grads(self): def tensors_to_compute_grads(self):
return self._params return self._params
@ -452,7 +520,7 @@ class InputOutputMultiTower(object):
return self.__outputs return self.__outputs
class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx. """FisherBlock for fully-connected (dense) layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a fully Estimates the Fisher Information matrix's diagonal entries for a fully
@ -497,32 +565,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
self._damping_func = _package_func(lambda: damping, (damping,)) self._damping_func = _package_func(lambda: damping, (damping,))
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply_matpower(self, vector, exp): class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
"""Multiplies the vector by the (damped) matrix-power of the block.
Args:
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
[input_size, output_size] corresponding to layer's weights. If not, a
2-tuple of the former and a Tensor of shape [output_size] corresponding
to the layer's bias.
exp: A scalar representing the power to raise the block before multiplying
it by the vector.
Returns:
The vector left-multiplied by the (damped) matrix-power of the block.
"""
reshaped_vec = utils.layer_params_to_mat2d(vector)
reshaped_out = self._factor.left_multiply_matpower(
reshaped_vec, exp, self._damping_func)
return utils.mat2d_to_layer_params(vector, reshaped_out)
class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
"""FisherBlock for 2-D convolutional layers using a diagonal approx. """FisherBlock for 2-D convolutional layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a convolutional Estimates the Fisher Information matrix's diagonal entries for a convolutional
@ -621,17 +665,6 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
self._num_locations) self._num_locations)
self._damping_func = _package_func(damping_func, damping_id) self._damping_func = _package_func(damping_func, damping_id)
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply_matpower(self, vector, exp):
reshaped_vect = utils.layer_params_to_mat2d(vector)
reshaped_out = self._factor.left_multiply_matpower(
reshaped_vect, exp, self._damping_func)
return utils.mat2d_to_layer_params(vector, reshaped_out)
class KroneckerProductFB(FisherBlock): class KroneckerProductFB(FisherBlock):
"""A base class for blocks with separate input and output Kronecker factors. """A base class for blocks with separate input and output Kronecker factors.
@ -651,9 +684,10 @@ class KroneckerProductFB(FisherBlock):
else: else:
maybe_normalized_damping = damping maybe_normalized_damping = damping
return compute_pi_adjusted_damping(self._input_factor.get_cov(), return compute_pi_adjusted_damping(
self._output_factor.get_cov(), self._input_factor.get_cov_as_linear_operator(),
maybe_normalized_damping**0.5) self._output_factor.get_cov_as_linear_operator(),
maybe_normalized_damping**0.5)
if normalization is not None: if normalization is not None:
damping_id = ("compute_pi_adjusted_damping", damping_id = ("compute_pi_adjusted_damping",
@ -675,6 +709,14 @@ class KroneckerProductFB(FisherBlock):
self._input_factor.register_matpower(exp, self._input_damping_func) self._input_factor.register_matpower(exp, self._input_damping_func)
self._output_factor.register_matpower(exp, self._output_damping_func) self._output_factor.register_matpower(exp, self._output_damping_func)
def register_cholesky(self):
self._input_factor.register_cholesky(self._input_damping_func)
self._output_factor.register_cholesky(self._output_damping_func)
def register_cholesky_inverse(self):
self._input_factor.register_cholesky_inverse(self._input_damping_func)
self._output_factor.register_cholesky_inverse(self._output_damping_func)
@property @property
def _renorm_coeff(self): def _renorm_coeff(self):
"""Kronecker factor multiplier coefficient. """Kronecker factor multiplier coefficient.
@ -687,17 +729,47 @@ class KroneckerProductFB(FisherBlock):
""" """
return 1.0 return 1.0
def multiply_matpower(self, vector, exp): def _multiply_factored_matrix(self, left_factor, right_factor, vector,
extra_scale=1.0, transpose_left=False,
transpose_right=False):
reshaped_vector = utils.layer_params_to_mat2d(vector) reshaped_vector = utils.layer_params_to_mat2d(vector)
reshaped_out = self._output_factor.right_multiply_matpower( reshaped_out = right_factor.matmul_right(reshaped_vector,
reshaped_vector, exp, self._output_damping_func) adjoint=transpose_right)
reshaped_out = self._input_factor.left_multiply_matpower( reshaped_out = left_factor.matmul(reshaped_out,
reshaped_out, exp, self._input_damping_func) adjoint=transpose_left)
if self._renorm_coeff != 1.0: if extra_scale != 1.0:
renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype) reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out) return utils.mat2d_to_layer_params(vector, reshaped_out)
def multiply_matpower(self, vector, exp):
left_factor = self._input_factor.get_matpower(
exp, self._input_damping_func)
right_factor = self._output_factor.get_matpower(
exp, self._output_damping_func)
extra_scale = float(self._renorm_coeff)**exp
return self._multiply_factored_matrix(left_factor, right_factor, vector,
extra_scale=extra_scale)
def multiply_cholesky(self, vector, transpose=False):
left_factor = self._input_factor.get_cholesky(self._input_damping_func)
right_factor = self._output_factor.get_cholesky(self._output_damping_func)
extra_scale = float(self._renorm_coeff)**0.5
return self._multiply_factored_matrix(left_factor, right_factor, vector,
extra_scale=extra_scale,
transpose_left=transpose,
transpose_right=not transpose)
def multiply_cholesky_inverse(self, vector, transpose=False):
left_factor = self._input_factor.get_cholesky_inverse(
self._input_damping_func)
right_factor = self._output_factor.get_cholesky_inverse(
self._output_damping_func)
extra_scale = float(self._renorm_coeff)**-0.5
return self._multiply_factored_matrix(left_factor, right_factor, vector,
extra_scale=extra_scale,
transpose_left=transpose,
transpose_right=not transpose)
def full_fisher_block(self): def full_fisher_block(self):
"""Explicitly constructs the full Fisher block. """Explicitly constructs the full Fisher block.
@ -706,8 +778,8 @@ class KroneckerProductFB(FisherBlock):
Returns: Returns:
The full Fisher block. The full Fisher block.
""" """
left_factor = self._input_factor.get_cov() left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
right_factor = self._output_factor.get_cov() right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
return self._renorm_coeff * utils.kronecker_product(left_factor, return self._renorm_coeff * utils.kronecker_product(left_factor,
right_factor) right_factor)
@ -796,7 +868,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
"""FisherBlock for convolutional layers using the basic KFC approx. r"""FisherBlock for convolutional layers using the basic KFC approx.
Estimates the Fisher Information matrix's blog for a convolutional Estimates the Fisher Information matrix's blog for a convolutional
layer. layer.
@ -945,10 +1017,10 @@ class DepthwiseConvDiagonalFB(ConvDiagonalFB):
self._filter_shape = (filter_height, filter_width, in_channels, self._filter_shape = (filter_height, filter_width, in_channels,
in_channels * channel_multiplier) in_channels * channel_multiplier)
def multiply_matpower(self, vector, exp): def _multiply_matrix(self, matrix, vector):
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower( conv2d_result = super(
conv2d_vector, exp) DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
@ -1016,10 +1088,14 @@ class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
self._filter_shape = (filter_height, filter_width, in_channels, self._filter_shape = (filter_height, filter_width, in_channels,
in_channels * channel_multiplier) in_channels * channel_multiplier)
def multiply_matpower(self, vector, exp): def _multiply_factored_matrix(self, left_factor, right_factor, vector,
extra_scale=1.0, transpose_left=False,
transpose_right=False):
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower( conv2d_result = super(
conv2d_vector, exp) DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
transpose_left=transpose_left, transpose_right=transpose_right)
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
@ -1664,3 +1740,12 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
return utils.mat2d_to_layer_params(vector, Z) return utils.mat2d_to_layer_params(vector, Z)
# pylint: enable=invalid-name # pylint: enable=invalid-name
def multiply_cholesky(self, vector):
raise NotImplementedError("FullyConnectedSeriesFB does not support "
"Cholesky computations.")
def multiply_cholesky_inverse(self, vector):
raise NotImplementedError("FullyConnectedSeriesFB does not support "
"Cholesky computations.")

View File

@ -24,6 +24,7 @@ import contextlib
import numpy as np import numpy as np
import six import six
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
from tensorflow.contrib.kfac.python.ops import utils from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops from tensorflow.python.framework import ops as tf_ops
@ -399,7 +400,7 @@ class FisherFactor(object):
the cov update. the cov update.
Returns: Returns:
Tensor of same shape as self.get_cov_var(). Tensor of same shape as self.get_cov().
""" """
pass pass
@ -448,78 +449,43 @@ class FisherFactor(object):
"""Create and return update ops corresponding to registered computations.""" """Create and return update ops corresponding to registered computations."""
pass pass
@abc.abstractmethod
def get_cov(self): def get_cov(self):
"""Get full covariance matrix.
Returns:
Tensor of shape [n, n]. Represents all parameter-parameter correlations
captured by this FisherFactor.
"""
pass
def get_cov_var(self):
"""Get variable backing this FisherFactor.
May or may not be the same as self.get_cov()
Returns:
Variable of shape self._cov_shape.
"""
return self._cov return self._cov
@abc.abstractmethod @abc.abstractmethod
def left_multiply_matpower(self, x, exp, damping_func): def get_cov_as_linear_operator(self):
"""Left multiplies 'x' by matrix power of this factor (w/ damping applied).
This calculation is essentially:
(C + damping * I)**exp * x
where * is matrix-multiplication, ** is matrix power, I is the identity
matrix, and C is the matrix represented by this factor.
x can represent either a matrix or a vector. For some factors, 'x' might
represent a vector but actually be stored as a 2D matrix for convenience.
Args:
x: Tensor. Represents a single vector. Shape depends on implementation.
exp: float. The matrix exponent to use.
damping_func: A function that computes a 0-D Tensor or a float which will
be the damping value used. i.e. damping = damping_func().
Returns:
Tensor of same shape as 'x' representing the result of the multiplication.
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def right_multiply_matpower(self, x, exp, damping_func): def register_matpower(self, exp, damping_func):
"""Right multiplies 'x' by matrix power of this factor (w/ damping applied). pass
This calculation is essentially: @abc.abstractmethod
x * (C + damping * I)**exp def register_cholesky(self, damping_func):
where * is matrix-multiplication, ** is matrix power, I is the identity pass
matrix, and C is the matrix represented by this factor.
Unlike left_multiply_matpower, x will always be a matrix. @abc.abstractmethod
def register_cholesky_inverse(self, damping_func):
pass
Args: @abc.abstractmethod
x: Tensor. Represents a single vector. Shape depends on implementation. def get_matpower(self, exp, damping_func):
exp: float. The matrix exponent to use. pass
damping_func: A function that computes a 0-D Tensor or a float which will
be the damping value used. i.e. damping = damping_func().
Returns: @abc.abstractmethod
Tensor of same shape as 'x' representing the result of the multiplication. def get_cholesky(self, damping_func):
""" pass
@abc.abstractmethod
def get_cholesky_inverse(self, damping_func):
pass pass
class InverseProvidingFactor(FisherFactor): class DenseSquareMatrixFactor(FisherFactor):
"""Base class for FisherFactors that maintain inverses explicitly. """Base class for FisherFactors that are stored as dense square matrices.
This class explicitly calculates and stores inverses of covariance matrices This class explicitly calculates and stores inverses of their `cov` matrices,
provided by the underlying FisherFactor implementation. It is assumed that which must be square dense matrices.
vectors can be represented as 2-D matrices.
Subclasses must implement the _compute_new_cov method, and the _var_scope and Subclasses must implement the _compute_new_cov method, and the _var_scope and
_cov_shape properties. _cov_shape properties.
@ -538,7 +504,19 @@ class InverseProvidingFactor(FisherFactor):
self._eigendecomp = None self._eigendecomp = None
self._damping_funcs_by_id = {} # {hashable: lambda} self._damping_funcs_by_id = {} # {hashable: lambda}
super(InverseProvidingFactor, self).__init__() self._cholesky_registrations = set() # { hashable }
self._cholesky_inverse_registrations = set() # { hashable }
self._cholesky_by_damping = {} # { hashable: variable }
self._cholesky_inverse_by_damping = {} # { hashable: variable }
super(DenseSquareMatrixFactor, self).__init__()
def get_cov_as_linear_operator(self):
assert self.get_cov().shape.ndims == 2
return lo.LinearOperatorFullMatrix(self.get_cov(),
is_self_adjoint=True,
is_square=True)
def _register_damping(self, damping_func): def _register_damping(self, damping_func):
damping_id = graph_func_to_id(damping_func) damping_id = graph_func_to_id(damping_func)
@ -563,8 +541,6 @@ class InverseProvidingFactor(FisherFactor):
be the damping value used. i.e. damping = damping_func(). be the damping value used. i.e. damping = damping_func().
""" """
if exp == 1.0: if exp == 1.0:
# We don't register these. The user shouldn't even be calling this
# function with exp = 1.0.
return return
damping_id = self._register_damping(damping_func) damping_id = self._register_damping(damping_func)
@ -572,6 +548,38 @@ class InverseProvidingFactor(FisherFactor):
if (exp, damping_id) not in self._matpower_registrations: if (exp, damping_id) not in self._matpower_registrations:
self._matpower_registrations.add((exp, damping_id)) self._matpower_registrations.add((exp, damping_id))
def register_cholesky(self, damping_func):
"""Registers a Cholesky factor to be maintained and served on demand.
This creates a variable and signals make_inverse_update_ops to make the
corresponding update op. The variable can be read via the method
get_cholesky.
Args:
damping_func: A function that computes a 0-D Tensor or a float which will
be the damping value used. i.e. damping = damping_func().
"""
damping_id = self._register_damping(damping_func)
if damping_id not in self._cholesky_registrations:
self._cholesky_registrations.add(damping_id)
def register_cholesky_inverse(self, damping_func):
"""Registers an inverse Cholesky factor to be maintained/served on demand.
This creates a variable and signals make_inverse_update_ops to make the
corresponding update op. The variable can be read via the method
get_cholesky_inverse.
Args:
damping_func: A function that computes a 0-D Tensor or a float which will
be the damping value used. i.e. damping = damping_func().
"""
damping_id = self._register_damping(damping_func)
if damping_id not in self._cholesky_inverse_registrations:
self._cholesky_inverse_registrations.add(damping_id)
def instantiate_inv_variables(self): def instantiate_inv_variables(self):
"""Makes the internal "inverse" variable(s).""" """Makes the internal "inverse" variable(s)."""
@ -589,6 +597,32 @@ class InverseProvidingFactor(FisherFactor):
assert (exp, damping_id) not in self._matpower_by_exp_and_damping assert (exp, damping_id) not in self._matpower_by_exp_and_damping
self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
for damping_id in self._cholesky_registrations:
damping_func = self._damping_funcs_by_id[damping_id]
damping_string = graph_func_to_string(damping_func)
with variable_scope.variable_scope(self._var_scope):
chol = variable_scope.get_variable(
"cholesky_damp{}".format(damping_string),
initializer=inverse_initializer,
shape=self._cov_shape,
trainable=False,
dtype=self._dtype)
assert damping_id not in self._cholesky_by_damping
self._cholesky_by_damping[damping_id] = chol
for damping_id in self._cholesky_inverse_registrations:
damping_func = self._damping_funcs_by_id[damping_id]
damping_string = graph_func_to_string(damping_func)
with variable_scope.variable_scope(self._var_scope):
cholinv = variable_scope.get_variable(
"cholesky_inverse_damp{}".format(damping_string),
initializer=inverse_initializer,
shape=self._cov_shape,
trainable=False,
dtype=self._dtype)
assert damping_id not in self._cholesky_inverse_by_damping
self._cholesky_inverse_by_damping[damping_id] = cholinv
def make_inverse_update_ops(self): def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations.""" """Create and return update ops corresponding to registered computations."""
ops = [] ops = []
@ -606,7 +640,8 @@ class InverseProvidingFactor(FisherFactor):
# We precompute these so we don't need to evaluate them multiple times (for # We precompute these so we don't need to evaluate them multiple times (for
# each matrix power that uses them) # each matrix power that uses them)
damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]() damping_value_by_id = {damping_id: math_ops.cast(
self._damping_funcs_by_id[damping_id](), self._dtype)
for damping_id in self._damping_funcs_by_id} for damping_id in self._damping_funcs_by_id}
if use_eig: if use_eig:
@ -627,29 +662,91 @@ class InverseProvidingFactor(FisherFactor):
self._matpower_by_exp_and_damping.items()): self._matpower_by_exp_and_damping.items()):
assert exp == -1 assert exp == -1
damping = damping_value_by_id[damping_id] damping = damping_value_by_id[damping_id]
ops.append(matpower.assign(utils.posdef_inv(self._cov, damping))) ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
# TODO(b/77902055): If inverses are being computed with Cholesky's
# we can share the work. Instead this code currently just computes the
# Cholesky a second time. It does at least share work between requests for
# Cholesky's and Cholesky inverses with the same damping id.
for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
cholesky_ops = []
damping = damping_value_by_id[damping_id]
cholesky_value = utils.cholesky(self.get_cov(), damping)
if damping_id in self._cholesky_by_damping:
cholesky = self._cholesky_by_damping[damping_id]
cholesky_ops.append(cholesky.assign(cholesky_value))
identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
dtype=cholesky_value.dtype)
cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
identity)
cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
ops.append(control_flow_ops.group(*cholesky_ops))
for damping_id, cholesky in self._cholesky_by_damping.items():
if damping_id not in self._cholesky_inverse_by_damping:
damping = damping_value_by_id[damping_id]
cholesky_value = utils.cholesky(self.get_cov(), damping)
ops.append(cholesky.assign(cholesky_value))
self._eigendecomp = False self._eigendecomp = False
return ops return ops
def get_inverse(self, damping_func): def get_inverse(self, damping_func):
# Just for backwards compatibility of some old code and tests # Just for backwards compatibility of some old code and tests
damping_id = graph_func_to_id(damping_func) return self.get_matpower(-1, damping_func)
return self._matpower_by_exp_and_damping[(-1, damping_id)]
def get_matpower(self, exp, damping_func): def get_matpower(self, exp, damping_func):
# Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of
# get_cov().
if exp != 1:
damping_id = graph_func_to_id(damping_func)
matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
else:
matpower = self.get_cov()
identity = linalg_ops.eye(matpower.shape.as_list()[0],
dtype=matpower.dtype)
matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
assert matpower.shape.ndims == 2
return lo.LinearOperatorFullMatrix(matpower,
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True,
is_square=True)
def get_cholesky(self, damping_func):
# Note that this function returns a variable which gets updated by the # Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of # inverse ops. It may be stale / inconsistent with the latest value of
# get_cov(). # get_cov().
damping_id = graph_func_to_id(damping_func) damping_id = graph_func_to_id(damping_func)
return self._matpower_by_exp_and_damping[(exp, damping_id)] cholesky = self._cholesky_by_damping[damping_id]
assert cholesky.shape.ndims == 2
return lo.LinearOperatorFullMatrix(cholesky,
is_non_singular=True,
is_square=True)
def get_cholesky_inverse(self, damping_func):
# Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of
# get_cov().
damping_id = graph_func_to_id(damping_func)
cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
assert cholesky_inv.shape.ndims == 2
return lo.LinearOperatorFullMatrix(cholesky_inv,
is_non_singular=True,
is_square=True)
def get_eigendecomp(self): def get_eigendecomp(self):
"""Creates or retrieves eigendecomposition of self._cov.""" """Creates or retrieves eigendecomposition of self._cov."""
# Unlike get_matpower this doesn't retrieve a stored variable, but instead # Unlike get_matpower this doesn't retrieve a stored variable, but instead
# always computes a fresh version from the current value of get_cov(). # always computes a fresh version from the current value of get_cov().
if not self._eigendecomp: if not self._eigendecomp:
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
# The matrix self._cov is positive semidefinite by construction, but the # The matrix self._cov is positive semidefinite by construction, but the
# numerical eigenvalues could be negative due to numerical errors, so here # numerical eigenvalues could be negative due to numerical errors, so here
@ -660,45 +757,8 @@ class InverseProvidingFactor(FisherFactor):
return self._eigendecomp return self._eigendecomp
def get_cov(self):
# Variable contains full covariance matrix.
return self.get_cov_var()
def left_multiply_matpower(self, x, exp, damping_func): class FullFactor(DenseSquareMatrixFactor):
if isinstance(x, tf_ops.IndexedSlices):
raise ValueError("Left-multiply not yet supported for IndexedSlices.")
if x.shape.ndims != 2:
raise ValueError(
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
% (x,))
if exp == 1:
return math_ops.matmul(self.get_cov(), x) + damping_func() * x
return math_ops.matmul(self.get_matpower(exp, damping_func), x)
def right_multiply_matpower(self, x, exp, damping_func):
if isinstance(x, tf_ops.IndexedSlices):
if exp == 1:
n = self.get_cov().shape[0]
damped_cov = self.get_cov() + damping_func() * array_ops.eye(n)
return utils.matmul_sparse_dense(x, damped_cov)
return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func))
if x.shape.ndims != 2:
raise ValueError(
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
% (x,))
if exp == 1:
return math_ops.matmul(x, self.get_cov()) + damping_func() * x
return math_ops.matmul(x, self.get_matpower(exp, damping_func))
class FullFactor(InverseProvidingFactor):
"""FisherFactor for a full matrix representation of the Fisher of a parameter. """FisherFactor for a full matrix representation of the Fisher of a parameter.
Note that this uses the naive "square the sum estimator", and so is applicable Note that this uses the naive "square the sum estimator", and so is applicable
@ -757,42 +817,52 @@ class DiagonalFactor(FisherFactor):
""" """
def __init__(self): def __init__(self):
self._damping_funcs_by_id = {} # { hashable: lambda }
super(DiagonalFactor, self).__init__() super(DiagonalFactor, self).__init__()
def get_cov_as_linear_operator(self):
assert self._matrix_diagonal.shape.ndims == 1
return lo.LinearOperatorDiag(self._matrix_diagonal,
is_self_adjoint=True,
is_square=True)
@property @property
def _cov_initializer(self): def _cov_initializer(self):
return diagonal_covariance_initializer return diagonal_covariance_initializer
@property
def _matrix_diagonal(self):
return array_ops.reshape(self.get_cov(), [-1])
def make_inverse_update_ops(self): def make_inverse_update_ops(self):
return [] return []
def instantiate_inv_variables(self): def instantiate_inv_variables(self):
pass pass
def get_cov(self):
# self.get_cov() could be any shape, but it must have one entry per
# parameter. Flatten it into a vector.
cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1])
return array_ops.diag(cov_diag_vec)
def left_multiply_matpower(self, x, exp, damping_func):
matpower = (self.get_cov_var() + damping_func())**exp
if isinstance(x, tf_ops.IndexedSlices):
return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x)
if x.shape != matpower.shape:
raise ValueError("x (%s) and cov (%s) must have same shape." %
(x, matpower))
return matpower * x
def right_multiply_matpower(self, x, exp, damping_func):
raise NotImplementedError("Only left-multiply is currently supported.")
def register_matpower(self, exp, damping_func): def register_matpower(self, exp, damping_func):
pass pass
def register_cholesky(self, damping_func):
pass
def register_cholesky_inverse(self, damping_func):
pass
def get_matpower(self, exp, damping_func):
matpower_diagonal = (self._matrix_diagonal
+ math_ops.cast(damping_func(), self._dtype))**exp
return lo.LinearOperatorDiag(matpower_diagonal,
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True,
is_square=True)
def get_cholesky(self, damping_func):
return self.get_matpower(0.5, damping_func)
def get_cholesky_inverse(self, damping_func):
return self.get_matpower(-0.5, damping_func)
class NaiveDiagonalFactor(DiagonalFactor): class NaiveDiagonalFactor(DiagonalFactor):
"""FisherFactor for a diagonal approximation of any type of param's Fisher. """FisherFactor for a diagonal approximation of any type of param's Fisher.
@ -1167,7 +1237,7 @@ class ConvDiagonalFactor(DiagonalFactor):
return self._inputs[tower].device return self._inputs[tower].device
class FullyConnectedKroneckerFactor(InverseProvidingFactor): class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
"""Kronecker factor for the input or output side of a fully-connected layer. """Kronecker factor for the input or output side of a fully-connected layer.
""" """
@ -1220,7 +1290,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
return self._tensors[0][tower].device return self._tensors[0][tower].device
class ConvInputKroneckerFactor(InverseProvidingFactor): class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
r"""Kronecker factor for the input side of a convolutional layer. r"""Kronecker factor for the input side of a convolutional layer.
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
@ -1384,7 +1454,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
return self._inputs[tower].device return self._inputs[tower].device
class ConvOutputKroneckerFactor(InverseProvidingFactor): class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
r"""Kronecker factor for the output side of a convolutional layer. r"""Kronecker factor for the output side of a convolutional layer.
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
@ -1674,6 +1744,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
psi_var) in self._option1quants_by_damping.items(): psi_var) in self._option1quants_by_damping.items():
damping = self._damping_funcs_by_id[damping_id]() damping = self._damping_funcs_by_id[damping_id]()
damping = math_ops.cast(damping, self._dtype)
invsqrtC0 = math_ops.matmul( invsqrtC0 = math_ops.matmul(
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
@ -1702,6 +1773,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
mu_var) in self._option2quants_by_damping.items(): mu_var) in self._option2quants_by_damping.items():
damping = self._damping_funcs_by_id[damping_id]() damping = self._damping_funcs_by_id[damping_id]()
damping = math_ops.cast(damping, self._dtype)
# compute C0^(-1/2) # compute C0^(-1/2)
invsqrtC0 = math_ops.matmul( invsqrtC0 = math_ops.matmul(

View File

@ -0,0 +1,95 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SmartMatrices definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg
from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.ops.linalg import linear_operator_util as lou
class LinearOperatorExtras(object): # pylint: disable=missing-docstring
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
with self._name_scope(name, values=[x]):
if isinstance(x, ops.IndexedSlices):
return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
x = ops.convert_to_tensor(x, name="x")
self._check_input_dtype(x)
self_dim = -2 if adjoint else -1
arg_dim = -1 if adjoint_arg else -2
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
with self._name_scope(name, values=[x]):
if isinstance(x, ops.IndexedSlices):
return self._matmul_right_sparse(
x, adjoint=adjoint, adjoint_arg=adjoint_arg)
x = ops.convert_to_tensor(x, name="x")
self._check_input_dtype(x)
self_dim = -1 if adjoint else -2
arg_dim = -2 if adjoint_arg else -1
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
class LinearOperatorFullMatrix(LinearOperatorExtras,
linalg.LinearOperatorFullMatrix):
# TODO(b/78117889) Remove this definition once core LinearOperator
# has _matmul_right.
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
return lou.matmul_with_broadcast(
x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
raise NotImplementedError
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
assert not adjoint and not adjoint_arg
return utils.matmul_sparse_dense(x, self._matrix)
class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
linalg.LinearOperatorDiag):
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
x = linalg_impl.adjoint(x) if adjoint_arg else x
return diag_mat * x
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
assert not adjoint_arg
return utils.matmul_diag_sparse(diag_mat, x)
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
raise NotImplementedError

View File

@ -35,7 +35,7 @@ def _make_thunk_on_device(func, device):
class RoundRobinPlacementMixin(object): class RoundRobinPlacementMixin(object):
"""Implements round robin placement strategy for ops and variables.""" """Implements round robin placement strategy for ops and variables."""
def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
"""Initializes the RoundRobinPlacementMixin class. """Initializes the RoundRobinPlacementMixin class.
Args: Args:
@ -45,11 +45,10 @@ class RoundRobinPlacementMixin(object):
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion. computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified. Can be None, which means that no devices are specified.
*args: **kwargs: Need something here?
**kwargs:
""" """
super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs) super(RoundRobinPlacementMixin, self).__init__(**kwargs)
self._cov_devices = cov_devices self._cov_devices = cov_devices
self._inv_devices = inv_devices self._inv_devices = inv_devices

View File

@ -235,6 +235,13 @@ posdef_eig_functions = {
} }
def cholesky(tensor, damping):
"""Computes the inverse of tensor + damping * identity."""
identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
damping = math_ops.cast(damping, dtype=tensor.dtype)
return linalg_ops.cholesky(tensor + damping * identity)
class SubGraph(object): class SubGraph(object):
"""Defines a subgraph given by all the dependencies of a given set of outputs. """Defines a subgraph given by all the dependencies of a given set of outputs.
""" """
@ -553,13 +560,17 @@ def is_data_format_channel_last(data_format):
return data_format.endswith("C") return data_format.endswith("C")
def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
"""Computes matmul(A, B) where A is sparse, B is dense. """Computes matmul(A, B) where A is sparse, B is dense.
Args: Args:
A: tf.IndexedSlices with dense shape [m, n]. A: tf.IndexedSlices with dense shape [m, n].
B: tf.Tensor with shape [n, k]. B: tf.Tensor with shape [n, k].
name: str. Name of op. name: str. Name of op.
transpose_a: Bool. If true we transpose A before multiplying it by B.
(Default: False)
transpose_b: Bool. If true we transpose B before multiplying it by A.
(Default: False)
Returns: Returns:
tf.IndexedSlices resulting from matmul(A, B). tf.IndexedSlices resulting from matmul(A, B).
@ -573,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
raise ValueError("A must represent a matrix. Found: %s." % A) raise ValueError("A must represent a matrix. Found: %s." % A)
if B.shape.ndims != 2: if B.shape.ndims != 2:
raise ValueError("B must be a matrix.") raise ValueError("B must be a matrix.")
new_values = math_ops.matmul(A.values, B) new_values = math_ops.matmul(
A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
return ops.IndexedSlices( return ops.IndexedSlices(
new_values, new_values,
A.indices, A.indices,