- Adding support for Cholesky (inverse) factor multiplications.

- Refactored FisherFactor to use LinearOperator classes that know how to multiply themselves, compute their own trace, etc. This addresses the feature request: b/73356352
- Fixed some problems with FisherEstimator construction
- More careful casting of damping constants before they are used

PiperOrigin-RevId: 194379298
This commit is contained in:
James Martens 2018-04-26 04:37:28 -07:00 committed by TensorFlower Gardener
parent 8148895adc
commit 481f229881
11 changed files with 637 additions and 282 deletions

View File

@ -58,6 +58,7 @@ py_test(
deps = [
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
"//tensorflow/contrib/kfac/python/ops:layer_collection",
"//tensorflow/contrib/kfac/python/ops:linear_operator",
"//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
@ -46,8 +47,9 @@ class UtilsTest(test.TestCase):
def testComputePiTracenorm(self):
with ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
left_factor = array_ops.diag([1., 2., 0., 1.])
right_factor = array_ops.ones([2., 2.])
diag = ops.convert_to_tensor([1., 2., 0., 1.])
left_factor = lo.LinearOperatorDiag(diag)
right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
# pi is the sqrt of the left trace norm divided by the right trace norm
pi = fb.compute_pi_tracenorm(left_factor, right_factor)
@ -245,7 +247,6 @@ class NaiveDiagonalFBTest(test.TestCase):
full = sess.run(block.full_fisher_block())
explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
self.assertAllClose(output_flat, explicit)

View File

@ -70,18 +70,6 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def get_cov(self):
return NotImplementedError
def left_multiply(self, x, damping):
return NotImplementedError
def right_multiply(self, x, damping):
return NotImplementedError
def left_multiply_matpower(self, x, exp, damping):
return NotImplementedError
def right_multiply_matpower(self, x, exp, damping):
return NotImplementedError
def instantiate_inv_variables(self):
return NotImplementedError
@ -91,14 +79,35 @@ class FisherFactorTestingDummy(ff.FisherFactor):
def _get_data_device(self):
raise NotImplementedError
def register_matpower(self, exp, damping_func):
raise NotImplementedError
class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor):
"""Dummy class to test the non-abstract methods on ff.InverseProvidingFactor.
def register_cholesky(self, damping_func):
raise NotImplementedError
def register_cholesky_inverse(self, damping_func):
raise NotImplementedError
def get_matpower(self, exp, damping_func):
raise NotImplementedError
def get_cholesky(self, damping_func):
raise NotImplementedError
def get_cholesky_inverse(self, damping_func):
raise NotImplementedError
def get_cov_as_linear_operator(self):
raise NotImplementedError
class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
"""Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
"""
def __init__(self, shape):
self._shape = shape
super(InverseProvidingFactorTestingDummy, self).__init__()
super(DenseSquareMatrixFactorTestingDummy, self).__init__()
@property
def _var_scope(self):
@ -230,13 +239,13 @@ class FisherFactorTest(test.TestCase):
self.assertEqual(0, len(factor.make_inverse_update_ops()))
class InverseProvidingFactorTest(test.TestCase):
class DenseSquareMatrixFactorTest(test.TestCase):
def testRegisterDampedInverse(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
shape = [2, 2]
factor = InverseProvidingFactorTestingDummy(shape)
factor = DenseSquareMatrixFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
damping_funcs = [make_damping_func(0.1),
@ -248,22 +257,25 @@ class InverseProvidingFactorTest(test.TestCase):
factor.instantiate_inv_variables()
inv = factor.get_inverse(damping_funcs[0])
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]))
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]))
self.assertEqual(factor.get_inverse(damping_funcs[2]),
factor.get_inverse(damping_funcs[3]))
inv = factor.get_inverse(damping_funcs[0]).to_dense()
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
factor.get_inverse(damping_funcs[3]).to_dense())
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]),
set(factor_vars))
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
self.assertEqual(set([inv,
factor.get_inverse(damping_funcs[2]).to_dense()]),
set(factor_tensors))
self.assertEqual(shape, inv.get_shape())
def testRegisterMatpower(self):
with tf_ops.Graph().as_default():
random_seed.set_random_seed(200)
shape = [3, 3]
factor = InverseProvidingFactorTestingDummy(shape)
factor = DenseSquareMatrixFactorTestingDummy(shape)
factor_var_scope = 'dummy/a_b_c'
# TODO(b/74201126): Change to using the same func for both once
@ -278,10 +290,13 @@ class InverseProvidingFactorTest(test.TestCase):
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
factor_var_scope)
matpower1 = factor.get_matpower(-0.5, damping_func_1)
matpower2 = factor.get_matpower(2, damping_func_2)
self.assertEqual(set([matpower1, matpower2]), set(factor_vars))
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
self.assertEqual(shape, matpower1.get_shape())
self.assertEqual(shape, matpower2.get_shape())
@ -297,7 +312,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[1., 2.], [3., 4.]])
factor = InverseProvidingFactorTestingDummy(cov.shape)
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_funcs = []
@ -316,7 +331,8 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(ops)
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
# The inverse op will assign the damped inverse of cov to the inv var.
new_invs.append(sess.run(factor.get_inverse(damping_funcs[i])))
new_invs.append(
sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
# We want to see that the new invs are all different from each other.
for i in range(len(new_invs)):
@ -328,7 +344,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[6., 2.], [2., 4.]])
factor = InverseProvidingFactorTestingDummy(cov.shape)
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
damping = 0.5
@ -341,7 +357,7 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(tf_variables.global_variables_initializer())
sess.run(ops[0])
matpower = sess.run(factor.get_matpower(exp, damping_func))
matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
self.assertAllClose(matpower, matpower_np)
@ -349,7 +365,7 @@ class InverseProvidingFactorTest(test.TestCase):
with tf_ops.Graph().as_default(), self.test_session() as sess:
random_seed.set_random_seed(200)
cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
factor = InverseProvidingFactorTestingDummy(cov.shape)
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
damping_func = make_damping_func(0)
@ -361,12 +377,12 @@ class InverseProvidingFactorTest(test.TestCase):
sess.run(tf_variables.global_variables_initializer())
# The inverse op will assign the damped inverse of cov to the inv var.
old_inv = sess.run(factor.get_inverse(damping_func))
old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
self.assertAllClose(
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
sess.run(ops)
new_inv = sess.run(factor.get_inverse(damping_func))
new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
self.assertAllClose(new_inv, np.linalg.inv(cov))
@ -411,7 +427,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables()
self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list())
self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
def testNaiveDiagonalFactorInitFloat64(self):
with tf_ops.Graph().as_default():
@ -420,7 +436,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
factor = ff.NaiveDiagonalFactor((tensor,), 32)
factor.instantiate_cov_variables()
cov = factor.get_cov_var()
cov = factor.get_cov()
self.assertEqual(cov.dtype, dtype)
self.assertEqual([6, 1], cov.get_shape().as_list())
@ -444,7 +460,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
vocab_size = 5
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
factor.instantiate_cov_variables()
cov = factor.get_cov_var()
cov = factor.get_cov()
self.assertEqual(cov.shape.as_list(), [vocab_size])
def testCovarianceUpdateOp(self):
@ -502,7 +518,7 @@ class ConvDiagonalFactorTest(test.TestCase):
self.kernel_height * self.kernel_width * self.in_channels,
self.out_channels
],
factor.get_cov_var().shape.as_list())
factor.get_cov().shape.as_list())
def testMakeCovarianceUpdateOp(self):
with tf_ops.Graph().as_default():
@ -564,7 +580,7 @@ class ConvDiagonalFactorTest(test.TestCase):
self.kernel_height * self.kernel_width * self.in_channels + 1,
self.out_channels
],
factor.get_cov_var().shape.as_list())
factor.get_cov().shape.as_list())
# Ensure update op doesn't crash.
cov_update_op = factor.make_covariance_update_op(0.0)
@ -654,13 +670,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
# Ensure shape of covariance matches input size of filter.
input_size = in_channels * (width**3)
self.assertEqual([input_size, input_size],
factor.get_cov_var().shape.as_list())
factor.get_cov().shape.as_list())
# Ensure cov_update_op doesn't crash.
with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov_var())
cov = sess.run(factor.get_cov())
# Cov should be rank-8, as the filter will be applied at each corner of
# the 4-D cube.
@ -685,13 +701,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
# Ensure shape of covariance matches input size of filter.
self.assertEqual([in_channels, in_channels],
factor.get_cov_var().shape.as_list())
factor.get_cov().shape.as_list())
# Ensure cov_update_op doesn't crash.
with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov_var())
cov = sess.run(factor.get_cov())
# Cov should be rank-9, as the filter will be applied at each location.
self.assertMatrixRank(9, cov)
@ -716,7 +732,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov_var())
cov = sess.run(factor.get_cov())
# Cov should be the sum of 3 * 2 = 6 outer products.
self.assertMatrixRank(6, cov)
@ -742,7 +758,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
with self.test_session() as sess:
sess.run(tf_variables.global_variables_initializer())
sess.run(factor.make_covariance_update_op(0.0))
cov = sess.run(factor.get_cov_var())
cov = sess.run(factor.get_cov())
# Cov should be rank = in_channels, as only the center of the filter
# receives non-zero input for each input channel.

View File

@ -35,6 +35,7 @@ py_library(
srcs = ["fisher_factors.py"],
srcs_version = "PY2AND3",
deps = [
":linear_operator",
":utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
@ -63,6 +64,19 @@ py_library(
],
)
py_library(
name = "linear_operator",
srcs = ["linear_operator.py"],
srcs_version = "PY2AND3",
deps = [
":utils",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/ops/linalg",
"@six_archive//:six",
],
)
py_library(
name = "loss_functions",
srcs = ["loss_functions.py"],

View File

@ -57,8 +57,8 @@ def make_fisher_estimator(placement_strategy=None, **kwargs):
if placement_strategy in [None, "round_robin"]:
return FisherEstimatorRoundRobin(**kwargs)
else:
raise ValueError("Unimplemented vars and ops placement strategy : %s",
placement_strategy)
raise ValueError("Unimplemented vars and ops "
"placement strategy : {}".format(placement_strategy))
# pylint: enable=abstract-class-instantiated
@ -81,7 +81,9 @@ class FisherEstimator(object):
exps=(-1,),
estimation_mode="gradients",
colocate_gradients_with_ops=True,
name="FisherEstimator"):
name="FisherEstimator",
compute_cholesky=False,
compute_cholesky_inverse=False):
"""Create a FisherEstimator object.
Args:
@ -124,6 +126,12 @@ class FisherEstimator(object):
name: A string. A name given to this estimator, which is added to the
variable scope when constructing variables and ops.
(Default: "FisherEstimator")
compute_cholesky: Bool. Whether or not the FisherEstimator will be
able to multiply vectors by the Cholesky factor.
(Default: False)
compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
will be able to multiply vectors by the Cholesky factor inverse.
(Default: False)
Raises:
ValueError: If no losses have been registered with layer_collection.
"""
@ -142,6 +150,8 @@ class FisherEstimator(object):
self._made_vars = False
self._exps = exps
self._compute_cholesky = compute_cholesky
self._compute_cholesky_inverse = compute_cholesky_inverse
self._name = name
@ -300,9 +310,54 @@ class FisherEstimator(object):
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
assert exp in self._exps
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
return self._apply_transformation(vecs_and_vars, fcn)
def multiply_cholesky(self, vecs_and_vars, transpose=False):
"""Multiplies the vecs by the corresponding Cholesky factors.
Args:
vecs_and_vars: List of (vector, variable) pairs.
transpose: Bool. If true the Cholesky factors are transposed before
multiplying the vecs. (Default: False)
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
assert self._compute_cholesky
fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
return self._apply_transformation(vecs_and_vars, fcn)
def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
"""Mults the vecs by the inverses of the corresponding Cholesky factors.
Note: if you are using Cholesky inverse multiplication to sample from
a matrix-variate Gaussian you will want to multiply by the transpose.
Let L be the Cholesky factor of F and observe that
L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
Thus we want to multiply by L^-T in order to sample from Gaussian with
covariance F^-1.
Args:
vecs_and_vars: List of (vector, variable) pairs.
transpose: Bool. If true the Cholesky factor inverses are transposed
before multiplying the vecs. (Default: False)
Returns:
A list of (transformed vector, var) pairs in the same order as
vecs_and_vars.
"""
assert self._compute_cholesky_inverse
fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
return self._apply_transformation(vecs_and_vars, fcn)
def _instantiate_factors(self):
"""Instantiates FisherFactors' variables.
@ -333,9 +388,13 @@ class FisherEstimator(object):
return self._made_vars
def _register_matrix_functions(self):
for exp in self._exps:
for block in self.blocks:
for block in self.blocks:
for exp in self._exps:
block.register_matpower(exp)
if self._compute_cholesky:
block.register_cholesky()
if self._compute_cholesky_inverse:
block.register_cholesky_inverse()
def _finalize_layer_collection(self):
self._layers.create_subgraph()

View File

@ -25,6 +25,7 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'FisherEstimator',
'make_fisher_estimator',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -83,34 +83,22 @@ def normalize_damping(damping, num_replications):
def compute_pi_tracenorm(left_cov, right_cov):
"""Computes the scalar constant pi for Tikhonov regularization/damping.
r"""Computes the scalar constant pi for Tikhonov regularization/damping.
$$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
Args:
left_cov: The left Kronecker factor "covariance".
right_cov: The right Kronecker factor "covariance".
left_cov: A LinearOperator object. The left Kronecker factor "covariance".
right_cov: A LinearOperator object. The right Kronecker factor "covariance".
Returns:
The computed scalar constant pi for these Kronecker Factors (as a Tensor).
"""
def _trace(cov):
if len(cov.shape) == 1:
# Diagonal matrix.
return math_ops.reduce_sum(cov)
elif len(cov.shape) == 2:
# Full matrix.
return math_ops.trace(cov)
else:
raise ValueError(
"What's the trace of a Tensor of rank %d?" % len(cov.shape))
# Instead of dividing by the dim of the norm, we multiply by the dim of the
# other norm. This works out the same in the ratio.
left_norm = _trace(left_cov) * right_cov.shape.as_list()[0]
right_norm = _trace(right_cov) * left_cov.shape.as_list()[0]
left_norm = left_cov.trace() * int(right_cov.domain_dimension)
right_norm = right_cov.trace() * int(left_cov.domain_dimension)
return math_ops.sqrt(left_norm / right_norm)
@ -188,6 +176,16 @@ class FisherBlock(object):
"""
pass
@abc.abstractmethod
def register_cholesky(self):
"""Registers a Cholesky factor to be computed by the block."""
pass
@abc.abstractmethod
def register_cholesky_inverse(self):
"""Registers an inverse Cholesky factor to be computed by the block."""
pass
def register_inverse(self):
"""Registers a matrix inverse to be computed by the block."""
self.register_matpower(-1)
@ -228,6 +226,33 @@ class FisherBlock(object):
"""
return self.multiply_matpower(vector, 1)
@abc.abstractmethod
def multiply_cholesky(self, vector, transpose=False):
"""Multiplies the vector by the (damped) Cholesky-factor of the block.
Args:
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
transpose: Bool. If true the Cholesky factor is transposed before
multiplying the vector. (Default: False)
Returns:
The vector left-multiplied by the (damped) Cholesky-factor of the block.
"""
pass
@abc.abstractmethod
def multiply_cholesky_inverse(self, vector, transpose=False):
"""Multiplies vector by the (damped) inverse Cholesky-factor of the block.
Args:
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
transpose: Bool. If true the Cholesky factor inverse is transposed
before multiplying the vector. (Default: False)
Returns:
Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
"""
pass
@abc.abstractmethod
def tensors_to_compute_grads(self):
"""Returns the Tensor(s) with respect to which this FisherBlock needs grads.
@ -275,15 +300,32 @@ class FullFB(FisherBlock):
def register_matpower(self, exp):
self._factor.register_matpower(exp, self._damping_func)
def multiply_matpower(self, vector, exp):
def register_cholesky(self):
self._factor.register_cholesky(self._damping_func)
def register_cholesky_inverse(self):
self._factor.register_cholesky_inverse(self._damping_func)
def _multiply_matrix(self, matrix, vector, transpose=False):
vector_flat = utils.tensors_to_column(vector)
out_flat = self._factor.left_multiply_matpower(
vector_flat, exp, self._damping_func)
out_flat = matrix.matmul(vector_flat, adjoint=transpose)
return utils.column_to_tensors(vector, out_flat)
def multiply_matpower(self, vector, exp):
matrix = self._factor.get_matpower(exp, self._damping_func)
return self._multiply_matrix(matrix, vector)
def multiply_cholesky(self, vector, transpose=False):
matrix = self._factor.get_cholesky(self._damping_func)
return self._multiply_matrix(matrix, vector, transpose=transpose)
def multiply_cholesky_inverse(self, vector, transpose=False):
matrix = self._factor.get_cholesky_inverse(self._damping_func)
return self._multiply_matrix(matrix, vector, transpose=transpose)
def full_fisher_block(self):
"""Explicitly constructs the full Fisher block."""
return self._factor.get_cov()
return self._factor.get_cov_as_linear_operator().to_dense()
def tensors_to_compute_grads(self):
return self._params
@ -305,7 +347,47 @@ class FullFB(FisherBlock):
return math_ops.reduce_sum(self._batch_sizes)
class NaiveDiagonalFB(FisherBlock):
@six.add_metaclass(abc.ABCMeta)
class DiagonalFB(FisherBlock):
"""A base class for FisherBlocks that use diagonal approximations."""
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def register_cholesky(self):
# Not needed for this. Cholesky's are computed on demand in the
# diagonal case
pass
def register_cholesky_inverse(self):
# Not needed for this. Cholesky inverses's are computed on demand in the
# diagonal case
pass
def _multiply_matrix(self, matrix, vector):
vector_flat = utils.tensors_to_column(vector)
out_flat = matrix.matmul(vector_flat)
return utils.column_to_tensors(vector, out_flat)
def multiply_matpower(self, vector, exp):
matrix = self._factor.get_matpower(exp, self._damping_func)
return self._multiply_matrix(matrix, vector)
def multiply_cholesky(self, vector, transpose=False):
matrix = self._factor.get_cholesky(self._damping_func)
return self._multiply_matrix(matrix, vector)
def multiply_cholesky_inverse(self, vector, transpose=False):
matrix = self._factor.get_cholesky_inverse(self._damping_func)
return self._multiply_matrix(matrix, vector)
def full_fisher_block(self):
return self._factor.get_cov_as_linear_operator().to_dense()
class NaiveDiagonalFB(DiagonalFB):
"""FisherBlock using a diagonal matrix approximation.
This type of approximation is generically applicable but quite primitive.
@ -333,20 +415,6 @@ class NaiveDiagonalFB(FisherBlock):
self._factor = self._layer_collection.make_or_get_factor(
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply_matpower(self, vector, exp):
vector_flat = utils.tensors_to_column(vector)
out_flat = self._factor.left_multiply_matpower(
vector_flat, exp, self._damping_func)
return utils.column_to_tensors(vector, out_flat)
def full_fisher_block(self):
return self._factor.get_cov()
def tensors_to_compute_grads(self):
return self._params
@ -452,7 +520,7 @@ class InputOutputMultiTower(object):
return self.__outputs
class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a fully
@ -497,32 +565,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
self._damping_func = _package_func(lambda: damping, (damping,))
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply_matpower(self, vector, exp):
"""Multiplies the vector by the (damped) matrix-power of the block.
Args:
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
[input_size, output_size] corresponding to layer's weights. If not, a
2-tuple of the former and a Tensor of shape [output_size] corresponding
to the layer's bias.
exp: A scalar representing the power to raise the block before multiplying
it by the vector.
Returns:
The vector left-multiplied by the (damped) matrix-power of the block.
"""
reshaped_vec = utils.layer_params_to_mat2d(vector)
reshaped_out = self._factor.left_multiply_matpower(
reshaped_vec, exp, self._damping_func)
return utils.mat2d_to_layer_params(vector, reshaped_out)
class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
"""FisherBlock for 2-D convolutional layers using a diagonal approx.
Estimates the Fisher Information matrix's diagonal entries for a convolutional
@ -621,17 +665,6 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
self._num_locations)
self._damping_func = _package_func(damping_func, damping_id)
def register_matpower(self, exp):
# Not needed for this. Matrix powers are computed on demand in the
# diagonal case
pass
def multiply_matpower(self, vector, exp):
reshaped_vect = utils.layer_params_to_mat2d(vector)
reshaped_out = self._factor.left_multiply_matpower(
reshaped_vect, exp, self._damping_func)
return utils.mat2d_to_layer_params(vector, reshaped_out)
class KroneckerProductFB(FisherBlock):
"""A base class for blocks with separate input and output Kronecker factors.
@ -651,9 +684,10 @@ class KroneckerProductFB(FisherBlock):
else:
maybe_normalized_damping = damping
return compute_pi_adjusted_damping(self._input_factor.get_cov(),
self._output_factor.get_cov(),
maybe_normalized_damping**0.5)
return compute_pi_adjusted_damping(
self._input_factor.get_cov_as_linear_operator(),
self._output_factor.get_cov_as_linear_operator(),
maybe_normalized_damping**0.5)
if normalization is not None:
damping_id = ("compute_pi_adjusted_damping",
@ -675,6 +709,14 @@ class KroneckerProductFB(FisherBlock):
self._input_factor.register_matpower(exp, self._input_damping_func)
self._output_factor.register_matpower(exp, self._output_damping_func)
def register_cholesky(self):
self._input_factor.register_cholesky(self._input_damping_func)
self._output_factor.register_cholesky(self._output_damping_func)
def register_cholesky_inverse(self):
self._input_factor.register_cholesky_inverse(self._input_damping_func)
self._output_factor.register_cholesky_inverse(self._output_damping_func)
@property
def _renorm_coeff(self):
"""Kronecker factor multiplier coefficient.
@ -687,17 +729,47 @@ class KroneckerProductFB(FisherBlock):
"""
return 1.0
def multiply_matpower(self, vector, exp):
def _multiply_factored_matrix(self, left_factor, right_factor, vector,
extra_scale=1.0, transpose_left=False,
transpose_right=False):
reshaped_vector = utils.layer_params_to_mat2d(vector)
reshaped_out = self._output_factor.right_multiply_matpower(
reshaped_vector, exp, self._output_damping_func)
reshaped_out = self._input_factor.left_multiply_matpower(
reshaped_out, exp, self._input_damping_func)
if self._renorm_coeff != 1.0:
renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype)
reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype)
reshaped_out = right_factor.matmul_right(reshaped_vector,
adjoint=transpose_right)
reshaped_out = left_factor.matmul(reshaped_out,
adjoint=transpose_left)
if extra_scale != 1.0:
reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
return utils.mat2d_to_layer_params(vector, reshaped_out)
def multiply_matpower(self, vector, exp):
left_factor = self._input_factor.get_matpower(
exp, self._input_damping_func)
right_factor = self._output_factor.get_matpower(
exp, self._output_damping_func)
extra_scale = float(self._renorm_coeff)**exp
return self._multiply_factored_matrix(left_factor, right_factor, vector,
extra_scale=extra_scale)
def multiply_cholesky(self, vector, transpose=False):
left_factor = self._input_factor.get_cholesky(self._input_damping_func)
right_factor = self._output_factor.get_cholesky(self._output_damping_func)
extra_scale = float(self._renorm_coeff)**0.5
return self._multiply_factored_matrix(left_factor, right_factor, vector,
extra_scale=extra_scale,
transpose_left=transpose,
transpose_right=not transpose)
def multiply_cholesky_inverse(self, vector, transpose=False):
left_factor = self._input_factor.get_cholesky_inverse(
self._input_damping_func)
right_factor = self._output_factor.get_cholesky_inverse(
self._output_damping_func)
extra_scale = float(self._renorm_coeff)**-0.5
return self._multiply_factored_matrix(left_factor, right_factor, vector,
extra_scale=extra_scale,
transpose_left=transpose,
transpose_right=not transpose)
def full_fisher_block(self):
"""Explicitly constructs the full Fisher block.
@ -706,8 +778,8 @@ class KroneckerProductFB(FisherBlock):
Returns:
The full Fisher block.
"""
left_factor = self._input_factor.get_cov()
right_factor = self._output_factor.get_cov()
left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
return self._renorm_coeff * utils.kronecker_product(left_factor,
right_factor)
@ -796,7 +868,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
"""FisherBlock for convolutional layers using the basic KFC approx.
r"""FisherBlock for convolutional layers using the basic KFC approx.
Estimates the Fisher Information matrix's blog for a convolutional
layer.
@ -945,10 +1017,10 @@ class DepthwiseConvDiagonalFB(ConvDiagonalFB):
self._filter_shape = (filter_height, filter_width, in_channels,
in_channels * channel_multiplier)
def multiply_matpower(self, vector, exp):
def _multiply_matrix(self, matrix, vector):
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower(
conv2d_vector, exp)
conv2d_result = super(
DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
@ -1016,10 +1088,14 @@ class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
self._filter_shape = (filter_height, filter_width, in_channels,
in_channels * channel_multiplier)
def multiply_matpower(self, vector, exp):
def _multiply_factored_matrix(self, left_factor, right_factor, vector,
extra_scale=1.0, transpose_left=False,
transpose_right=False):
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower(
conv2d_vector, exp)
conv2d_result = super(
DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
transpose_left=transpose_left, transpose_right=transpose_right)
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
@ -1664,3 +1740,12 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
return utils.mat2d_to_layer_params(vector, Z)
# pylint: enable=invalid-name
def multiply_cholesky(self, vector):
raise NotImplementedError("FullyConnectedSeriesFB does not support "
"Cholesky computations.")
def multiply_cholesky_inverse(self, vector):
raise NotImplementedError("FullyConnectedSeriesFB does not support "
"Cholesky computations.")

View File

@ -24,6 +24,7 @@ import contextlib
import numpy as np
import six
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as tf_ops
@ -399,7 +400,7 @@ class FisherFactor(object):
the cov update.
Returns:
Tensor of same shape as self.get_cov_var().
Tensor of same shape as self.get_cov().
"""
pass
@ -448,78 +449,43 @@ class FisherFactor(object):
"""Create and return update ops corresponding to registered computations."""
pass
@abc.abstractmethod
def get_cov(self):
"""Get full covariance matrix.
Returns:
Tensor of shape [n, n]. Represents all parameter-parameter correlations
captured by this FisherFactor.
"""
pass
def get_cov_var(self):
"""Get variable backing this FisherFactor.
May or may not be the same as self.get_cov()
Returns:
Variable of shape self._cov_shape.
"""
return self._cov
@abc.abstractmethod
def left_multiply_matpower(self, x, exp, damping_func):
"""Left multiplies 'x' by matrix power of this factor (w/ damping applied).
This calculation is essentially:
(C + damping * I)**exp * x
where * is matrix-multiplication, ** is matrix power, I is the identity
matrix, and C is the matrix represented by this factor.
x can represent either a matrix or a vector. For some factors, 'x' might
represent a vector but actually be stored as a 2D matrix for convenience.
Args:
x: Tensor. Represents a single vector. Shape depends on implementation.
exp: float. The matrix exponent to use.
damping_func: A function that computes a 0-D Tensor or a float which will
be the damping value used. i.e. damping = damping_func().
Returns:
Tensor of same shape as 'x' representing the result of the multiplication.
"""
def get_cov_as_linear_operator(self):
pass
@abc.abstractmethod
def right_multiply_matpower(self, x, exp, damping_func):
"""Right multiplies 'x' by matrix power of this factor (w/ damping applied).
def register_matpower(self, exp, damping_func):
pass
This calculation is essentially:
x * (C + damping * I)**exp
where * is matrix-multiplication, ** is matrix power, I is the identity
matrix, and C is the matrix represented by this factor.
@abc.abstractmethod
def register_cholesky(self, damping_func):
pass
Unlike left_multiply_matpower, x will always be a matrix.
@abc.abstractmethod
def register_cholesky_inverse(self, damping_func):
pass
Args:
x: Tensor. Represents a single vector. Shape depends on implementation.
exp: float. The matrix exponent to use.
damping_func: A function that computes a 0-D Tensor or a float which will
be the damping value used. i.e. damping = damping_func().
@abc.abstractmethod
def get_matpower(self, exp, damping_func):
pass
Returns:
Tensor of same shape as 'x' representing the result of the multiplication.
"""
@abc.abstractmethod
def get_cholesky(self, damping_func):
pass
@abc.abstractmethod
def get_cholesky_inverse(self, damping_func):
pass
class InverseProvidingFactor(FisherFactor):
"""Base class for FisherFactors that maintain inverses explicitly.
class DenseSquareMatrixFactor(FisherFactor):
"""Base class for FisherFactors that are stored as dense square matrices.
This class explicitly calculates and stores inverses of covariance matrices
provided by the underlying FisherFactor implementation. It is assumed that
vectors can be represented as 2-D matrices.
This class explicitly calculates and stores inverses of their `cov` matrices,
which must be square dense matrices.
Subclasses must implement the _compute_new_cov method, and the _var_scope and
_cov_shape properties.
@ -538,7 +504,19 @@ class InverseProvidingFactor(FisherFactor):
self._eigendecomp = None
self._damping_funcs_by_id = {} # {hashable: lambda}
super(InverseProvidingFactor, self).__init__()
self._cholesky_registrations = set() # { hashable }
self._cholesky_inverse_registrations = set() # { hashable }
self._cholesky_by_damping = {} # { hashable: variable }
self._cholesky_inverse_by_damping = {} # { hashable: variable }
super(DenseSquareMatrixFactor, self).__init__()
def get_cov_as_linear_operator(self):
assert self.get_cov().shape.ndims == 2
return lo.LinearOperatorFullMatrix(self.get_cov(),
is_self_adjoint=True,
is_square=True)
def _register_damping(self, damping_func):
damping_id = graph_func_to_id(damping_func)
@ -563,8 +541,6 @@ class InverseProvidingFactor(FisherFactor):
be the damping value used. i.e. damping = damping_func().
"""
if exp == 1.0:
# We don't register these. The user shouldn't even be calling this
# function with exp = 1.0.
return
damping_id = self._register_damping(damping_func)
@ -572,6 +548,38 @@ class InverseProvidingFactor(FisherFactor):
if (exp, damping_id) not in self._matpower_registrations:
self._matpower_registrations.add((exp, damping_id))
def register_cholesky(self, damping_func):
"""Registers a Cholesky factor to be maintained and served on demand.
This creates a variable and signals make_inverse_update_ops to make the
corresponding update op. The variable can be read via the method
get_cholesky.
Args:
damping_func: A function that computes a 0-D Tensor or a float which will
be the damping value used. i.e. damping = damping_func().
"""
damping_id = self._register_damping(damping_func)
if damping_id not in self._cholesky_registrations:
self._cholesky_registrations.add(damping_id)
def register_cholesky_inverse(self, damping_func):
"""Registers an inverse Cholesky factor to be maintained/served on demand.
This creates a variable and signals make_inverse_update_ops to make the
corresponding update op. The variable can be read via the method
get_cholesky_inverse.
Args:
damping_func: A function that computes a 0-D Tensor or a float which will
be the damping value used. i.e. damping = damping_func().
"""
damping_id = self._register_damping(damping_func)
if damping_id not in self._cholesky_inverse_registrations:
self._cholesky_inverse_registrations.add(damping_id)
def instantiate_inv_variables(self):
"""Makes the internal "inverse" variable(s)."""
@ -589,6 +597,32 @@ class InverseProvidingFactor(FisherFactor):
assert (exp, damping_id) not in self._matpower_by_exp_and_damping
self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
for damping_id in self._cholesky_registrations:
damping_func = self._damping_funcs_by_id[damping_id]
damping_string = graph_func_to_string(damping_func)
with variable_scope.variable_scope(self._var_scope):
chol = variable_scope.get_variable(
"cholesky_damp{}".format(damping_string),
initializer=inverse_initializer,
shape=self._cov_shape,
trainable=False,
dtype=self._dtype)
assert damping_id not in self._cholesky_by_damping
self._cholesky_by_damping[damping_id] = chol
for damping_id in self._cholesky_inverse_registrations:
damping_func = self._damping_funcs_by_id[damping_id]
damping_string = graph_func_to_string(damping_func)
with variable_scope.variable_scope(self._var_scope):
cholinv = variable_scope.get_variable(
"cholesky_inverse_damp{}".format(damping_string),
initializer=inverse_initializer,
shape=self._cov_shape,
trainable=False,
dtype=self._dtype)
assert damping_id not in self._cholesky_inverse_by_damping
self._cholesky_inverse_by_damping[damping_id] = cholinv
def make_inverse_update_ops(self):
"""Create and return update ops corresponding to registered computations."""
ops = []
@ -606,7 +640,8 @@ class InverseProvidingFactor(FisherFactor):
# We precompute these so we don't need to evaluate them multiple times (for
# each matrix power that uses them)
damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]()
damping_value_by_id = {damping_id: math_ops.cast(
self._damping_funcs_by_id[damping_id](), self._dtype)
for damping_id in self._damping_funcs_by_id}
if use_eig:
@ -627,29 +662,91 @@ class InverseProvidingFactor(FisherFactor):
self._matpower_by_exp_and_damping.items()):
assert exp == -1
damping = damping_value_by_id[damping_id]
ops.append(matpower.assign(utils.posdef_inv(self._cov, damping)))
ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
# TODO(b/77902055): If inverses are being computed with Cholesky's
# we can share the work. Instead this code currently just computes the
# Cholesky a second time. It does at least share work between requests for
# Cholesky's and Cholesky inverses with the same damping id.
for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
cholesky_ops = []
damping = damping_value_by_id[damping_id]
cholesky_value = utils.cholesky(self.get_cov(), damping)
if damping_id in self._cholesky_by_damping:
cholesky = self._cholesky_by_damping[damping_id]
cholesky_ops.append(cholesky.assign(cholesky_value))
identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
dtype=cholesky_value.dtype)
cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
identity)
cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
ops.append(control_flow_ops.group(*cholesky_ops))
for damping_id, cholesky in self._cholesky_by_damping.items():
if damping_id not in self._cholesky_inverse_by_damping:
damping = damping_value_by_id[damping_id]
cholesky_value = utils.cholesky(self.get_cov(), damping)
ops.append(cholesky.assign(cholesky_value))
self._eigendecomp = False
return ops
def get_inverse(self, damping_func):
# Just for backwards compatibility of some old code and tests
damping_id = graph_func_to_id(damping_func)
return self._matpower_by_exp_and_damping[(-1, damping_id)]
return self.get_matpower(-1, damping_func)
def get_matpower(self, exp, damping_func):
# Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of
# get_cov().
if exp != 1:
damping_id = graph_func_to_id(damping_func)
matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
else:
matpower = self.get_cov()
identity = linalg_ops.eye(matpower.shape.as_list()[0],
dtype=matpower.dtype)
matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
assert matpower.shape.ndims == 2
return lo.LinearOperatorFullMatrix(matpower,
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True,
is_square=True)
def get_cholesky(self, damping_func):
# Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of
# get_cov().
damping_id = graph_func_to_id(damping_func)
return self._matpower_by_exp_and_damping[(exp, damping_id)]
cholesky = self._cholesky_by_damping[damping_id]
assert cholesky.shape.ndims == 2
return lo.LinearOperatorFullMatrix(cholesky,
is_non_singular=True,
is_square=True)
def get_cholesky_inverse(self, damping_func):
# Note that this function returns a variable which gets updated by the
# inverse ops. It may be stale / inconsistent with the latest value of
# get_cov().
damping_id = graph_func_to_id(damping_func)
cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
assert cholesky_inv.shape.ndims == 2
return lo.LinearOperatorFullMatrix(cholesky_inv,
is_non_singular=True,
is_square=True)
def get_eigendecomp(self):
"""Creates or retrieves eigendecomposition of self._cov."""
# Unlike get_matpower this doesn't retrieve a stored variable, but instead
# always computes a fresh version from the current value of get_cov().
if not self._eigendecomp:
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
# The matrix self._cov is positive semidefinite by construction, but the
# numerical eigenvalues could be negative due to numerical errors, so here
@ -660,45 +757,8 @@ class InverseProvidingFactor(FisherFactor):
return self._eigendecomp
def get_cov(self):
# Variable contains full covariance matrix.
return self.get_cov_var()
def left_multiply_matpower(self, x, exp, damping_func):
if isinstance(x, tf_ops.IndexedSlices):
raise ValueError("Left-multiply not yet supported for IndexedSlices.")
if x.shape.ndims != 2:
raise ValueError(
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
% (x,))
if exp == 1:
return math_ops.matmul(self.get_cov(), x) + damping_func() * x
return math_ops.matmul(self.get_matpower(exp, damping_func), x)
def right_multiply_matpower(self, x, exp, damping_func):
if isinstance(x, tf_ops.IndexedSlices):
if exp == 1:
n = self.get_cov().shape[0]
damped_cov = self.get_cov() + damping_func() * array_ops.eye(n)
return utils.matmul_sparse_dense(x, damped_cov)
return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func))
if x.shape.ndims != 2:
raise ValueError(
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
% (x,))
if exp == 1:
return math_ops.matmul(x, self.get_cov()) + damping_func() * x
return math_ops.matmul(x, self.get_matpower(exp, damping_func))
class FullFactor(InverseProvidingFactor):
class FullFactor(DenseSquareMatrixFactor):
"""FisherFactor for a full matrix representation of the Fisher of a parameter.
Note that this uses the naive "square the sum estimator", and so is applicable
@ -757,42 +817,52 @@ class DiagonalFactor(FisherFactor):
"""
def __init__(self):
self._damping_funcs_by_id = {} # { hashable: lambda }
super(DiagonalFactor, self).__init__()
def get_cov_as_linear_operator(self):
assert self._matrix_diagonal.shape.ndims == 1
return lo.LinearOperatorDiag(self._matrix_diagonal,
is_self_adjoint=True,
is_square=True)
@property
def _cov_initializer(self):
return diagonal_covariance_initializer
@property
def _matrix_diagonal(self):
return array_ops.reshape(self.get_cov(), [-1])
def make_inverse_update_ops(self):
return []
def instantiate_inv_variables(self):
pass
def get_cov(self):
# self.get_cov() could be any shape, but it must have one entry per
# parameter. Flatten it into a vector.
cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1])
return array_ops.diag(cov_diag_vec)
def left_multiply_matpower(self, x, exp, damping_func):
matpower = (self.get_cov_var() + damping_func())**exp
if isinstance(x, tf_ops.IndexedSlices):
return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x)
if x.shape != matpower.shape:
raise ValueError("x (%s) and cov (%s) must have same shape." %
(x, matpower))
return matpower * x
def right_multiply_matpower(self, x, exp, damping_func):
raise NotImplementedError("Only left-multiply is currently supported.")
def register_matpower(self, exp, damping_func):
pass
def register_cholesky(self, damping_func):
pass
def register_cholesky_inverse(self, damping_func):
pass
def get_matpower(self, exp, damping_func):
matpower_diagonal = (self._matrix_diagonal
+ math_ops.cast(damping_func(), self._dtype))**exp
return lo.LinearOperatorDiag(matpower_diagonal,
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True,
is_square=True)
def get_cholesky(self, damping_func):
return self.get_matpower(0.5, damping_func)
def get_cholesky_inverse(self, damping_func):
return self.get_matpower(-0.5, damping_func)
class NaiveDiagonalFactor(DiagonalFactor):
"""FisherFactor for a diagonal approximation of any type of param's Fisher.
@ -1167,7 +1237,7 @@ class ConvDiagonalFactor(DiagonalFactor):
return self._inputs[tower].device
class FullyConnectedKroneckerFactor(InverseProvidingFactor):
class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
"""Kronecker factor for the input or output side of a fully-connected layer.
"""
@ -1220,7 +1290,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
return self._tensors[0][tower].device
class ConvInputKroneckerFactor(InverseProvidingFactor):
class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
r"""Kronecker factor for the input side of a convolutional layer.
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
@ -1384,7 +1454,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
return self._inputs[tower].device
class ConvOutputKroneckerFactor(InverseProvidingFactor):
class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
r"""Kronecker factor for the output side of a convolutional layer.
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
@ -1674,6 +1744,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
psi_var) in self._option1quants_by_damping.items():
damping = self._damping_funcs_by_id[damping_id]()
damping = math_ops.cast(damping, self._dtype)
invsqrtC0 = math_ops.matmul(
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
@ -1702,6 +1773,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
mu_var) in self._option2quants_by_damping.items():
damping = self._damping_funcs_by_id[damping_id]()
damping = math_ops.cast(damping, self._dtype)
# compute C0^(-1/2)
invsqrtC0 = math_ops.matmul(

View File

@ -0,0 +1,95 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SmartMatrices definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg
from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.ops.linalg import linear_operator_util as lou
class LinearOperatorExtras(object): # pylint: disable=missing-docstring
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
with self._name_scope(name, values=[x]):
if isinstance(x, ops.IndexedSlices):
return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
x = ops.convert_to_tensor(x, name="x")
self._check_input_dtype(x)
self_dim = -2 if adjoint else -1
arg_dim = -1 if adjoint_arg else -2
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
with self._name_scope(name, values=[x]):
if isinstance(x, ops.IndexedSlices):
return self._matmul_right_sparse(
x, adjoint=adjoint, adjoint_arg=adjoint_arg)
x = ops.convert_to_tensor(x, name="x")
self._check_input_dtype(x)
self_dim = -1 if adjoint else -2
arg_dim = -2 if adjoint_arg else -1
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
class LinearOperatorFullMatrix(LinearOperatorExtras,
linalg.LinearOperatorFullMatrix):
# TODO(b/78117889) Remove this definition once core LinearOperator
# has _matmul_right.
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
return lou.matmul_with_broadcast(
x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
raise NotImplementedError
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
assert not adjoint and not adjoint_arg
return utils.matmul_sparse_dense(x, self._matrix)
class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
linalg.LinearOperatorDiag):
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
x = linalg_impl.adjoint(x) if adjoint_arg else x
return diag_mat * x
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
assert not adjoint_arg
return utils.matmul_diag_sparse(diag_mat, x)
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
raise NotImplementedError

View File

@ -35,7 +35,7 @@ def _make_thunk_on_device(func, device):
class RoundRobinPlacementMixin(object):
"""Implements round robin placement strategy for ops and variables."""
def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs):
def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
"""Initializes the RoundRobinPlacementMixin class.
Args:
@ -45,11 +45,10 @@ class RoundRobinPlacementMixin(object):
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
computations will be placed on these devices in a round-robin fashion.
Can be None, which means that no devices are specified.
*args:
**kwargs:
**kwargs: Need something here?
"""
super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs)
super(RoundRobinPlacementMixin, self).__init__(**kwargs)
self._cov_devices = cov_devices
self._inv_devices = inv_devices

View File

@ -235,6 +235,13 @@ posdef_eig_functions = {
}
def cholesky(tensor, damping):
"""Computes the inverse of tensor + damping * identity."""
identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
damping = math_ops.cast(damping, dtype=tensor.dtype)
return linalg_ops.cholesky(tensor + damping * identity)
class SubGraph(object):
"""Defines a subgraph given by all the dependencies of a given set of outputs.
"""
@ -553,13 +560,17 @@ def is_data_format_channel_last(data_format):
return data_format.endswith("C")
def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
"""Computes matmul(A, B) where A is sparse, B is dense.
Args:
A: tf.IndexedSlices with dense shape [m, n].
B: tf.Tensor with shape [n, k].
name: str. Name of op.
transpose_a: Bool. If true we transpose A before multiplying it by B.
(Default: False)
transpose_b: Bool. If true we transpose B before multiplying it by A.
(Default: False)
Returns:
tf.IndexedSlices resulting from matmul(A, B).
@ -573,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
raise ValueError("A must represent a matrix. Found: %s." % A)
if B.shape.ndims != 2:
raise ValueError("B must be a matrix.")
new_values = math_ops.matmul(A.values, B)
new_values = math_ops.matmul(
A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
return ops.IndexedSlices(
new_values,
A.indices,