Makes Optimizer._zeros_slot() copy XLA sharding from variable.

Adds a test to confirm that Adam's slots are sharded the same way as corresponding variables.

PiperOrigin-RevId: 355176017
Change-Id: I8776bd8e0a70d812230b6de6041d752387706456
This commit is contained in:
A. Unique TensorFlower 2021-02-02 08:52:49 -08:00 committed by TensorFlower Gardener
parent 7d7c5a29a7
commit abb251c299
2 changed files with 42 additions and 1 deletions

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -377,6 +378,45 @@ class AdamOptimizerTest(test.TestCase):
# for v1 and v2 respectively.
self.assertEqual(6, len({id(v) for v in opt.variables()}))
@test_util.deprecated_graph_mode_only
def testXlaSharding(self):
dtype = dtypes.float32
with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
var0 = resource_variable_ops.ResourceVariable(var0_np, name="var0")
var1 = resource_variable_ops.ResourceVariable(var1_np, name="var1")
var0, var1 = [
xla_sharding.mesh_split(
v, np.array([0, 1]), [0], use_sharding_op=False)
for v in (var0, var1)
]
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
learning_rate = lambda: 0.001
opt = adam.AdamOptimizer(learning_rate=learning_rate)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
self.evaluate(variables.global_variables_initializer())
self.evaluate(update)
# The beta accumulators are not sharded.
beta1_power, beta2_power = opt._get_beta_accumulators()
self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power))
self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power))
# Variables and slots are sharded.
for v in (var0, var1):
self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
for slot_name in ("m", "v"):
slot = opt.get_slot(v, slot_name)
self.assertIsNotNone(xla_sharding.get_tensor_sharding(slot))
if __name__ == "__main__":
test.main()

View File

@ -1153,7 +1153,8 @@ class Optimizer(
"""
named_slots = self._slot_dict(slot_name)
if _var_key(var) not in named_slots:
new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
new_slot_variable = slot_creator.create_zeros_slot(
var, op_name, copy_xla_sharding=True)
self._restore_slot_variable(
slot_name=slot_name, variable=var,
slot_variable=new_slot_variable)