Makes Optimizer._zeros_slot() copy XLA sharding from variable.
Adds a test to confirm that Adam's slots are sharded the same way as corresponding variables. PiperOrigin-RevId: 355176017 Change-Id: I8776bd8e0a70d812230b6de6041d752387706456
This commit is contained in:
parent
7d7c5a29a7
commit
abb251c299
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -377,6 +378,45 @@ class AdamOptimizerTest(test.TestCase):
|
||||
# for v1 and v2 respectively.
|
||||
self.assertEqual(6, len({id(v) for v in opt.variables()}))
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testXlaSharding(self):
|
||||
dtype = dtypes.float32
|
||||
with self.session(graph=ops.Graph()):
|
||||
# Initialize variables for numpy implementation.
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np, name="var0")
|
||||
var1 = resource_variable_ops.ResourceVariable(var1_np, name="var1")
|
||||
var0, var1 = [
|
||||
xla_sharding.mesh_split(
|
||||
v, np.array([0, 1]), [0], use_sharding_op=False)
|
||||
for v in (var0, var1)
|
||||
]
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
|
||||
learning_rate = lambda: 0.001
|
||||
|
||||
opt = adam.AdamOptimizer(learning_rate=learning_rate)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(update)
|
||||
# The beta accumulators are not sharded.
|
||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||
self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power))
|
||||
self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power))
|
||||
|
||||
# Variables and slots are sharded.
|
||||
for v in (var0, var1):
|
||||
self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
|
||||
for slot_name in ("m", "v"):
|
||||
slot = opt.get_slot(v, slot_name)
|
||||
self.assertIsNotNone(xla_sharding.get_tensor_sharding(slot))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1153,7 +1153,8 @@ class Optimizer(
|
||||
"""
|
||||
named_slots = self._slot_dict(slot_name)
|
||||
if _var_key(var) not in named_slots:
|
||||
new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
|
||||
new_slot_variable = slot_creator.create_zeros_slot(
|
||||
var, op_name, copy_xla_sharding=True)
|
||||
self._restore_slot_variable(
|
||||
slot_name=slot_name, variable=var,
|
||||
slot_variable=new_slot_variable)
|
||||
|
Loading…
x
Reference in New Issue
Block a user