- Adding support for Cholesky (inverse) factor multiplications.
- Refactored FisherFactor to use LinearOperator classes that know how to multiply themselves, compute their own trace, etc. This addresses the feature request: b/73356352 - Fixed some problems with FisherEstimator construction - More careful casting of damping constants before they are used PiperOrigin-RevId: 194379298
This commit is contained in:
parent
8148895adc
commit
481f229881
@ -58,6 +58,7 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
|
"//tensorflow/contrib/kfac/python/ops:fisher_blocks",
|
||||||
"//tensorflow/contrib/kfac/python/ops:layer_collection",
|
"//tensorflow/contrib/kfac/python/ops:layer_collection",
|
||||||
|
"//tensorflow/contrib/kfac/python/ops:linear_operator",
|
||||||
"//tensorflow/contrib/kfac/python/ops:utils",
|
"//tensorflow/contrib/kfac/python/ops:utils",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -22,6 +22,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
|
from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
|
||||||
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
|
from tensorflow.contrib.kfac.python.ops import layer_collection as lc
|
||||||
|
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
|
||||||
from tensorflow.contrib.kfac.python.ops import utils
|
from tensorflow.contrib.kfac.python.ops import utils
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
@ -46,8 +47,9 @@ class UtilsTest(test.TestCase):
|
|||||||
def testComputePiTracenorm(self):
|
def testComputePiTracenorm(self):
|
||||||
with ops.Graph().as_default(), self.test_session() as sess:
|
with ops.Graph().as_default(), self.test_session() as sess:
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
left_factor = array_ops.diag([1., 2., 0., 1.])
|
diag = ops.convert_to_tensor([1., 2., 0., 1.])
|
||||||
right_factor = array_ops.ones([2., 2.])
|
left_factor = lo.LinearOperatorDiag(diag)
|
||||||
|
right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
|
||||||
|
|
||||||
# pi is the sqrt of the left trace norm divided by the right trace norm
|
# pi is the sqrt of the left trace norm divided by the right trace norm
|
||||||
pi = fb.compute_pi_tracenorm(left_factor, right_factor)
|
pi = fb.compute_pi_tracenorm(left_factor, right_factor)
|
||||||
@ -245,7 +247,6 @@ class NaiveDiagonalFBTest(test.TestCase):
|
|||||||
|
|
||||||
full = sess.run(block.full_fisher_block())
|
full = sess.run(block.full_fisher_block())
|
||||||
explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
|
explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
|
||||||
|
|
||||||
self.assertAllClose(output_flat, explicit)
|
self.assertAllClose(output_flat, explicit)
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,18 +70,6 @@ class FisherFactorTestingDummy(ff.FisherFactor):
|
|||||||
def get_cov(self):
|
def get_cov(self):
|
||||||
return NotImplementedError
|
return NotImplementedError
|
||||||
|
|
||||||
def left_multiply(self, x, damping):
|
|
||||||
return NotImplementedError
|
|
||||||
|
|
||||||
def right_multiply(self, x, damping):
|
|
||||||
return NotImplementedError
|
|
||||||
|
|
||||||
def left_multiply_matpower(self, x, exp, damping):
|
|
||||||
return NotImplementedError
|
|
||||||
|
|
||||||
def right_multiply_matpower(self, x, exp, damping):
|
|
||||||
return NotImplementedError
|
|
||||||
|
|
||||||
def instantiate_inv_variables(self):
|
def instantiate_inv_variables(self):
|
||||||
return NotImplementedError
|
return NotImplementedError
|
||||||
|
|
||||||
@ -91,14 +79,35 @@ class FisherFactorTestingDummy(ff.FisherFactor):
|
|||||||
def _get_data_device(self):
|
def _get_data_device(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def register_matpower(self, exp, damping_func):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor):
|
def register_cholesky(self, damping_func):
|
||||||
"""Dummy class to test the non-abstract methods on ff.InverseProvidingFactor.
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def register_cholesky_inverse(self, damping_func):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_matpower(self, exp, damping_func):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_cholesky(self, damping_func):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_cholesky_inverse(self, damping_func):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_cov_as_linear_operator(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
|
||||||
|
"""Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shape):
|
def __init__(self, shape):
|
||||||
self._shape = shape
|
self._shape = shape
|
||||||
super(InverseProvidingFactorTestingDummy, self).__init__()
|
super(DenseSquareMatrixFactorTestingDummy, self).__init__()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _var_scope(self):
|
def _var_scope(self):
|
||||||
@ -230,13 +239,13 @@ class FisherFactorTest(test.TestCase):
|
|||||||
self.assertEqual(0, len(factor.make_inverse_update_ops()))
|
self.assertEqual(0, len(factor.make_inverse_update_ops()))
|
||||||
|
|
||||||
|
|
||||||
class InverseProvidingFactorTest(test.TestCase):
|
class DenseSquareMatrixFactorTest(test.TestCase):
|
||||||
|
|
||||||
def testRegisterDampedInverse(self):
|
def testRegisterDampedInverse(self):
|
||||||
with tf_ops.Graph().as_default():
|
with tf_ops.Graph().as_default():
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
shape = [2, 2]
|
shape = [2, 2]
|
||||||
factor = InverseProvidingFactorTestingDummy(shape)
|
factor = DenseSquareMatrixFactorTestingDummy(shape)
|
||||||
factor_var_scope = 'dummy/a_b_c'
|
factor_var_scope = 'dummy/a_b_c'
|
||||||
|
|
||||||
damping_funcs = [make_damping_func(0.1),
|
damping_funcs = [make_damping_func(0.1),
|
||||||
@ -248,22 +257,25 @@ class InverseProvidingFactorTest(test.TestCase):
|
|||||||
|
|
||||||
factor.instantiate_inv_variables()
|
factor.instantiate_inv_variables()
|
||||||
|
|
||||||
inv = factor.get_inverse(damping_funcs[0])
|
inv = factor.get_inverse(damping_funcs[0]).to_dense()
|
||||||
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]))
|
self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
|
||||||
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]))
|
self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
|
||||||
self.assertEqual(factor.get_inverse(damping_funcs[2]),
|
self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
|
||||||
factor.get_inverse(damping_funcs[3]))
|
factor.get_inverse(damping_funcs[3]).to_dense())
|
||||||
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
||||||
factor_var_scope)
|
factor_var_scope)
|
||||||
self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]),
|
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
|
||||||
set(factor_vars))
|
|
||||||
|
self.assertEqual(set([inv,
|
||||||
|
factor.get_inverse(damping_funcs[2]).to_dense()]),
|
||||||
|
set(factor_tensors))
|
||||||
self.assertEqual(shape, inv.get_shape())
|
self.assertEqual(shape, inv.get_shape())
|
||||||
|
|
||||||
def testRegisterMatpower(self):
|
def testRegisterMatpower(self):
|
||||||
with tf_ops.Graph().as_default():
|
with tf_ops.Graph().as_default():
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
shape = [3, 3]
|
shape = [3, 3]
|
||||||
factor = InverseProvidingFactorTestingDummy(shape)
|
factor = DenseSquareMatrixFactorTestingDummy(shape)
|
||||||
factor_var_scope = 'dummy/a_b_c'
|
factor_var_scope = 'dummy/a_b_c'
|
||||||
|
|
||||||
# TODO(b/74201126): Change to using the same func for both once
|
# TODO(b/74201126): Change to using the same func for both once
|
||||||
@ -278,10 +290,13 @@ class InverseProvidingFactorTest(test.TestCase):
|
|||||||
|
|
||||||
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
|
||||||
factor_var_scope)
|
factor_var_scope)
|
||||||
matpower1 = factor.get_matpower(-0.5, damping_func_1)
|
|
||||||
matpower2 = factor.get_matpower(2, damping_func_2)
|
|
||||||
|
|
||||||
self.assertEqual(set([matpower1, matpower2]), set(factor_vars))
|
factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
|
||||||
|
|
||||||
|
matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
|
||||||
|
matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
|
||||||
|
|
||||||
|
self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
|
||||||
|
|
||||||
self.assertEqual(shape, matpower1.get_shape())
|
self.assertEqual(shape, matpower1.get_shape())
|
||||||
self.assertEqual(shape, matpower2.get_shape())
|
self.assertEqual(shape, matpower2.get_shape())
|
||||||
@ -297,7 +312,7 @@ class InverseProvidingFactorTest(test.TestCase):
|
|||||||
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
cov = np.array([[1., 2.], [3., 4.]])
|
cov = np.array([[1., 2.], [3., 4.]])
|
||||||
factor = InverseProvidingFactorTestingDummy(cov.shape)
|
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||||
|
|
||||||
damping_funcs = []
|
damping_funcs = []
|
||||||
@ -316,7 +331,8 @@ class InverseProvidingFactorTest(test.TestCase):
|
|||||||
sess.run(ops)
|
sess.run(ops)
|
||||||
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
|
for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
|
||||||
# The inverse op will assign the damped inverse of cov to the inv var.
|
# The inverse op will assign the damped inverse of cov to the inv var.
|
||||||
new_invs.append(sess.run(factor.get_inverse(damping_funcs[i])))
|
new_invs.append(
|
||||||
|
sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
|
||||||
|
|
||||||
# We want to see that the new invs are all different from each other.
|
# We want to see that the new invs are all different from each other.
|
||||||
for i in range(len(new_invs)):
|
for i in range(len(new_invs)):
|
||||||
@ -328,7 +344,7 @@ class InverseProvidingFactorTest(test.TestCase):
|
|||||||
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
cov = np.array([[6., 2.], [2., 4.]])
|
cov = np.array([[6., 2.], [2., 4.]])
|
||||||
factor = InverseProvidingFactorTestingDummy(cov.shape)
|
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||||
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
|
exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
|
||||||
damping = 0.5
|
damping = 0.5
|
||||||
@ -341,7 +357,7 @@ class InverseProvidingFactorTest(test.TestCase):
|
|||||||
|
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
sess.run(ops[0])
|
sess.run(ops[0])
|
||||||
matpower = sess.run(factor.get_matpower(exp, damping_func))
|
matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
|
||||||
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
|
matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
|
||||||
self.assertAllClose(matpower, matpower_np)
|
self.assertAllClose(matpower, matpower_np)
|
||||||
|
|
||||||
@ -349,7 +365,7 @@ class InverseProvidingFactorTest(test.TestCase):
|
|||||||
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
with tf_ops.Graph().as_default(), self.test_session() as sess:
|
||||||
random_seed.set_random_seed(200)
|
random_seed.set_random_seed(200)
|
||||||
cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
|
cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
|
||||||
factor = InverseProvidingFactorTestingDummy(cov.shape)
|
factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
|
||||||
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
|
||||||
|
|
||||||
damping_func = make_damping_func(0)
|
damping_func = make_damping_func(0)
|
||||||
@ -361,12 +377,12 @@ class InverseProvidingFactorTest(test.TestCase):
|
|||||||
|
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
# The inverse op will assign the damped inverse of cov to the inv var.
|
# The inverse op will assign the damped inverse of cov to the inv var.
|
||||||
old_inv = sess.run(factor.get_inverse(damping_func))
|
old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
|
sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
|
||||||
|
|
||||||
sess.run(ops)
|
sess.run(ops)
|
||||||
new_inv = sess.run(factor.get_inverse(damping_func))
|
new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
|
||||||
self.assertAllClose(new_inv, np.linalg.inv(cov))
|
self.assertAllClose(new_inv, np.linalg.inv(cov))
|
||||||
|
|
||||||
|
|
||||||
@ -411,7 +427,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
|
|||||||
tensor = array_ops.ones((2, 3), name='a/b/c')
|
tensor = array_ops.ones((2, 3), name='a/b/c')
|
||||||
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
||||||
factor.instantiate_cov_variables()
|
factor.instantiate_cov_variables()
|
||||||
self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list())
|
self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
|
||||||
|
|
||||||
def testNaiveDiagonalFactorInitFloat64(self):
|
def testNaiveDiagonalFactorInitFloat64(self):
|
||||||
with tf_ops.Graph().as_default():
|
with tf_ops.Graph().as_default():
|
||||||
@ -420,7 +436,7 @@ class NaiveDiagonalFactorTest(test.TestCase):
|
|||||||
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
|
||||||
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
factor = ff.NaiveDiagonalFactor((tensor,), 32)
|
||||||
factor.instantiate_cov_variables()
|
factor.instantiate_cov_variables()
|
||||||
cov = factor.get_cov_var()
|
cov = factor.get_cov()
|
||||||
self.assertEqual(cov.dtype, dtype)
|
self.assertEqual(cov.dtype, dtype)
|
||||||
self.assertEqual([6, 1], cov.get_shape().as_list())
|
self.assertEqual([6, 1], cov.get_shape().as_list())
|
||||||
|
|
||||||
@ -444,7 +460,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
|
|||||||
vocab_size = 5
|
vocab_size = 5
|
||||||
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
|
factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
|
||||||
factor.instantiate_cov_variables()
|
factor.instantiate_cov_variables()
|
||||||
cov = factor.get_cov_var()
|
cov = factor.get_cov()
|
||||||
self.assertEqual(cov.shape.as_list(), [vocab_size])
|
self.assertEqual(cov.shape.as_list(), [vocab_size])
|
||||||
|
|
||||||
def testCovarianceUpdateOp(self):
|
def testCovarianceUpdateOp(self):
|
||||||
@ -502,7 +518,7 @@ class ConvDiagonalFactorTest(test.TestCase):
|
|||||||
self.kernel_height * self.kernel_width * self.in_channels,
|
self.kernel_height * self.kernel_width * self.in_channels,
|
||||||
self.out_channels
|
self.out_channels
|
||||||
],
|
],
|
||||||
factor.get_cov_var().shape.as_list())
|
factor.get_cov().shape.as_list())
|
||||||
|
|
||||||
def testMakeCovarianceUpdateOp(self):
|
def testMakeCovarianceUpdateOp(self):
|
||||||
with tf_ops.Graph().as_default():
|
with tf_ops.Graph().as_default():
|
||||||
@ -564,7 +580,7 @@ class ConvDiagonalFactorTest(test.TestCase):
|
|||||||
self.kernel_height * self.kernel_width * self.in_channels + 1,
|
self.kernel_height * self.kernel_width * self.in_channels + 1,
|
||||||
self.out_channels
|
self.out_channels
|
||||||
],
|
],
|
||||||
factor.get_cov_var().shape.as_list())
|
factor.get_cov().shape.as_list())
|
||||||
|
|
||||||
# Ensure update op doesn't crash.
|
# Ensure update op doesn't crash.
|
||||||
cov_update_op = factor.make_covariance_update_op(0.0)
|
cov_update_op = factor.make_covariance_update_op(0.0)
|
||||||
@ -654,13 +670,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
|||||||
# Ensure shape of covariance matches input size of filter.
|
# Ensure shape of covariance matches input size of filter.
|
||||||
input_size = in_channels * (width**3)
|
input_size = in_channels * (width**3)
|
||||||
self.assertEqual([input_size, input_size],
|
self.assertEqual([input_size, input_size],
|
||||||
factor.get_cov_var().shape.as_list())
|
factor.get_cov().shape.as_list())
|
||||||
|
|
||||||
# Ensure cov_update_op doesn't crash.
|
# Ensure cov_update_op doesn't crash.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
sess.run(factor.make_covariance_update_op(0.0))
|
sess.run(factor.make_covariance_update_op(0.0))
|
||||||
cov = sess.run(factor.get_cov_var())
|
cov = sess.run(factor.get_cov())
|
||||||
|
|
||||||
# Cov should be rank-8, as the filter will be applied at each corner of
|
# Cov should be rank-8, as the filter will be applied at each corner of
|
||||||
# the 4-D cube.
|
# the 4-D cube.
|
||||||
@ -685,13 +701,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
|||||||
|
|
||||||
# Ensure shape of covariance matches input size of filter.
|
# Ensure shape of covariance matches input size of filter.
|
||||||
self.assertEqual([in_channels, in_channels],
|
self.assertEqual([in_channels, in_channels],
|
||||||
factor.get_cov_var().shape.as_list())
|
factor.get_cov().shape.as_list())
|
||||||
|
|
||||||
# Ensure cov_update_op doesn't crash.
|
# Ensure cov_update_op doesn't crash.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
sess.run(factor.make_covariance_update_op(0.0))
|
sess.run(factor.make_covariance_update_op(0.0))
|
||||||
cov = sess.run(factor.get_cov_var())
|
cov = sess.run(factor.get_cov())
|
||||||
|
|
||||||
# Cov should be rank-9, as the filter will be applied at each location.
|
# Cov should be rank-9, as the filter will be applied at each location.
|
||||||
self.assertMatrixRank(9, cov)
|
self.assertMatrixRank(9, cov)
|
||||||
@ -716,7 +732,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
sess.run(factor.make_covariance_update_op(0.0))
|
sess.run(factor.make_covariance_update_op(0.0))
|
||||||
cov = sess.run(factor.get_cov_var())
|
cov = sess.run(factor.get_cov())
|
||||||
|
|
||||||
# Cov should be the sum of 3 * 2 = 6 outer products.
|
# Cov should be the sum of 3 * 2 = 6 outer products.
|
||||||
self.assertMatrixRank(6, cov)
|
self.assertMatrixRank(6, cov)
|
||||||
@ -742,7 +758,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(tf_variables.global_variables_initializer())
|
sess.run(tf_variables.global_variables_initializer())
|
||||||
sess.run(factor.make_covariance_update_op(0.0))
|
sess.run(factor.make_covariance_update_op(0.0))
|
||||||
cov = sess.run(factor.get_cov_var())
|
cov = sess.run(factor.get_cov())
|
||||||
|
|
||||||
# Cov should be rank = in_channels, as only the center of the filter
|
# Cov should be rank = in_channels, as only the center of the filter
|
||||||
# receives non-zero input for each input channel.
|
# receives non-zero input for each input channel.
|
||||||
|
@ -35,6 +35,7 @@ py_library(
|
|||||||
srcs = ["fisher_factors.py"],
|
srcs = ["fisher_factors.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":linear_operator",
|
||||||
":utils",
|
":utils",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
@ -63,6 +64,19 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "linear_operator",
|
||||||
|
srcs = ["linear_operator.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python/ops/linalg",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "loss_functions",
|
name = "loss_functions",
|
||||||
srcs = ["loss_functions.py"],
|
srcs = ["loss_functions.py"],
|
||||||
|
@ -57,8 +57,8 @@ def make_fisher_estimator(placement_strategy=None, **kwargs):
|
|||||||
if placement_strategy in [None, "round_robin"]:
|
if placement_strategy in [None, "round_robin"]:
|
||||||
return FisherEstimatorRoundRobin(**kwargs)
|
return FisherEstimatorRoundRobin(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unimplemented vars and ops placement strategy : %s",
|
raise ValueError("Unimplemented vars and ops "
|
||||||
placement_strategy)
|
"placement strategy : {}".format(placement_strategy))
|
||||||
# pylint: enable=abstract-class-instantiated
|
# pylint: enable=abstract-class-instantiated
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,9 @@ class FisherEstimator(object):
|
|||||||
exps=(-1,),
|
exps=(-1,),
|
||||||
estimation_mode="gradients",
|
estimation_mode="gradients",
|
||||||
colocate_gradients_with_ops=True,
|
colocate_gradients_with_ops=True,
|
||||||
name="FisherEstimator"):
|
name="FisherEstimator",
|
||||||
|
compute_cholesky=False,
|
||||||
|
compute_cholesky_inverse=False):
|
||||||
"""Create a FisherEstimator object.
|
"""Create a FisherEstimator object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -124,6 +126,12 @@ class FisherEstimator(object):
|
|||||||
name: A string. A name given to this estimator, which is added to the
|
name: A string. A name given to this estimator, which is added to the
|
||||||
variable scope when constructing variables and ops.
|
variable scope when constructing variables and ops.
|
||||||
(Default: "FisherEstimator")
|
(Default: "FisherEstimator")
|
||||||
|
compute_cholesky: Bool. Whether or not the FisherEstimator will be
|
||||||
|
able to multiply vectors by the Cholesky factor.
|
||||||
|
(Default: False)
|
||||||
|
compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
|
||||||
|
will be able to multiply vectors by the Cholesky factor inverse.
|
||||||
|
(Default: False)
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no losses have been registered with layer_collection.
|
ValueError: If no losses have been registered with layer_collection.
|
||||||
"""
|
"""
|
||||||
@ -142,6 +150,8 @@ class FisherEstimator(object):
|
|||||||
|
|
||||||
self._made_vars = False
|
self._made_vars = False
|
||||||
self._exps = exps
|
self._exps = exps
|
||||||
|
self._compute_cholesky = compute_cholesky
|
||||||
|
self._compute_cholesky_inverse = compute_cholesky_inverse
|
||||||
|
|
||||||
self._name = name
|
self._name = name
|
||||||
|
|
||||||
@ -300,9 +310,54 @@ class FisherEstimator(object):
|
|||||||
A list of (transformed vector, var) pairs in the same order as
|
A list of (transformed vector, var) pairs in the same order as
|
||||||
vecs_and_vars.
|
vecs_and_vars.
|
||||||
"""
|
"""
|
||||||
|
assert exp in self._exps
|
||||||
|
|
||||||
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
|
fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
|
||||||
return self._apply_transformation(vecs_and_vars, fcn)
|
return self._apply_transformation(vecs_and_vars, fcn)
|
||||||
|
|
||||||
|
def multiply_cholesky(self, vecs_and_vars, transpose=False):
|
||||||
|
"""Multiplies the vecs by the corresponding Cholesky factors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vecs_and_vars: List of (vector, variable) pairs.
|
||||||
|
transpose: Bool. If true the Cholesky factors are transposed before
|
||||||
|
multiplying the vecs. (Default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of (transformed vector, var) pairs in the same order as
|
||||||
|
vecs_and_vars.
|
||||||
|
"""
|
||||||
|
assert self._compute_cholesky
|
||||||
|
|
||||||
|
fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
|
||||||
|
return self._apply_transformation(vecs_and_vars, fcn)
|
||||||
|
|
||||||
|
def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
|
||||||
|
"""Mults the vecs by the inverses of the corresponding Cholesky factors.
|
||||||
|
|
||||||
|
Note: if you are using Cholesky inverse multiplication to sample from
|
||||||
|
a matrix-variate Gaussian you will want to multiply by the transpose.
|
||||||
|
Let L be the Cholesky factor of F and observe that
|
||||||
|
|
||||||
|
L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
|
||||||
|
|
||||||
|
Thus we want to multiply by L^-T in order to sample from Gaussian with
|
||||||
|
covariance F^-1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vecs_and_vars: List of (vector, variable) pairs.
|
||||||
|
transpose: Bool. If true the Cholesky factor inverses are transposed
|
||||||
|
before multiplying the vecs. (Default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of (transformed vector, var) pairs in the same order as
|
||||||
|
vecs_and_vars.
|
||||||
|
"""
|
||||||
|
assert self._compute_cholesky_inverse
|
||||||
|
|
||||||
|
fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
|
||||||
|
return self._apply_transformation(vecs_and_vars, fcn)
|
||||||
|
|
||||||
def _instantiate_factors(self):
|
def _instantiate_factors(self):
|
||||||
"""Instantiates FisherFactors' variables.
|
"""Instantiates FisherFactors' variables.
|
||||||
|
|
||||||
@ -333,9 +388,13 @@ class FisherEstimator(object):
|
|||||||
return self._made_vars
|
return self._made_vars
|
||||||
|
|
||||||
def _register_matrix_functions(self):
|
def _register_matrix_functions(self):
|
||||||
for exp in self._exps:
|
for block in self.blocks:
|
||||||
for block in self.blocks:
|
for exp in self._exps:
|
||||||
block.register_matpower(exp)
|
block.register_matpower(exp)
|
||||||
|
if self._compute_cholesky:
|
||||||
|
block.register_cholesky()
|
||||||
|
if self._compute_cholesky_inverse:
|
||||||
|
block.register_cholesky_inverse()
|
||||||
|
|
||||||
def _finalize_layer_collection(self):
|
def _finalize_layer_collection(self):
|
||||||
self._layers.create_subgraph()
|
self._layers.create_subgraph()
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.util.all_util import remove_undocumented
|
|||||||
|
|
||||||
_allowed_symbols = [
|
_allowed_symbols = [
|
||||||
'FisherEstimator',
|
'FisherEstimator',
|
||||||
|
'make_fisher_estimator',
|
||||||
]
|
]
|
||||||
|
|
||||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
||||||
|
@ -83,34 +83,22 @@ def normalize_damping(damping, num_replications):
|
|||||||
|
|
||||||
|
|
||||||
def compute_pi_tracenorm(left_cov, right_cov):
|
def compute_pi_tracenorm(left_cov, right_cov):
|
||||||
"""Computes the scalar constant pi for Tikhonov regularization/damping.
|
r"""Computes the scalar constant pi for Tikhonov regularization/damping.
|
||||||
|
|
||||||
$$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
|
$$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
|
||||||
See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
|
See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
left_cov: The left Kronecker factor "covariance".
|
left_cov: A LinearOperator object. The left Kronecker factor "covariance".
|
||||||
right_cov: The right Kronecker factor "covariance".
|
right_cov: A LinearOperator object. The right Kronecker factor "covariance".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The computed scalar constant pi for these Kronecker Factors (as a Tensor).
|
The computed scalar constant pi for these Kronecker Factors (as a Tensor).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _trace(cov):
|
|
||||||
if len(cov.shape) == 1:
|
|
||||||
# Diagonal matrix.
|
|
||||||
return math_ops.reduce_sum(cov)
|
|
||||||
elif len(cov.shape) == 2:
|
|
||||||
# Full matrix.
|
|
||||||
return math_ops.trace(cov)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"What's the trace of a Tensor of rank %d?" % len(cov.shape))
|
|
||||||
|
|
||||||
# Instead of dividing by the dim of the norm, we multiply by the dim of the
|
# Instead of dividing by the dim of the norm, we multiply by the dim of the
|
||||||
# other norm. This works out the same in the ratio.
|
# other norm. This works out the same in the ratio.
|
||||||
left_norm = _trace(left_cov) * right_cov.shape.as_list()[0]
|
left_norm = left_cov.trace() * int(right_cov.domain_dimension)
|
||||||
right_norm = _trace(right_cov) * left_cov.shape.as_list()[0]
|
right_norm = right_cov.trace() * int(left_cov.domain_dimension)
|
||||||
return math_ops.sqrt(left_norm / right_norm)
|
return math_ops.sqrt(left_norm / right_norm)
|
||||||
|
|
||||||
|
|
||||||
@ -188,6 +176,16 @@ class FisherBlock(object):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def register_cholesky(self):
|
||||||
|
"""Registers a Cholesky factor to be computed by the block."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def register_cholesky_inverse(self):
|
||||||
|
"""Registers an inverse Cholesky factor to be computed by the block."""
|
||||||
|
pass
|
||||||
|
|
||||||
def register_inverse(self):
|
def register_inverse(self):
|
||||||
"""Registers a matrix inverse to be computed by the block."""
|
"""Registers a matrix inverse to be computed by the block."""
|
||||||
self.register_matpower(-1)
|
self.register_matpower(-1)
|
||||||
@ -228,6 +226,33 @@ class FisherBlock(object):
|
|||||||
"""
|
"""
|
||||||
return self.multiply_matpower(vector, 1)
|
return self.multiply_matpower(vector, 1)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def multiply_cholesky(self, vector, transpose=False):
|
||||||
|
"""Multiplies the vector by the (damped) Cholesky-factor of the block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
|
||||||
|
transpose: Bool. If true the Cholesky factor is transposed before
|
||||||
|
multiplying the vector. (Default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The vector left-multiplied by the (damped) Cholesky-factor of the block.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def multiply_cholesky_inverse(self, vector, transpose=False):
|
||||||
|
"""Multiplies vector by the (damped) inverse Cholesky-factor of the block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
|
||||||
|
transpose: Bool. If true the Cholesky factor inverse is transposed
|
||||||
|
before multiplying the vector. (Default: False)
|
||||||
|
Returns:
|
||||||
|
Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def tensors_to_compute_grads(self):
|
def tensors_to_compute_grads(self):
|
||||||
"""Returns the Tensor(s) with respect to which this FisherBlock needs grads.
|
"""Returns the Tensor(s) with respect to which this FisherBlock needs grads.
|
||||||
@ -275,15 +300,32 @@ class FullFB(FisherBlock):
|
|||||||
def register_matpower(self, exp):
|
def register_matpower(self, exp):
|
||||||
self._factor.register_matpower(exp, self._damping_func)
|
self._factor.register_matpower(exp, self._damping_func)
|
||||||
|
|
||||||
def multiply_matpower(self, vector, exp):
|
def register_cholesky(self):
|
||||||
|
self._factor.register_cholesky(self._damping_func)
|
||||||
|
|
||||||
|
def register_cholesky_inverse(self):
|
||||||
|
self._factor.register_cholesky_inverse(self._damping_func)
|
||||||
|
|
||||||
|
def _multiply_matrix(self, matrix, vector, transpose=False):
|
||||||
vector_flat = utils.tensors_to_column(vector)
|
vector_flat = utils.tensors_to_column(vector)
|
||||||
out_flat = self._factor.left_multiply_matpower(
|
out_flat = matrix.matmul(vector_flat, adjoint=transpose)
|
||||||
vector_flat, exp, self._damping_func)
|
|
||||||
return utils.column_to_tensors(vector, out_flat)
|
return utils.column_to_tensors(vector, out_flat)
|
||||||
|
|
||||||
|
def multiply_matpower(self, vector, exp):
|
||||||
|
matrix = self._factor.get_matpower(exp, self._damping_func)
|
||||||
|
return self._multiply_matrix(matrix, vector)
|
||||||
|
|
||||||
|
def multiply_cholesky(self, vector, transpose=False):
|
||||||
|
matrix = self._factor.get_cholesky(self._damping_func)
|
||||||
|
return self._multiply_matrix(matrix, vector, transpose=transpose)
|
||||||
|
|
||||||
|
def multiply_cholesky_inverse(self, vector, transpose=False):
|
||||||
|
matrix = self._factor.get_cholesky_inverse(self._damping_func)
|
||||||
|
return self._multiply_matrix(matrix, vector, transpose=transpose)
|
||||||
|
|
||||||
def full_fisher_block(self):
|
def full_fisher_block(self):
|
||||||
"""Explicitly constructs the full Fisher block."""
|
"""Explicitly constructs the full Fisher block."""
|
||||||
return self._factor.get_cov()
|
return self._factor.get_cov_as_linear_operator().to_dense()
|
||||||
|
|
||||||
def tensors_to_compute_grads(self):
|
def tensors_to_compute_grads(self):
|
||||||
return self._params
|
return self._params
|
||||||
@ -305,7 +347,47 @@ class FullFB(FisherBlock):
|
|||||||
return math_ops.reduce_sum(self._batch_sizes)
|
return math_ops.reduce_sum(self._batch_sizes)
|
||||||
|
|
||||||
|
|
||||||
class NaiveDiagonalFB(FisherBlock):
|
@six.add_metaclass(abc.ABCMeta)
|
||||||
|
class DiagonalFB(FisherBlock):
|
||||||
|
"""A base class for FisherBlocks that use diagonal approximations."""
|
||||||
|
|
||||||
|
def register_matpower(self, exp):
|
||||||
|
# Not needed for this. Matrix powers are computed on demand in the
|
||||||
|
# diagonal case
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_cholesky(self):
|
||||||
|
# Not needed for this. Cholesky's are computed on demand in the
|
||||||
|
# diagonal case
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_cholesky_inverse(self):
|
||||||
|
# Not needed for this. Cholesky inverses's are computed on demand in the
|
||||||
|
# diagonal case
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _multiply_matrix(self, matrix, vector):
|
||||||
|
vector_flat = utils.tensors_to_column(vector)
|
||||||
|
out_flat = matrix.matmul(vector_flat)
|
||||||
|
return utils.column_to_tensors(vector, out_flat)
|
||||||
|
|
||||||
|
def multiply_matpower(self, vector, exp):
|
||||||
|
matrix = self._factor.get_matpower(exp, self._damping_func)
|
||||||
|
return self._multiply_matrix(matrix, vector)
|
||||||
|
|
||||||
|
def multiply_cholesky(self, vector, transpose=False):
|
||||||
|
matrix = self._factor.get_cholesky(self._damping_func)
|
||||||
|
return self._multiply_matrix(matrix, vector)
|
||||||
|
|
||||||
|
def multiply_cholesky_inverse(self, vector, transpose=False):
|
||||||
|
matrix = self._factor.get_cholesky_inverse(self._damping_func)
|
||||||
|
return self._multiply_matrix(matrix, vector)
|
||||||
|
|
||||||
|
def full_fisher_block(self):
|
||||||
|
return self._factor.get_cov_as_linear_operator().to_dense()
|
||||||
|
|
||||||
|
|
||||||
|
class NaiveDiagonalFB(DiagonalFB):
|
||||||
"""FisherBlock using a diagonal matrix approximation.
|
"""FisherBlock using a diagonal matrix approximation.
|
||||||
|
|
||||||
This type of approximation is generically applicable but quite primitive.
|
This type of approximation is generically applicable but quite primitive.
|
||||||
@ -333,20 +415,6 @@ class NaiveDiagonalFB(FisherBlock):
|
|||||||
self._factor = self._layer_collection.make_or_get_factor(
|
self._factor = self._layer_collection.make_or_get_factor(
|
||||||
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
|
fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
|
||||||
|
|
||||||
def register_matpower(self, exp):
|
|
||||||
# Not needed for this. Matrix powers are computed on demand in the
|
|
||||||
# diagonal case
|
|
||||||
pass
|
|
||||||
|
|
||||||
def multiply_matpower(self, vector, exp):
|
|
||||||
vector_flat = utils.tensors_to_column(vector)
|
|
||||||
out_flat = self._factor.left_multiply_matpower(
|
|
||||||
vector_flat, exp, self._damping_func)
|
|
||||||
return utils.column_to_tensors(vector, out_flat)
|
|
||||||
|
|
||||||
def full_fisher_block(self):
|
|
||||||
return self._factor.get_cov()
|
|
||||||
|
|
||||||
def tensors_to_compute_grads(self):
|
def tensors_to_compute_grads(self):
|
||||||
return self._params
|
return self._params
|
||||||
|
|
||||||
@ -452,7 +520,7 @@ class InputOutputMultiTower(object):
|
|||||||
return self.__outputs
|
return self.__outputs
|
||||||
|
|
||||||
|
|
||||||
class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
|
class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
|
||||||
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
|
"""FisherBlock for fully-connected (dense) layers using a diagonal approx.
|
||||||
|
|
||||||
Estimates the Fisher Information matrix's diagonal entries for a fully
|
Estimates the Fisher Information matrix's diagonal entries for a fully
|
||||||
@ -497,32 +565,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock):
|
|||||||
|
|
||||||
self._damping_func = _package_func(lambda: damping, (damping,))
|
self._damping_func = _package_func(lambda: damping, (damping,))
|
||||||
|
|
||||||
def register_matpower(self, exp):
|
|
||||||
# Not needed for this. Matrix powers are computed on demand in the
|
|
||||||
# diagonal case
|
|
||||||
pass
|
|
||||||
|
|
||||||
def multiply_matpower(self, vector, exp):
|
class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
|
||||||
"""Multiplies the vector by the (damped) matrix-power of the block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape
|
|
||||||
[input_size, output_size] corresponding to layer's weights. If not, a
|
|
||||||
2-tuple of the former and a Tensor of shape [output_size] corresponding
|
|
||||||
to the layer's bias.
|
|
||||||
exp: A scalar representing the power to raise the block before multiplying
|
|
||||||
it by the vector.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The vector left-multiplied by the (damped) matrix-power of the block.
|
|
||||||
"""
|
|
||||||
reshaped_vec = utils.layer_params_to_mat2d(vector)
|
|
||||||
reshaped_out = self._factor.left_multiply_matpower(
|
|
||||||
reshaped_vec, exp, self._damping_func)
|
|
||||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
|
|
||||||
"""FisherBlock for 2-D convolutional layers using a diagonal approx.
|
"""FisherBlock for 2-D convolutional layers using a diagonal approx.
|
||||||
|
|
||||||
Estimates the Fisher Information matrix's diagonal entries for a convolutional
|
Estimates the Fisher Information matrix's diagonal entries for a convolutional
|
||||||
@ -621,17 +665,6 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock):
|
|||||||
self._num_locations)
|
self._num_locations)
|
||||||
self._damping_func = _package_func(damping_func, damping_id)
|
self._damping_func = _package_func(damping_func, damping_id)
|
||||||
|
|
||||||
def register_matpower(self, exp):
|
|
||||||
# Not needed for this. Matrix powers are computed on demand in the
|
|
||||||
# diagonal case
|
|
||||||
pass
|
|
||||||
|
|
||||||
def multiply_matpower(self, vector, exp):
|
|
||||||
reshaped_vect = utils.layer_params_to_mat2d(vector)
|
|
||||||
reshaped_out = self._factor.left_multiply_matpower(
|
|
||||||
reshaped_vect, exp, self._damping_func)
|
|
||||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
|
||||||
|
|
||||||
|
|
||||||
class KroneckerProductFB(FisherBlock):
|
class KroneckerProductFB(FisherBlock):
|
||||||
"""A base class for blocks with separate input and output Kronecker factors.
|
"""A base class for blocks with separate input and output Kronecker factors.
|
||||||
@ -651,9 +684,10 @@ class KroneckerProductFB(FisherBlock):
|
|||||||
else:
|
else:
|
||||||
maybe_normalized_damping = damping
|
maybe_normalized_damping = damping
|
||||||
|
|
||||||
return compute_pi_adjusted_damping(self._input_factor.get_cov(),
|
return compute_pi_adjusted_damping(
|
||||||
self._output_factor.get_cov(),
|
self._input_factor.get_cov_as_linear_operator(),
|
||||||
maybe_normalized_damping**0.5)
|
self._output_factor.get_cov_as_linear_operator(),
|
||||||
|
maybe_normalized_damping**0.5)
|
||||||
|
|
||||||
if normalization is not None:
|
if normalization is not None:
|
||||||
damping_id = ("compute_pi_adjusted_damping",
|
damping_id = ("compute_pi_adjusted_damping",
|
||||||
@ -675,6 +709,14 @@ class KroneckerProductFB(FisherBlock):
|
|||||||
self._input_factor.register_matpower(exp, self._input_damping_func)
|
self._input_factor.register_matpower(exp, self._input_damping_func)
|
||||||
self._output_factor.register_matpower(exp, self._output_damping_func)
|
self._output_factor.register_matpower(exp, self._output_damping_func)
|
||||||
|
|
||||||
|
def register_cholesky(self):
|
||||||
|
self._input_factor.register_cholesky(self._input_damping_func)
|
||||||
|
self._output_factor.register_cholesky(self._output_damping_func)
|
||||||
|
|
||||||
|
def register_cholesky_inverse(self):
|
||||||
|
self._input_factor.register_cholesky_inverse(self._input_damping_func)
|
||||||
|
self._output_factor.register_cholesky_inverse(self._output_damping_func)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _renorm_coeff(self):
|
def _renorm_coeff(self):
|
||||||
"""Kronecker factor multiplier coefficient.
|
"""Kronecker factor multiplier coefficient.
|
||||||
@ -687,17 +729,47 @@ class KroneckerProductFB(FisherBlock):
|
|||||||
"""
|
"""
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def multiply_matpower(self, vector, exp):
|
def _multiply_factored_matrix(self, left_factor, right_factor, vector,
|
||||||
|
extra_scale=1.0, transpose_left=False,
|
||||||
|
transpose_right=False):
|
||||||
reshaped_vector = utils.layer_params_to_mat2d(vector)
|
reshaped_vector = utils.layer_params_to_mat2d(vector)
|
||||||
reshaped_out = self._output_factor.right_multiply_matpower(
|
reshaped_out = right_factor.matmul_right(reshaped_vector,
|
||||||
reshaped_vector, exp, self._output_damping_func)
|
adjoint=transpose_right)
|
||||||
reshaped_out = self._input_factor.left_multiply_matpower(
|
reshaped_out = left_factor.matmul(reshaped_out,
|
||||||
reshaped_out, exp, self._input_damping_func)
|
adjoint=transpose_left)
|
||||||
if self._renorm_coeff != 1.0:
|
if extra_scale != 1.0:
|
||||||
renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype)
|
reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
|
||||||
reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype)
|
|
||||||
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
return utils.mat2d_to_layer_params(vector, reshaped_out)
|
||||||
|
|
||||||
|
def multiply_matpower(self, vector, exp):
|
||||||
|
left_factor = self._input_factor.get_matpower(
|
||||||
|
exp, self._input_damping_func)
|
||||||
|
right_factor = self._output_factor.get_matpower(
|
||||||
|
exp, self._output_damping_func)
|
||||||
|
extra_scale = float(self._renorm_coeff)**exp
|
||||||
|
return self._multiply_factored_matrix(left_factor, right_factor, vector,
|
||||||
|
extra_scale=extra_scale)
|
||||||
|
|
||||||
|
def multiply_cholesky(self, vector, transpose=False):
|
||||||
|
left_factor = self._input_factor.get_cholesky(self._input_damping_func)
|
||||||
|
right_factor = self._output_factor.get_cholesky(self._output_damping_func)
|
||||||
|
extra_scale = float(self._renorm_coeff)**0.5
|
||||||
|
return self._multiply_factored_matrix(left_factor, right_factor, vector,
|
||||||
|
extra_scale=extra_scale,
|
||||||
|
transpose_left=transpose,
|
||||||
|
transpose_right=not transpose)
|
||||||
|
|
||||||
|
def multiply_cholesky_inverse(self, vector, transpose=False):
|
||||||
|
left_factor = self._input_factor.get_cholesky_inverse(
|
||||||
|
self._input_damping_func)
|
||||||
|
right_factor = self._output_factor.get_cholesky_inverse(
|
||||||
|
self._output_damping_func)
|
||||||
|
extra_scale = float(self._renorm_coeff)**-0.5
|
||||||
|
return self._multiply_factored_matrix(left_factor, right_factor, vector,
|
||||||
|
extra_scale=extra_scale,
|
||||||
|
transpose_left=transpose,
|
||||||
|
transpose_right=not transpose)
|
||||||
|
|
||||||
def full_fisher_block(self):
|
def full_fisher_block(self):
|
||||||
"""Explicitly constructs the full Fisher block.
|
"""Explicitly constructs the full Fisher block.
|
||||||
|
|
||||||
@ -706,8 +778,8 @@ class KroneckerProductFB(FisherBlock):
|
|||||||
Returns:
|
Returns:
|
||||||
The full Fisher block.
|
The full Fisher block.
|
||||||
"""
|
"""
|
||||||
left_factor = self._input_factor.get_cov()
|
left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
|
||||||
right_factor = self._output_factor.get_cov()
|
right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
|
||||||
return self._renorm_coeff * utils.kronecker_product(left_factor,
|
return self._renorm_coeff * utils.kronecker_product(left_factor,
|
||||||
right_factor)
|
right_factor)
|
||||||
|
|
||||||
@ -796,7 +868,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
|
|||||||
|
|
||||||
|
|
||||||
class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
|
class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
|
||||||
"""FisherBlock for convolutional layers using the basic KFC approx.
|
r"""FisherBlock for convolutional layers using the basic KFC approx.
|
||||||
|
|
||||||
Estimates the Fisher Information matrix's blog for a convolutional
|
Estimates the Fisher Information matrix's blog for a convolutional
|
||||||
layer.
|
layer.
|
||||||
@ -945,10 +1017,10 @@ class DepthwiseConvDiagonalFB(ConvDiagonalFB):
|
|||||||
self._filter_shape = (filter_height, filter_width, in_channels,
|
self._filter_shape = (filter_height, filter_width, in_channels,
|
||||||
in_channels * channel_multiplier)
|
in_channels * channel_multiplier)
|
||||||
|
|
||||||
def multiply_matpower(self, vector, exp):
|
def _multiply_matrix(self, matrix, vector):
|
||||||
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
|
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
|
||||||
conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower(
|
conv2d_result = super(
|
||||||
conv2d_vector, exp)
|
DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
|
||||||
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
|
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
|
||||||
|
|
||||||
|
|
||||||
@ -1016,10 +1088,14 @@ class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
|
|||||||
self._filter_shape = (filter_height, filter_width, in_channels,
|
self._filter_shape = (filter_height, filter_width, in_channels,
|
||||||
in_channels * channel_multiplier)
|
in_channels * channel_multiplier)
|
||||||
|
|
||||||
def multiply_matpower(self, vector, exp):
|
def _multiply_factored_matrix(self, left_factor, right_factor, vector,
|
||||||
|
extra_scale=1.0, transpose_left=False,
|
||||||
|
transpose_right=False):
|
||||||
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
|
conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
|
||||||
conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower(
|
conv2d_result = super(
|
||||||
conv2d_vector, exp)
|
DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
|
||||||
|
left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
|
||||||
|
transpose_left=transpose_left, transpose_right=transpose_right)
|
||||||
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
|
return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
|
||||||
|
|
||||||
|
|
||||||
@ -1664,3 +1740,12 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
|
|||||||
return utils.mat2d_to_layer_params(vector, Z)
|
return utils.mat2d_to_layer_params(vector, Z)
|
||||||
|
|
||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
|
|
||||||
|
def multiply_cholesky(self, vector):
|
||||||
|
raise NotImplementedError("FullyConnectedSeriesFB does not support "
|
||||||
|
"Cholesky computations.")
|
||||||
|
|
||||||
|
def multiply_cholesky_inverse(self, vector):
|
||||||
|
raise NotImplementedError("FullyConnectedSeriesFB does not support "
|
||||||
|
"Cholesky computations.")
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ import contextlib
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.contrib.kfac.python.ops import linear_operator as lo
|
||||||
from tensorflow.contrib.kfac.python.ops import utils
|
from tensorflow.contrib.kfac.python.ops import utils
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops as tf_ops
|
from tensorflow.python.framework import ops as tf_ops
|
||||||
@ -399,7 +400,7 @@ class FisherFactor(object):
|
|||||||
the cov update.
|
the cov update.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor of same shape as self.get_cov_var().
|
Tensor of same shape as self.get_cov().
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -448,78 +449,43 @@ class FisherFactor(object):
|
|||||||
"""Create and return update ops corresponding to registered computations."""
|
"""Create and return update ops corresponding to registered computations."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_cov(self):
|
def get_cov(self):
|
||||||
"""Get full covariance matrix.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of shape [n, n]. Represents all parameter-parameter correlations
|
|
||||||
captured by this FisherFactor.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_cov_var(self):
|
|
||||||
"""Get variable backing this FisherFactor.
|
|
||||||
|
|
||||||
May or may not be the same as self.get_cov()
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Variable of shape self._cov_shape.
|
|
||||||
"""
|
|
||||||
return self._cov
|
return self._cov
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def left_multiply_matpower(self, x, exp, damping_func):
|
def get_cov_as_linear_operator(self):
|
||||||
"""Left multiplies 'x' by matrix power of this factor (w/ damping applied).
|
|
||||||
|
|
||||||
This calculation is essentially:
|
|
||||||
(C + damping * I)**exp * x
|
|
||||||
where * is matrix-multiplication, ** is matrix power, I is the identity
|
|
||||||
matrix, and C is the matrix represented by this factor.
|
|
||||||
|
|
||||||
x can represent either a matrix or a vector. For some factors, 'x' might
|
|
||||||
represent a vector but actually be stored as a 2D matrix for convenience.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Tensor. Represents a single vector. Shape depends on implementation.
|
|
||||||
exp: float. The matrix exponent to use.
|
|
||||||
damping_func: A function that computes a 0-D Tensor or a float which will
|
|
||||||
be the damping value used. i.e. damping = damping_func().
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of same shape as 'x' representing the result of the multiplication.
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def right_multiply_matpower(self, x, exp, damping_func):
|
def register_matpower(self, exp, damping_func):
|
||||||
"""Right multiplies 'x' by matrix power of this factor (w/ damping applied).
|
pass
|
||||||
|
|
||||||
This calculation is essentially:
|
@abc.abstractmethod
|
||||||
x * (C + damping * I)**exp
|
def register_cholesky(self, damping_func):
|
||||||
where * is matrix-multiplication, ** is matrix power, I is the identity
|
pass
|
||||||
matrix, and C is the matrix represented by this factor.
|
|
||||||
|
|
||||||
Unlike left_multiply_matpower, x will always be a matrix.
|
@abc.abstractmethod
|
||||||
|
def register_cholesky_inverse(self, damping_func):
|
||||||
|
pass
|
||||||
|
|
||||||
Args:
|
@abc.abstractmethod
|
||||||
x: Tensor. Represents a single vector. Shape depends on implementation.
|
def get_matpower(self, exp, damping_func):
|
||||||
exp: float. The matrix exponent to use.
|
pass
|
||||||
damping_func: A function that computes a 0-D Tensor or a float which will
|
|
||||||
be the damping value used. i.e. damping = damping_func().
|
|
||||||
|
|
||||||
Returns:
|
@abc.abstractmethod
|
||||||
Tensor of same shape as 'x' representing the result of the multiplication.
|
def get_cholesky(self, damping_func):
|
||||||
"""
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_cholesky_inverse(self, damping_func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InverseProvidingFactor(FisherFactor):
|
class DenseSquareMatrixFactor(FisherFactor):
|
||||||
"""Base class for FisherFactors that maintain inverses explicitly.
|
"""Base class for FisherFactors that are stored as dense square matrices.
|
||||||
|
|
||||||
This class explicitly calculates and stores inverses of covariance matrices
|
This class explicitly calculates and stores inverses of their `cov` matrices,
|
||||||
provided by the underlying FisherFactor implementation. It is assumed that
|
which must be square dense matrices.
|
||||||
vectors can be represented as 2-D matrices.
|
|
||||||
|
|
||||||
Subclasses must implement the _compute_new_cov method, and the _var_scope and
|
Subclasses must implement the _compute_new_cov method, and the _var_scope and
|
||||||
_cov_shape properties.
|
_cov_shape properties.
|
||||||
@ -538,7 +504,19 @@ class InverseProvidingFactor(FisherFactor):
|
|||||||
self._eigendecomp = None
|
self._eigendecomp = None
|
||||||
self._damping_funcs_by_id = {} # {hashable: lambda}
|
self._damping_funcs_by_id = {} # {hashable: lambda}
|
||||||
|
|
||||||
super(InverseProvidingFactor, self).__init__()
|
self._cholesky_registrations = set() # { hashable }
|
||||||
|
self._cholesky_inverse_registrations = set() # { hashable }
|
||||||
|
|
||||||
|
self._cholesky_by_damping = {} # { hashable: variable }
|
||||||
|
self._cholesky_inverse_by_damping = {} # { hashable: variable }
|
||||||
|
|
||||||
|
super(DenseSquareMatrixFactor, self).__init__()
|
||||||
|
|
||||||
|
def get_cov_as_linear_operator(self):
|
||||||
|
assert self.get_cov().shape.ndims == 2
|
||||||
|
return lo.LinearOperatorFullMatrix(self.get_cov(),
|
||||||
|
is_self_adjoint=True,
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
def _register_damping(self, damping_func):
|
def _register_damping(self, damping_func):
|
||||||
damping_id = graph_func_to_id(damping_func)
|
damping_id = graph_func_to_id(damping_func)
|
||||||
@ -563,8 +541,6 @@ class InverseProvidingFactor(FisherFactor):
|
|||||||
be the damping value used. i.e. damping = damping_func().
|
be the damping value used. i.e. damping = damping_func().
|
||||||
"""
|
"""
|
||||||
if exp == 1.0:
|
if exp == 1.0:
|
||||||
# We don't register these. The user shouldn't even be calling this
|
|
||||||
# function with exp = 1.0.
|
|
||||||
return
|
return
|
||||||
|
|
||||||
damping_id = self._register_damping(damping_func)
|
damping_id = self._register_damping(damping_func)
|
||||||
@ -572,6 +548,38 @@ class InverseProvidingFactor(FisherFactor):
|
|||||||
if (exp, damping_id) not in self._matpower_registrations:
|
if (exp, damping_id) not in self._matpower_registrations:
|
||||||
self._matpower_registrations.add((exp, damping_id))
|
self._matpower_registrations.add((exp, damping_id))
|
||||||
|
|
||||||
|
def register_cholesky(self, damping_func):
|
||||||
|
"""Registers a Cholesky factor to be maintained and served on demand.
|
||||||
|
|
||||||
|
This creates a variable and signals make_inverse_update_ops to make the
|
||||||
|
corresponding update op. The variable can be read via the method
|
||||||
|
get_cholesky.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
damping_func: A function that computes a 0-D Tensor or a float which will
|
||||||
|
be the damping value used. i.e. damping = damping_func().
|
||||||
|
"""
|
||||||
|
damping_id = self._register_damping(damping_func)
|
||||||
|
|
||||||
|
if damping_id not in self._cholesky_registrations:
|
||||||
|
self._cholesky_registrations.add(damping_id)
|
||||||
|
|
||||||
|
def register_cholesky_inverse(self, damping_func):
|
||||||
|
"""Registers an inverse Cholesky factor to be maintained/served on demand.
|
||||||
|
|
||||||
|
This creates a variable and signals make_inverse_update_ops to make the
|
||||||
|
corresponding update op. The variable can be read via the method
|
||||||
|
get_cholesky_inverse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
damping_func: A function that computes a 0-D Tensor or a float which will
|
||||||
|
be the damping value used. i.e. damping = damping_func().
|
||||||
|
"""
|
||||||
|
damping_id = self._register_damping(damping_func)
|
||||||
|
|
||||||
|
if damping_id not in self._cholesky_inverse_registrations:
|
||||||
|
self._cholesky_inverse_registrations.add(damping_id)
|
||||||
|
|
||||||
def instantiate_inv_variables(self):
|
def instantiate_inv_variables(self):
|
||||||
"""Makes the internal "inverse" variable(s)."""
|
"""Makes the internal "inverse" variable(s)."""
|
||||||
|
|
||||||
@ -589,6 +597,32 @@ class InverseProvidingFactor(FisherFactor):
|
|||||||
assert (exp, damping_id) not in self._matpower_by_exp_and_damping
|
assert (exp, damping_id) not in self._matpower_by_exp_and_damping
|
||||||
self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
|
self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
|
||||||
|
|
||||||
|
for damping_id in self._cholesky_registrations:
|
||||||
|
damping_func = self._damping_funcs_by_id[damping_id]
|
||||||
|
damping_string = graph_func_to_string(damping_func)
|
||||||
|
with variable_scope.variable_scope(self._var_scope):
|
||||||
|
chol = variable_scope.get_variable(
|
||||||
|
"cholesky_damp{}".format(damping_string),
|
||||||
|
initializer=inverse_initializer,
|
||||||
|
shape=self._cov_shape,
|
||||||
|
trainable=False,
|
||||||
|
dtype=self._dtype)
|
||||||
|
assert damping_id not in self._cholesky_by_damping
|
||||||
|
self._cholesky_by_damping[damping_id] = chol
|
||||||
|
|
||||||
|
for damping_id in self._cholesky_inverse_registrations:
|
||||||
|
damping_func = self._damping_funcs_by_id[damping_id]
|
||||||
|
damping_string = graph_func_to_string(damping_func)
|
||||||
|
with variable_scope.variable_scope(self._var_scope):
|
||||||
|
cholinv = variable_scope.get_variable(
|
||||||
|
"cholesky_inverse_damp{}".format(damping_string),
|
||||||
|
initializer=inverse_initializer,
|
||||||
|
shape=self._cov_shape,
|
||||||
|
trainable=False,
|
||||||
|
dtype=self._dtype)
|
||||||
|
assert damping_id not in self._cholesky_inverse_by_damping
|
||||||
|
self._cholesky_inverse_by_damping[damping_id] = cholinv
|
||||||
|
|
||||||
def make_inverse_update_ops(self):
|
def make_inverse_update_ops(self):
|
||||||
"""Create and return update ops corresponding to registered computations."""
|
"""Create and return update ops corresponding to registered computations."""
|
||||||
ops = []
|
ops = []
|
||||||
@ -606,7 +640,8 @@ class InverseProvidingFactor(FisherFactor):
|
|||||||
|
|
||||||
# We precompute these so we don't need to evaluate them multiple times (for
|
# We precompute these so we don't need to evaluate them multiple times (for
|
||||||
# each matrix power that uses them)
|
# each matrix power that uses them)
|
||||||
damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]()
|
damping_value_by_id = {damping_id: math_ops.cast(
|
||||||
|
self._damping_funcs_by_id[damping_id](), self._dtype)
|
||||||
for damping_id in self._damping_funcs_by_id}
|
for damping_id in self._damping_funcs_by_id}
|
||||||
|
|
||||||
if use_eig:
|
if use_eig:
|
||||||
@ -627,29 +662,91 @@ class InverseProvidingFactor(FisherFactor):
|
|||||||
self._matpower_by_exp_and_damping.items()):
|
self._matpower_by_exp_and_damping.items()):
|
||||||
assert exp == -1
|
assert exp == -1
|
||||||
damping = damping_value_by_id[damping_id]
|
damping = damping_value_by_id[damping_id]
|
||||||
ops.append(matpower.assign(utils.posdef_inv(self._cov, damping)))
|
ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
|
||||||
|
|
||||||
|
# TODO(b/77902055): If inverses are being computed with Cholesky's
|
||||||
|
# we can share the work. Instead this code currently just computes the
|
||||||
|
# Cholesky a second time. It does at least share work between requests for
|
||||||
|
# Cholesky's and Cholesky inverses with the same damping id.
|
||||||
|
for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
|
||||||
|
cholesky_ops = []
|
||||||
|
|
||||||
|
damping = damping_value_by_id[damping_id]
|
||||||
|
cholesky_value = utils.cholesky(self.get_cov(), damping)
|
||||||
|
|
||||||
|
if damping_id in self._cholesky_by_damping:
|
||||||
|
cholesky = self._cholesky_by_damping[damping_id]
|
||||||
|
cholesky_ops.append(cholesky.assign(cholesky_value))
|
||||||
|
|
||||||
|
identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
|
||||||
|
dtype=cholesky_value.dtype)
|
||||||
|
cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
|
||||||
|
identity)
|
||||||
|
cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
|
||||||
|
|
||||||
|
ops.append(control_flow_ops.group(*cholesky_ops))
|
||||||
|
|
||||||
|
for damping_id, cholesky in self._cholesky_by_damping.items():
|
||||||
|
if damping_id not in self._cholesky_inverse_by_damping:
|
||||||
|
damping = damping_value_by_id[damping_id]
|
||||||
|
cholesky_value = utils.cholesky(self.get_cov(), damping)
|
||||||
|
ops.append(cholesky.assign(cholesky_value))
|
||||||
|
|
||||||
self._eigendecomp = False
|
self._eigendecomp = False
|
||||||
return ops
|
return ops
|
||||||
|
|
||||||
def get_inverse(self, damping_func):
|
def get_inverse(self, damping_func):
|
||||||
# Just for backwards compatibility of some old code and tests
|
# Just for backwards compatibility of some old code and tests
|
||||||
damping_id = graph_func_to_id(damping_func)
|
return self.get_matpower(-1, damping_func)
|
||||||
return self._matpower_by_exp_and_damping[(-1, damping_id)]
|
|
||||||
|
|
||||||
def get_matpower(self, exp, damping_func):
|
def get_matpower(self, exp, damping_func):
|
||||||
|
# Note that this function returns a variable which gets updated by the
|
||||||
|
# inverse ops. It may be stale / inconsistent with the latest value of
|
||||||
|
# get_cov().
|
||||||
|
if exp != 1:
|
||||||
|
damping_id = graph_func_to_id(damping_func)
|
||||||
|
matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
|
||||||
|
else:
|
||||||
|
matpower = self.get_cov()
|
||||||
|
identity = linalg_ops.eye(matpower.shape.as_list()[0],
|
||||||
|
dtype=matpower.dtype)
|
||||||
|
matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
|
||||||
|
|
||||||
|
assert matpower.shape.ndims == 2
|
||||||
|
return lo.LinearOperatorFullMatrix(matpower,
|
||||||
|
is_non_singular=True,
|
||||||
|
is_self_adjoint=True,
|
||||||
|
is_positive_definite=True,
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
|
def get_cholesky(self, damping_func):
|
||||||
# Note that this function returns a variable which gets updated by the
|
# Note that this function returns a variable which gets updated by the
|
||||||
# inverse ops. It may be stale / inconsistent with the latest value of
|
# inverse ops. It may be stale / inconsistent with the latest value of
|
||||||
# get_cov().
|
# get_cov().
|
||||||
damping_id = graph_func_to_id(damping_func)
|
damping_id = graph_func_to_id(damping_func)
|
||||||
return self._matpower_by_exp_and_damping[(exp, damping_id)]
|
cholesky = self._cholesky_by_damping[damping_id]
|
||||||
|
assert cholesky.shape.ndims == 2
|
||||||
|
return lo.LinearOperatorFullMatrix(cholesky,
|
||||||
|
is_non_singular=True,
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
|
def get_cholesky_inverse(self, damping_func):
|
||||||
|
# Note that this function returns a variable which gets updated by the
|
||||||
|
# inverse ops. It may be stale / inconsistent with the latest value of
|
||||||
|
# get_cov().
|
||||||
|
damping_id = graph_func_to_id(damping_func)
|
||||||
|
cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
|
||||||
|
assert cholesky_inv.shape.ndims == 2
|
||||||
|
return lo.LinearOperatorFullMatrix(cholesky_inv,
|
||||||
|
is_non_singular=True,
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
def get_eigendecomp(self):
|
def get_eigendecomp(self):
|
||||||
"""Creates or retrieves eigendecomposition of self._cov."""
|
"""Creates or retrieves eigendecomposition of self._cov."""
|
||||||
# Unlike get_matpower this doesn't retrieve a stored variable, but instead
|
# Unlike get_matpower this doesn't retrieve a stored variable, but instead
|
||||||
# always computes a fresh version from the current value of get_cov().
|
# always computes a fresh version from the current value of get_cov().
|
||||||
if not self._eigendecomp:
|
if not self._eigendecomp:
|
||||||
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)
|
eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
|
||||||
|
|
||||||
# The matrix self._cov is positive semidefinite by construction, but the
|
# The matrix self._cov is positive semidefinite by construction, but the
|
||||||
# numerical eigenvalues could be negative due to numerical errors, so here
|
# numerical eigenvalues could be negative due to numerical errors, so here
|
||||||
@ -660,45 +757,8 @@ class InverseProvidingFactor(FisherFactor):
|
|||||||
|
|
||||||
return self._eigendecomp
|
return self._eigendecomp
|
||||||
|
|
||||||
def get_cov(self):
|
|
||||||
# Variable contains full covariance matrix.
|
|
||||||
return self.get_cov_var()
|
|
||||||
|
|
||||||
def left_multiply_matpower(self, x, exp, damping_func):
|
class FullFactor(DenseSquareMatrixFactor):
|
||||||
if isinstance(x, tf_ops.IndexedSlices):
|
|
||||||
raise ValueError("Left-multiply not yet supported for IndexedSlices.")
|
|
||||||
|
|
||||||
if x.shape.ndims != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
|
|
||||||
% (x,))
|
|
||||||
|
|
||||||
if exp == 1:
|
|
||||||
return math_ops.matmul(self.get_cov(), x) + damping_func() * x
|
|
||||||
|
|
||||||
return math_ops.matmul(self.get_matpower(exp, damping_func), x)
|
|
||||||
|
|
||||||
def right_multiply_matpower(self, x, exp, damping_func):
|
|
||||||
if isinstance(x, tf_ops.IndexedSlices):
|
|
||||||
if exp == 1:
|
|
||||||
n = self.get_cov().shape[0]
|
|
||||||
damped_cov = self.get_cov() + damping_func() * array_ops.eye(n)
|
|
||||||
return utils.matmul_sparse_dense(x, damped_cov)
|
|
||||||
|
|
||||||
return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func))
|
|
||||||
|
|
||||||
if x.shape.ndims != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"InverseProvidingFactors apply to matrix-shaped vectors. Found: %s."
|
|
||||||
% (x,))
|
|
||||||
|
|
||||||
if exp == 1:
|
|
||||||
return math_ops.matmul(x, self.get_cov()) + damping_func() * x
|
|
||||||
|
|
||||||
return math_ops.matmul(x, self.get_matpower(exp, damping_func))
|
|
||||||
|
|
||||||
|
|
||||||
class FullFactor(InverseProvidingFactor):
|
|
||||||
"""FisherFactor for a full matrix representation of the Fisher of a parameter.
|
"""FisherFactor for a full matrix representation of the Fisher of a parameter.
|
||||||
|
|
||||||
Note that this uses the naive "square the sum estimator", and so is applicable
|
Note that this uses the naive "square the sum estimator", and so is applicable
|
||||||
@ -757,42 +817,52 @@ class DiagonalFactor(FisherFactor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._damping_funcs_by_id = {} # { hashable: lambda }
|
|
||||||
super(DiagonalFactor, self).__init__()
|
super(DiagonalFactor, self).__init__()
|
||||||
|
|
||||||
|
def get_cov_as_linear_operator(self):
|
||||||
|
assert self._matrix_diagonal.shape.ndims == 1
|
||||||
|
return lo.LinearOperatorDiag(self._matrix_diagonal,
|
||||||
|
is_self_adjoint=True,
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _cov_initializer(self):
|
def _cov_initializer(self):
|
||||||
return diagonal_covariance_initializer
|
return diagonal_covariance_initializer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _matrix_diagonal(self):
|
||||||
|
return array_ops.reshape(self.get_cov(), [-1])
|
||||||
|
|
||||||
def make_inverse_update_ops(self):
|
def make_inverse_update_ops(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def instantiate_inv_variables(self):
|
def instantiate_inv_variables(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_cov(self):
|
|
||||||
# self.get_cov() could be any shape, but it must have one entry per
|
|
||||||
# parameter. Flatten it into a vector.
|
|
||||||
cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1])
|
|
||||||
return array_ops.diag(cov_diag_vec)
|
|
||||||
|
|
||||||
def left_multiply_matpower(self, x, exp, damping_func):
|
|
||||||
matpower = (self.get_cov_var() + damping_func())**exp
|
|
||||||
|
|
||||||
if isinstance(x, tf_ops.IndexedSlices):
|
|
||||||
return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x)
|
|
||||||
|
|
||||||
if x.shape != matpower.shape:
|
|
||||||
raise ValueError("x (%s) and cov (%s) must have same shape." %
|
|
||||||
(x, matpower))
|
|
||||||
return matpower * x
|
|
||||||
|
|
||||||
def right_multiply_matpower(self, x, exp, damping_func):
|
|
||||||
raise NotImplementedError("Only left-multiply is currently supported.")
|
|
||||||
|
|
||||||
def register_matpower(self, exp, damping_func):
|
def register_matpower(self, exp, damping_func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def register_cholesky(self, damping_func):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_cholesky_inverse(self, damping_func):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_matpower(self, exp, damping_func):
|
||||||
|
matpower_diagonal = (self._matrix_diagonal
|
||||||
|
+ math_ops.cast(damping_func(), self._dtype))**exp
|
||||||
|
return lo.LinearOperatorDiag(matpower_diagonal,
|
||||||
|
is_non_singular=True,
|
||||||
|
is_self_adjoint=True,
|
||||||
|
is_positive_definite=True,
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
|
def get_cholesky(self, damping_func):
|
||||||
|
return self.get_matpower(0.5, damping_func)
|
||||||
|
|
||||||
|
def get_cholesky_inverse(self, damping_func):
|
||||||
|
return self.get_matpower(-0.5, damping_func)
|
||||||
|
|
||||||
|
|
||||||
class NaiveDiagonalFactor(DiagonalFactor):
|
class NaiveDiagonalFactor(DiagonalFactor):
|
||||||
"""FisherFactor for a diagonal approximation of any type of param's Fisher.
|
"""FisherFactor for a diagonal approximation of any type of param's Fisher.
|
||||||
@ -1167,7 +1237,7 @@ class ConvDiagonalFactor(DiagonalFactor):
|
|||||||
return self._inputs[tower].device
|
return self._inputs[tower].device
|
||||||
|
|
||||||
|
|
||||||
class FullyConnectedKroneckerFactor(InverseProvidingFactor):
|
class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
|
||||||
"""Kronecker factor for the input or output side of a fully-connected layer.
|
"""Kronecker factor for the input or output side of a fully-connected layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1220,7 +1290,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor):
|
|||||||
return self._tensors[0][tower].device
|
return self._tensors[0][tower].device
|
||||||
|
|
||||||
|
|
||||||
class ConvInputKroneckerFactor(InverseProvidingFactor):
|
class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
|
||||||
r"""Kronecker factor for the input side of a convolutional layer.
|
r"""Kronecker factor for the input side of a convolutional layer.
|
||||||
|
|
||||||
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
|
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
|
||||||
@ -1384,7 +1454,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
|
|||||||
return self._inputs[tower].device
|
return self._inputs[tower].device
|
||||||
|
|
||||||
|
|
||||||
class ConvOutputKroneckerFactor(InverseProvidingFactor):
|
class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
|
||||||
r"""Kronecker factor for the output side of a convolutional layer.
|
r"""Kronecker factor for the output side of a convolutional layer.
|
||||||
|
|
||||||
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
|
Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
|
||||||
@ -1674,6 +1744,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
|
|||||||
psi_var) in self._option1quants_by_damping.items():
|
psi_var) in self._option1quants_by_damping.items():
|
||||||
|
|
||||||
damping = self._damping_funcs_by_id[damping_id]()
|
damping = self._damping_funcs_by_id[damping_id]()
|
||||||
|
damping = math_ops.cast(damping, self._dtype)
|
||||||
|
|
||||||
invsqrtC0 = math_ops.matmul(
|
invsqrtC0 = math_ops.matmul(
|
||||||
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
|
eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
|
||||||
@ -1702,6 +1773,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
|
|||||||
mu_var) in self._option2quants_by_damping.items():
|
mu_var) in self._option2quants_by_damping.items():
|
||||||
|
|
||||||
damping = self._damping_funcs_by_id[damping_id]()
|
damping = self._damping_funcs_by_id[damping_id]()
|
||||||
|
damping = math_ops.cast(damping, self._dtype)
|
||||||
|
|
||||||
# compute C0^(-1/2)
|
# compute C0^(-1/2)
|
||||||
invsqrtC0 = math_ops.matmul(
|
invsqrtC0 = math_ops.matmul(
|
||||||
|
95
tensorflow/contrib/kfac/python/ops/linear_operator.py
Normal file
95
tensorflow/contrib/kfac/python/ops/linear_operator.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""SmartMatrices definitions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.kfac.python.ops import utils
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.linalg import linalg
|
||||||
|
from tensorflow.python.ops.linalg import linalg_impl
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_util as lou
|
||||||
|
|
||||||
|
|
||||||
|
class LinearOperatorExtras(object): # pylint: disable=missing-docstring
|
||||||
|
|
||||||
|
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
|
||||||
|
|
||||||
|
with self._name_scope(name, values=[x]):
|
||||||
|
if isinstance(x, ops.IndexedSlices):
|
||||||
|
return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||||
|
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
self._check_input_dtype(x)
|
||||||
|
|
||||||
|
self_dim = -2 if adjoint else -1
|
||||||
|
arg_dim = -1 if adjoint_arg else -2
|
||||||
|
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
|
||||||
|
|
||||||
|
return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||||
|
|
||||||
|
def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
|
||||||
|
|
||||||
|
with self._name_scope(name, values=[x]):
|
||||||
|
|
||||||
|
if isinstance(x, ops.IndexedSlices):
|
||||||
|
return self._matmul_right_sparse(
|
||||||
|
x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||||
|
|
||||||
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
|
self._check_input_dtype(x)
|
||||||
|
|
||||||
|
self_dim = -1 if adjoint else -2
|
||||||
|
arg_dim = -2 if adjoint_arg else -1
|
||||||
|
self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
|
||||||
|
|
||||||
|
return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearOperatorFullMatrix(LinearOperatorExtras,
|
||||||
|
linalg.LinearOperatorFullMatrix):
|
||||||
|
|
||||||
|
# TODO(b/78117889) Remove this definition once core LinearOperator
|
||||||
|
# has _matmul_right.
|
||||||
|
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
|
||||||
|
return lou.matmul_with_broadcast(
|
||||||
|
x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
|
||||||
|
|
||||||
|
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||||
|
assert not adjoint and not adjoint_arg
|
||||||
|
return utils.matmul_sparse_dense(x, self._matrix)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
|
||||||
|
linalg.LinearOperatorDiag):
|
||||||
|
|
||||||
|
def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
|
||||||
|
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
|
||||||
|
x = linalg_impl.adjoint(x) if adjoint_arg else x
|
||||||
|
return diag_mat * x
|
||||||
|
|
||||||
|
def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||||
|
diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
|
||||||
|
assert not adjoint_arg
|
||||||
|
return utils.matmul_diag_sparse(diag_mat, x)
|
||||||
|
|
||||||
|
def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
|
||||||
|
raise NotImplementedError
|
@ -35,7 +35,7 @@ def _make_thunk_on_device(func, device):
|
|||||||
class RoundRobinPlacementMixin(object):
|
class RoundRobinPlacementMixin(object):
|
||||||
"""Implements round robin placement strategy for ops and variables."""
|
"""Implements round robin placement strategy for ops and variables."""
|
||||||
|
|
||||||
def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs):
|
def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
|
||||||
"""Initializes the RoundRobinPlacementMixin class.
|
"""Initializes the RoundRobinPlacementMixin class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -45,11 +45,10 @@ class RoundRobinPlacementMixin(object):
|
|||||||
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
|
||||||
computations will be placed on these devices in a round-robin fashion.
|
computations will be placed on these devices in a round-robin fashion.
|
||||||
Can be None, which means that no devices are specified.
|
Can be None, which means that no devices are specified.
|
||||||
*args:
|
**kwargs: Need something here?
|
||||||
**kwargs:
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs)
|
super(RoundRobinPlacementMixin, self).__init__(**kwargs)
|
||||||
self._cov_devices = cov_devices
|
self._cov_devices = cov_devices
|
||||||
self._inv_devices = inv_devices
|
self._inv_devices = inv_devices
|
||||||
|
|
||||||
|
@ -235,6 +235,13 @@ posdef_eig_functions = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def cholesky(tensor, damping):
|
||||||
|
"""Computes the inverse of tensor + damping * identity."""
|
||||||
|
identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
|
||||||
|
damping = math_ops.cast(damping, dtype=tensor.dtype)
|
||||||
|
return linalg_ops.cholesky(tensor + damping * identity)
|
||||||
|
|
||||||
|
|
||||||
class SubGraph(object):
|
class SubGraph(object):
|
||||||
"""Defines a subgraph given by all the dependencies of a given set of outputs.
|
"""Defines a subgraph given by all the dependencies of a given set of outputs.
|
||||||
"""
|
"""
|
||||||
@ -553,13 +560,17 @@ def is_data_format_channel_last(data_format):
|
|||||||
return data_format.endswith("C")
|
return data_format.endswith("C")
|
||||||
|
|
||||||
|
|
||||||
def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
|
def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
|
||||||
"""Computes matmul(A, B) where A is sparse, B is dense.
|
"""Computes matmul(A, B) where A is sparse, B is dense.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
A: tf.IndexedSlices with dense shape [m, n].
|
A: tf.IndexedSlices with dense shape [m, n].
|
||||||
B: tf.Tensor with shape [n, k].
|
B: tf.Tensor with shape [n, k].
|
||||||
name: str. Name of op.
|
name: str. Name of op.
|
||||||
|
transpose_a: Bool. If true we transpose A before multiplying it by B.
|
||||||
|
(Default: False)
|
||||||
|
transpose_b: Bool. If true we transpose B before multiplying it by A.
|
||||||
|
(Default: False)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.IndexedSlices resulting from matmul(A, B).
|
tf.IndexedSlices resulting from matmul(A, B).
|
||||||
@ -573,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name
|
|||||||
raise ValueError("A must represent a matrix. Found: %s." % A)
|
raise ValueError("A must represent a matrix. Found: %s." % A)
|
||||||
if B.shape.ndims != 2:
|
if B.shape.ndims != 2:
|
||||||
raise ValueError("B must be a matrix.")
|
raise ValueError("B must be a matrix.")
|
||||||
new_values = math_ops.matmul(A.values, B)
|
new_values = math_ops.matmul(
|
||||||
|
A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
|
||||||
return ops.IndexedSlices(
|
return ops.IndexedSlices(
|
||||||
new_values,
|
new_values,
|
||||||
A.indices,
|
A.indices,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user