From abb251c299db5c6318767cd90e008a63e0741808 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 2 Feb 2021 08:52:49 -0800 Subject: [PATCH] Makes Optimizer._zeros_slot() copy XLA sharding from variable. Adds a test to confirm that Adam's slots are sharded the same way as corresponding variables. PiperOrigin-RevId: 355176017 Change-Id: I8776bd8e0a70d812230b6de6041d752387706456 --- tensorflow/python/training/adam_test.py | 40 +++++++++++++++++++++++++ tensorflow/python/training/optimizer.py | 3 +- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index 7e817a18690..06cbc386e53 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -377,6 +378,45 @@ class AdamOptimizerTest(test.TestCase): # for v1 and v2 respectively. self.assertEqual(6, len({id(v) for v in opt.variables()})) + @test_util.deprecated_graph_mode_only + def testXlaSharding(self): + dtype = dtypes.float32 + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np, name="var0") + var1 = resource_variable_ops.ResourceVariable(var1_np, name="var1") + var0, var1 = [ + xla_sharding.mesh_split( + v, np.array([0, 1]), [0], use_sharding_op=False) + for v in (var0, var1) + ] + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + + opt = adam.AdamOptimizer(learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.evaluate(variables.global_variables_initializer()) + self.evaluate(update) + # The beta accumulators are not sharded. + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power)) + self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power)) + + # Variables and slots are sharded. + for v in (var0, var1): + self.assertIsNotNone(xla_sharding.get_tensor_sharding(v)) + for slot_name in ("m", "v"): + slot = opt.get_slot(v, slot_name) + self.assertIsNotNone(xla_sharding.get_tensor_sharding(slot)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 42beafac708..90722ab8e0b 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -1153,7 +1153,8 @@ class Optimizer( """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: - new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + new_slot_variable = slot_creator.create_zeros_slot( + var, op_name, copy_xla_sharding=True) self._restore_slot_variable( slot_name=slot_name, variable=var, slot_variable=new_slot_variable)