Correct an issue in calculating Frechet Inception Distance.
PiperOrigin-RevId: 169553227
This commit is contained in:
parent
bdf4dc38a2
commit
5cc7b86a0d
@ -90,9 +90,13 @@ def _matrix_square_root(mat, eps=1e-10):
|
||||
Returns:
|
||||
Matrix square root of mat.
|
||||
"""
|
||||
# Unlike numpy, tensorflow's return order is (s, u, v)
|
||||
s, u, v = linalg_ops.svd(mat)
|
||||
# sqrt is unstable around 0, just use 0 in such case
|
||||
si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s))
|
||||
# Note that the v returned by Tensorflow is v = V
|
||||
# (when referencing the equation A = U S V^T)
|
||||
# This is unlike Numpy which returns v = V^T
|
||||
return math_ops.matmul(
|
||||
math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True)
|
||||
|
||||
@ -388,9 +392,14 @@ def frechet_classifier_distance(real_images,
|
||||
# Compute mean and covariance matrices of activations.
|
||||
m = math_ops.reduce_mean(real_a, 0)
|
||||
m_v = math_ops.reduce_mean(gen_a, 0)
|
||||
dim = math_ops.to_float(array_ops.shape(m)[0])
|
||||
sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim
|
||||
sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim
|
||||
num_examples = math_ops.to_float(array_ops.shape(real_a)[0])
|
||||
|
||||
# sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
|
||||
sigma = math_ops.matmul(
|
||||
real_a - m, real_a - m, transpose_a=True) / (num_examples - 1)
|
||||
|
||||
sigma_v = math_ops.matmul(
|
||||
gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1)
|
||||
|
||||
# Take matrix square root of the product of covariance matrices.
|
||||
sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v))
|
||||
|
||||
@ -50,21 +50,22 @@ def _expected_inception_score(logits):
|
||||
|
||||
|
||||
def _approximate_matrix_sqrt(mat, eps=1e-8):
|
||||
s, u, v = np.linalg.svd(mat)
|
||||
# Unlike tensorflow, numpy's return order is (u, s, v)
|
||||
u, s, v = np.linalg.svd(mat)
|
||||
si = np.where(s < eps, s, np.sqrt(s))
|
||||
return np.dot(np.dot(u, np.diag(si)), v.T)
|
||||
# Note the "v" returned by numpy is actually v = V^T
|
||||
# (when referencing the SVD equation A = U S V^T)
|
||||
# This is unlike Tensorflow which returns v = V
|
||||
return np.dot(np.dot(u, np.diag(si)), v)
|
||||
|
||||
|
||||
def _expected_fid(real_imgs, gen_imgs):
|
||||
real_imgs = np.asarray(real_imgs)
|
||||
gen_imgs = np.asarray(gen_imgs)
|
||||
m = np.mean(real_imgs, axis=0)
|
||||
m_v = np.mean(gen_imgs, axis=0)
|
||||
dim = float(m.shape[0])
|
||||
sigma = np.dot((real_imgs - m), (real_imgs - m).T) / dim
|
||||
sigma_v = np.dot((gen_imgs - m), (gen_imgs - m).T) / dim
|
||||
sigma = np.cov(real_imgs, rowvar=False)
|
||||
sigma_v = np.cov(gen_imgs, rowvar=False)
|
||||
sqcc = _approximate_matrix_sqrt(np.dot(sigma, sigma_v))
|
||||
mean = np.square(np.linalg.norm(m - m_v))
|
||||
mean = np.square(m - m_v).sum()
|
||||
trace = np.trace(sigma + sigma_v - 2 * sqcc)
|
||||
fid = mean + trace
|
||||
return fid
|
||||
@ -264,22 +265,21 @@ class ClassifierMetricsTest(test.TestCase):
|
||||
|
||||
self.assertAllClose(_expected_inception_score(logits), incscore_np)
|
||||
|
||||
def test_frechet_inception_distance_value(self):
|
||||
"""Test that `frechet_inception_distance` gives the correct value."""
|
||||
def test_frechet_classifier_distance_value(self):
|
||||
"""Test that `frechet_classifier_distance` gives the correct value."""
|
||||
np.random.seed(0)
|
||||
test_pool_real_a = np.random.randn(5, 2048)
|
||||
test_pool_gen_a = np.random.randn(5, 2048)
|
||||
unused_image = array_ops.zeros([5, 299, 299, 3])
|
||||
test_pool_real_a = np.float32(np.random.randn(64, 256))
|
||||
test_pool_gen_a = np.float32(np.random.randn(64, 256))
|
||||
|
||||
pool_a = np.stack((test_pool_real_a, test_pool_gen_a))
|
||||
fid_op = _run_with_mock(classifier_metrics.frechet_inception_distance,
|
||||
unused_image, unused_image)
|
||||
activations_tensor = 'RunClassifier/TensorArrayStack/TensorArrayGatherV3:0'
|
||||
fid_op = _run_with_mock(classifier_metrics.frechet_classifier_distance,
|
||||
test_pool_real_a, test_pool_gen_a,
|
||||
classifier_fn=lambda x: x)
|
||||
|
||||
with self.test_session() as sess:
|
||||
actual_fid = sess.run(fid_op, {activations_tensor: pool_a})
|
||||
actual_fid = sess.run(fid_op)
|
||||
|
||||
expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a)
|
||||
|
||||
self.assertAllClose(expected_fid, actual_fid, 0.01)
|
||||
|
||||
def test_preprocess_image_graph(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user