K-FAC: Cross Replica Mean for TPU
Adds an op for taking the average of a Tensor across all TPU cores, and uses it before updating covariance statistics. This is a no-op if TPUs aren't used. PiperOrigin-RevId: 179620193
This commit is contained in:
parent
d2a78f9199
commit
1ddb053077
@ -110,6 +110,7 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:utils",
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
@ -267,6 +268,25 @@ class UtilsTest(test.TestCase):
|
||||
np_inv = np.linalg.inv(x + damp * np.eye(size))
|
||||
self.assertAllClose(sess.run(tf_inv), np_inv)
|
||||
|
||||
def testCrossReplicaMean(self):
|
||||
"""Ensures that cross_replica_mean() executes only when num_shards > 1."""
|
||||
with ops.Graph().as_default():
|
||||
with tpu_function.tpu_shard_context(4):
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
self.assertNotEqual(mean, tensor)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with tpu_function.tpu_shard_context(1):
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
self.assertEqual(mean, tensor)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaises(ValueError): # Outside of TPU context.
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -196,6 +196,7 @@ py_library(
|
||||
srcs = ["utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
|
@ -267,6 +267,10 @@ class FisherFactor(object):
|
||||
new_cov = math_ops.add_n(
|
||||
tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
|
||||
|
||||
# Synchronize value across all TPU cores.
|
||||
if utils.on_tpu():
|
||||
new_cov = utils.cross_replica_mean(new_cov)
|
||||
|
||||
return moving_averages.assign_moving_average(
|
||||
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
|
||||
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -313,5 +315,34 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
|
||||
|
||||
return dysdx
|
||||
|
||||
|
||||
def on_tpu():
|
||||
"""Returns True when building a TPU computation."""
|
||||
return tpu_function.get_tpu_context().number_of_shards is not None
|
||||
|
||||
|
||||
def cross_replica_mean(tensor, name=None):
|
||||
"""Takes mean value of a Tensor across all TPU cores.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to be synchronized.
|
||||
name: None or string. Name of Op.
|
||||
|
||||
Returns:
|
||||
Average of Tensor across all TPU cores.
|
||||
|
||||
Raises:
|
||||
ValueError: If called outside of TPU context.
|
||||
"""
|
||||
with ops.name_scope(name, "cross_replica_mean", [tensor]):
|
||||
num_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_shards is None:
|
||||
raise ValueError(
|
||||
"Cannot take cross_replica_mean() outside of TPU Context.")
|
||||
if num_shards == 1:
|
||||
return tensor
|
||||
return tpu_ops.cross_replica_sum(tensor / num_shards)
|
||||
|
||||
|
||||
# TODO(b/69623235): Add a function for finding tensors that share gradients
|
||||
# to eliminate redundant fisher factor computations.
|
||||
|
Loading…
x
Reference in New Issue
Block a user