Makes slot_creator support copy sharding from the primary variable.

When creating EMA variables, copy sharding from the primary variables.

PiperOrigin-RevId: 352102947
Change-Id: Ib29ae4dc1be7642924da8e386fb21d1a7e1a42e1
This commit is contained in:
A. Unique TensorFlower 2021-01-15 16:10:35 -08:00 committed by TensorFlower Gardener
parent 186d9b18da
commit b754340bed
6 changed files with 154 additions and 18 deletions

View File

@ -228,7 +228,7 @@ def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
Returns:
A tensor with sharding annotation copied from `from_tensor`.
"""
sharding = get_op_sharding(from_tensor.op)
sharding = get_tensor_sharding(from_tensor)
if sharding is None:
return to_tensor
@ -343,6 +343,25 @@ def get_op_sharding(op):
return op.get_attr('_XlaSharding')
except ValueError:
return None
except AttributeError:
# AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'.
return None
def get_tensor_sharding(tensor):
"""Returns sharding attribute of a Tensor.
Args:
tensor: a Tensor.
Returns:
The attribute representing XLA sharding on tensor's op.
"""
try:
return get_op_sharding(tensor.op)
except AttributeError:
# AttributeError: Tensor.op is meaningless when eager execution is enabled.
return None
def auto_to_manual_spmd_partition(tensor, manual_sharding):

View File

@ -499,6 +499,7 @@ py_library(
srcs = ["slot_creator.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/python:array_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:resource_variable_ops",

View File

@ -457,7 +457,8 @@ class ExponentialMovingAverage(object):
var,
var.initialized_value(),
self.name,
colocate_with_primary=True)
colocate_with_primary=True,
copy_xla_sharding=True)
# NOTE(mrry): We only add `tf.Variable` objects to the
# `MOVING_AVERAGE_VARIABLES` collection.
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
@ -467,7 +468,8 @@ class ExponentialMovingAverage(object):
self.name,
colocate_with_primary=(var.op.type in [
"Variable", "VariableV2", "VarHandleOp"
]))
]),
copy_xla_sharding=True)
if self._zero_debias:
zero_debias_true.add(avg.ref())
self._averages[var.ref()] = avg

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -491,6 +494,20 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertEqual(len(vars_to_restore), 1)
self.assertIn("v/foo_avg", vars_to_restore)
@test_util.deprecated_graph_mode_only
def testCopyXlaSharding(self):
ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
v = variables.Variable(_Repeat(10.0, 2), name="v")
self.assertIsNone(xla_sharding.get_tensor_sharding(v))
v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False)
self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
self.evaluate(variables.global_variables_initializer())
ema.apply([v])
avg = ema.average(v)
self.assertEqual(
xla_sharding.get_tensor_sharding(v),
xla_sharding.get_tensor_sharding(avg))
if __name__ == "__main__":
test.main()

View File

@ -39,6 +39,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@ -47,7 +48,14 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
def _create_slot_var(primary,
val,
scope,
validate_shape,
shape,
dtype,
*,
copy_xla_sharding=False):
"""Helper function for creating a slot variable."""
# TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
@ -98,10 +106,19 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
slice_info.full_shape[:n], slice_info.var_offset[:n],
slice_info.var_shape[:n]))
# pylint: enable=protected-access
# Copy XLA sharding attributes from primary.
if copy_xla_sharding:
slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
return slot
def create_slot(primary, val, name, colocate_with_primary=True):
def create_slot(primary,
val,
name,
colocate_with_primary=True,
*,
copy_xla_sharding=False):
"""Create a slot initialized to the given value.
The type of the slot is determined by the given value.
@ -112,6 +129,8 @@ def create_slot(primary, val, name, colocate_with_primary=True):
name: Name to use for the slot variable.
colocate_with_primary: Boolean. If True the slot is located
on the same device as `primary`.
copy_xla_sharding: Boolean. If True also copies XLA sharding
from primary.
Returns:
A `Variable` object.
@ -130,13 +149,33 @@ def create_slot(primary, val, name, colocate_with_primary=True):
if colocate_with_primary:
distribution_strategy = distribution_strategy_context.get_strategy()
with distribution_strategy.extended.colocate_vars_with(primary):
return _create_slot_var(primary, val, "", validate_shape, None, None)
return _create_slot_var(
primary,
val,
"",
validate_shape,
None,
None,
copy_xla_sharding=copy_xla_sharding)
else:
return _create_slot_var(primary, val, "", validate_shape, None, None)
return _create_slot_var(
primary,
val,
"",
validate_shape,
None,
None,
copy_xla_sharding=copy_xla_sharding)
def create_slot_with_initializer(primary, initializer, shape, dtype, name,
colocate_with_primary=True):
def create_slot_with_initializer(primary,
initializer,
shape,
dtype,
name,
colocate_with_primary=True,
*,
copy_xla_sharding=False):
"""Creates a slot initialized using an `Initializer`.
The type of the slot is determined by the given value.
@ -149,6 +188,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
name: Name to use for the slot variable.
colocate_with_primary: Boolean. If True the slot is located
on the same device as `primary`.
copy_xla_sharding: Boolean. If True also copies XLA sharding
from primary.
Returns:
A `Variable` object.
@ -167,14 +208,31 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
if colocate_with_primary:
distribution_strategy = distribution_strategy_context.get_strategy()
with distribution_strategy.extended.colocate_vars_with(primary):
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)
return _create_slot_var(
primary,
initializer,
"",
validate_shape,
shape,
dtype,
copy_xla_sharding=copy_xla_sharding)
else:
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)
return _create_slot_var(
primary,
initializer,
"",
validate_shape,
shape,
dtype,
copy_xla_sharding=copy_xla_sharding)
def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
def create_zeros_slot(primary,
name,
dtype=None,
colocate_with_primary=True,
*,
copy_xla_sharding=False):
"""Create a slot initialized to 0 with same shape as the primary object.
Args:
@ -183,6 +241,8 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
dtype: Type of the slot variable. Defaults to the type of `primary`.
colocate_with_primary: Boolean. If True the slot is located
on the same device as `primary`.
copy_xla_sharding: Boolean. If True also copies XLA sharding
from primary.
Returns:
A `Variable` object.
@ -193,13 +253,22 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
if slot_shape.is_fully_defined():
initializer = init_ops.zeros_initializer()
return create_slot_with_initializer(
primary, initializer, slot_shape, dtype, name,
colocate_with_primary=colocate_with_primary)
primary,
initializer,
slot_shape,
dtype,
name,
colocate_with_primary=colocate_with_primary,
copy_xla_sharding=copy_xla_sharding)
else:
if isinstance(primary, variables.Variable):
slot_shape = array_ops.shape(primary.initialized_value())
else:
slot_shape = array_ops.shape(primary)
val = array_ops.zeros(slot_shape, dtype=dtype)
return create_slot(primary, val, name,
colocate_with_primary=colocate_with_primary)
return create_slot(
primary,
val,
name,
colocate_with_primary=colocate_with_primary,
copy_xla_sharding=copy_xla_sharding)

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -174,6 +177,31 @@ class SlotCreatorTest(test.TestCase):
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
self.assertAllEqual(1.0, slot)
def testCreateSlotFromVariableCopyXlaSharding(self):
# slot_creator is used only in optimizer V1.
with ops.Graph().as_default(), self.cached_session():
v = variables.Variable([1.0, 2.5], name="var")
v = xla_sharding.mesh_split(
v, np.array([0, 1]), [0], use_sharding_op=False)
slot = slot_creator.create_slot(
v, v.initialized_value(), name="slot", copy_xla_sharding=True)
self.assertEqual(
xla_sharding.get_tensor_sharding(v),
xla_sharding.get_tensor_sharding(slot))
def testCreateZerosSlotFromVariableCopyXlaSharding(self):
# slot_creator is used only in optimizer V1.
with ops.Graph().as_default(), self.cached_session():
v = variables.Variable([1.0, 2.5], name="var")
v = xla_sharding.mesh_split(
v, np.array([0, 1]), [0], use_sharding_op=False)
with ops.control_dependencies(None):
slot = slot_creator.create_zeros_slot(
v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True)
self.assertEqual(
xla_sharding.get_tensor_sharding(v),
xla_sharding.get_tensor_sharding(slot))
if __name__ == "__main__":
test.main()