K-FAC: Cross Replica Mean for TPU

Adds an op for taking the average of a Tensor across all TPU cores, and uses it
before updating covariance statistics. This is a no-op if TPUs aren't used.

PiperOrigin-RevId: 179620193
This commit is contained in:
A. Unique TensorFlower 2017-12-19 16:31:31 -08:00 committed by TensorFlower Gardener
parent d2a78f9199
commit 1ddb053077
5 changed files with 57 additions and 0 deletions

View File

@ -110,6 +110,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",

View File

@ -22,6 +22,7 @@ import numpy as np
import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
@ -267,6 +268,25 @@ class UtilsTest(test.TestCase):
np_inv = np.linalg.inv(x + damp * np.eye(size))
self.assertAllClose(sess.run(tf_inv), np_inv)
def testCrossReplicaMean(self):
"""Ensures that cross_replica_mean() executes only when num_shards > 1."""
with ops.Graph().as_default():
with tpu_function.tpu_shard_context(4):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
self.assertNotEqual(mean, tensor)
with ops.Graph().as_default():
with tpu_function.tpu_shard_context(1):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
self.assertEqual(mean, tensor)
with ops.Graph().as_default():
with self.assertRaises(ValueError): # Outside of TPU context.
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
if __name__ == '__main__':
test.main()

View File

@ -196,6 +196,7 @@ py_library(
srcs = ["utils.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",

View File

@ -267,6 +267,10 @@ class FisherFactor(object):
new_cov = math_ops.add_n(
tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
# Synchronize value across all TPU cores.
if utils.on_tpu():
new_cov = utils.cross_replica_mean(new_cov)
return moving_averages.assign_moving_average(
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -313,5 +315,34 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
return dysdx
def on_tpu():
"""Returns True when building a TPU computation."""
return tpu_function.get_tpu_context().number_of_shards is not None
def cross_replica_mean(tensor, name=None):
"""Takes mean value of a Tensor across all TPU cores.
Args:
tensor: Tensor to be synchronized.
name: None or string. Name of Op.
Returns:
Average of Tensor across all TPU cores.
Raises:
ValueError: If called outside of TPU context.
"""
with ops.name_scope(name, "cross_replica_mean", [tensor]):
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
raise ValueError(
"Cannot take cross_replica_mean() outside of TPU Context.")
if num_shards == 1:
return tensor
return tpu_ops.cross_replica_sum(tensor / num_shards)
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.