From 5cc7b86a0df3c2b1f26e81ea4c8ad290025aa0ce Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 21 Sep 2017 09:26:57 -0700 Subject: [PATCH] Correct an issue in calculating Frechet Inception Distance. PiperOrigin-RevId: 169553227 --- .../eval/python/classifier_metrics_impl.py | 15 ++++++-- .../eval/python/classifier_metrics_test.py | 36 +++++++++---------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 7ff9a3a51d0..151fecdca0c 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -90,9 +90,13 @@ def _matrix_square_root(mat, eps=1e-10): Returns: Matrix square root of mat. """ + # Unlike numpy, tensorflow's return order is (s, u, v) s, u, v = linalg_ops.svd(mat) # sqrt is unstable around 0, just use 0 in such case si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s)) + # Note that the v returned by Tensorflow is v = V + # (when referencing the equation A = U S V^T) + # This is unlike Numpy which returns v = V^T return math_ops.matmul( math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) @@ -388,9 +392,14 @@ def frechet_classifier_distance(real_images, # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_a, 0) m_v = math_ops.reduce_mean(gen_a, 0) - dim = math_ops.to_float(array_ops.shape(m)[0]) - sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim - sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim + num_examples = math_ops.to_float(array_ops.shape(real_a)[0]) + + # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T + sigma = math_ops.matmul( + real_a - m, real_a - m, transpose_a=True) / (num_examples - 1) + + sigma_v = math_ops.matmul( + gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) # Take matrix square root of the product of covariance matrices. sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v)) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index d7bfa1ae28b..9e8776f3a4c 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -50,21 +50,22 @@ def _expected_inception_score(logits): def _approximate_matrix_sqrt(mat, eps=1e-8): - s, u, v = np.linalg.svd(mat) + # Unlike tensorflow, numpy's return order is (u, s, v) + u, s, v = np.linalg.svd(mat) si = np.where(s < eps, s, np.sqrt(s)) - return np.dot(np.dot(u, np.diag(si)), v.T) + # Note the "v" returned by numpy is actually v = V^T + # (when referencing the SVD equation A = U S V^T) + # This is unlike Tensorflow which returns v = V + return np.dot(np.dot(u, np.diag(si)), v) def _expected_fid(real_imgs, gen_imgs): - real_imgs = np.asarray(real_imgs) - gen_imgs = np.asarray(gen_imgs) m = np.mean(real_imgs, axis=0) m_v = np.mean(gen_imgs, axis=0) - dim = float(m.shape[0]) - sigma = np.dot((real_imgs - m), (real_imgs - m).T) / dim - sigma_v = np.dot((gen_imgs - m), (gen_imgs - m).T) / dim + sigma = np.cov(real_imgs, rowvar=False) + sigma_v = np.cov(gen_imgs, rowvar=False) sqcc = _approximate_matrix_sqrt(np.dot(sigma, sigma_v)) - mean = np.square(np.linalg.norm(m - m_v)) + mean = np.square(m - m_v).sum() trace = np.trace(sigma + sigma_v - 2 * sqcc) fid = mean + trace return fid @@ -264,22 +265,21 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(_expected_inception_score(logits), incscore_np) - def test_frechet_inception_distance_value(self): - """Test that `frechet_inception_distance` gives the correct value.""" + def test_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" np.random.seed(0) - test_pool_real_a = np.random.randn(5, 2048) - test_pool_gen_a = np.random.randn(5, 2048) - unused_image = array_ops.zeros([5, 299, 299, 3]) + test_pool_real_a = np.float32(np.random.randn(64, 256)) + test_pool_gen_a = np.float32(np.random.randn(64, 256)) - pool_a = np.stack((test_pool_real_a, test_pool_gen_a)) - fid_op = _run_with_mock(classifier_metrics.frechet_inception_distance, - unused_image, unused_image) - activations_tensor = 'RunClassifier/TensorArrayStack/TensorArrayGatherV3:0' + fid_op = _run_with_mock(classifier_metrics.frechet_classifier_distance, + test_pool_real_a, test_pool_gen_a, + classifier_fn=lambda x: x) with self.test_session() as sess: - actual_fid = sess.run(fid_op, {activations_tensor: pool_a}) + actual_fid = sess.run(fid_op) expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a) + self.assertAllClose(expected_fid, actual_fid, 0.01) def test_preprocess_image_graph(self):