Makes slot_creator support copy sharding from the primary variable.
When creating EMA variables, copy sharding from the primary variables. PiperOrigin-RevId: 352102947 Change-Id: Ib29ae4dc1be7642924da8e386fb21d1a7e1a42e1
This commit is contained in:
parent
186d9b18da
commit
b754340bed
@ -228,7 +228,7 @@ def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
|
||||
Returns:
|
||||
A tensor with sharding annotation copied from `from_tensor`.
|
||||
"""
|
||||
sharding = get_op_sharding(from_tensor.op)
|
||||
sharding = get_tensor_sharding(from_tensor)
|
||||
if sharding is None:
|
||||
return to_tensor
|
||||
|
||||
@ -343,6 +343,25 @@ def get_op_sharding(op):
|
||||
return op.get_attr('_XlaSharding')
|
||||
except ValueError:
|
||||
return None
|
||||
except AttributeError:
|
||||
# AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'.
|
||||
return None
|
||||
|
||||
|
||||
def get_tensor_sharding(tensor):
|
||||
"""Returns sharding attribute of a Tensor.
|
||||
|
||||
Args:
|
||||
tensor: a Tensor.
|
||||
|
||||
Returns:
|
||||
The attribute representing XLA sharding on tensor's op.
|
||||
"""
|
||||
try:
|
||||
return get_op_sharding(tensor.op)
|
||||
except AttributeError:
|
||||
# AttributeError: Tensor.op is meaningless when eager execution is enabled.
|
||||
return None
|
||||
|
||||
|
||||
def auto_to_manual_spmd_partition(tensor, manual_sharding):
|
||||
|
@ -499,6 +499,7 @@ py_library(
|
||||
srcs = ["slot_creator.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
|
@ -457,7 +457,8 @@ class ExponentialMovingAverage(object):
|
||||
var,
|
||||
var.initialized_value(),
|
||||
self.name,
|
||||
colocate_with_primary=True)
|
||||
colocate_with_primary=True,
|
||||
copy_xla_sharding=True)
|
||||
# NOTE(mrry): We only add `tf.Variable` objects to the
|
||||
# `MOVING_AVERAGE_VARIABLES` collection.
|
||||
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
|
||||
@ -467,7 +468,8 @@ class ExponentialMovingAverage(object):
|
||||
self.name,
|
||||
colocate_with_primary=(var.op.type in [
|
||||
"Variable", "VariableV2", "VarHandleOp"
|
||||
]))
|
||||
]),
|
||||
copy_xla_sharding=True)
|
||||
if self._zero_debias:
|
||||
zero_debias_true.add(avg.ref())
|
||||
self._averages[var.ref()] = avg
|
||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -491,6 +494,20 @@ class ExponentialMovingAverageTest(test.TestCase):
|
||||
self.assertEqual(len(vars_to_restore), 1)
|
||||
self.assertIn("v/foo_avg", vars_to_restore)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testCopyXlaSharding(self):
|
||||
ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
|
||||
v = variables.Variable(_Repeat(10.0, 2), name="v")
|
||||
self.assertIsNone(xla_sharding.get_tensor_sharding(v))
|
||||
v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False)
|
||||
self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
ema.apply([v])
|
||||
avg = ema.average(v)
|
||||
self.assertEqual(
|
||||
xla_sharding.get_tensor_sharding(v),
|
||||
xla_sharding.get_tensor_sharding(avg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -39,6 +39,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -47,7 +48,14 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
|
||||
def _create_slot_var(primary,
|
||||
val,
|
||||
scope,
|
||||
validate_shape,
|
||||
shape,
|
||||
dtype,
|
||||
*,
|
||||
copy_xla_sharding=False):
|
||||
"""Helper function for creating a slot variable."""
|
||||
|
||||
# TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
|
||||
@ -98,10 +106,19 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
|
||||
slice_info.full_shape[:n], slice_info.var_offset[:n],
|
||||
slice_info.var_shape[:n]))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# Copy XLA sharding attributes from primary.
|
||||
if copy_xla_sharding:
|
||||
slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
|
||||
return slot
|
||||
|
||||
|
||||
def create_slot(primary, val, name, colocate_with_primary=True):
|
||||
def create_slot(primary,
|
||||
val,
|
||||
name,
|
||||
colocate_with_primary=True,
|
||||
*,
|
||||
copy_xla_sharding=False):
|
||||
"""Create a slot initialized to the given value.
|
||||
|
||||
The type of the slot is determined by the given value.
|
||||
@ -112,6 +129,8 @@ def create_slot(primary, val, name, colocate_with_primary=True):
|
||||
name: Name to use for the slot variable.
|
||||
colocate_with_primary: Boolean. If True the slot is located
|
||||
on the same device as `primary`.
|
||||
copy_xla_sharding: Boolean. If True also copies XLA sharding
|
||||
from primary.
|
||||
|
||||
Returns:
|
||||
A `Variable` object.
|
||||
@ -130,13 +149,33 @@ def create_slot(primary, val, name, colocate_with_primary=True):
|
||||
if colocate_with_primary:
|
||||
distribution_strategy = distribution_strategy_context.get_strategy()
|
||||
with distribution_strategy.extended.colocate_vars_with(primary):
|
||||
return _create_slot_var(primary, val, "", validate_shape, None, None)
|
||||
return _create_slot_var(
|
||||
primary,
|
||||
val,
|
||||
"",
|
||||
validate_shape,
|
||||
None,
|
||||
None,
|
||||
copy_xla_sharding=copy_xla_sharding)
|
||||
else:
|
||||
return _create_slot_var(primary, val, "", validate_shape, None, None)
|
||||
return _create_slot_var(
|
||||
primary,
|
||||
val,
|
||||
"",
|
||||
validate_shape,
|
||||
None,
|
||||
None,
|
||||
copy_xla_sharding=copy_xla_sharding)
|
||||
|
||||
|
||||
def create_slot_with_initializer(primary, initializer, shape, dtype, name,
|
||||
colocate_with_primary=True):
|
||||
def create_slot_with_initializer(primary,
|
||||
initializer,
|
||||
shape,
|
||||
dtype,
|
||||
name,
|
||||
colocate_with_primary=True,
|
||||
*,
|
||||
copy_xla_sharding=False):
|
||||
"""Creates a slot initialized using an `Initializer`.
|
||||
|
||||
The type of the slot is determined by the given value.
|
||||
@ -149,6 +188,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
|
||||
name: Name to use for the slot variable.
|
||||
colocate_with_primary: Boolean. If True the slot is located
|
||||
on the same device as `primary`.
|
||||
copy_xla_sharding: Boolean. If True also copies XLA sharding
|
||||
from primary.
|
||||
|
||||
Returns:
|
||||
A `Variable` object.
|
||||
@ -167,14 +208,31 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
|
||||
if colocate_with_primary:
|
||||
distribution_strategy = distribution_strategy_context.get_strategy()
|
||||
with distribution_strategy.extended.colocate_vars_with(primary):
|
||||
return _create_slot_var(primary, initializer, "", validate_shape, shape,
|
||||
dtype)
|
||||
return _create_slot_var(
|
||||
primary,
|
||||
initializer,
|
||||
"",
|
||||
validate_shape,
|
||||
shape,
|
||||
dtype,
|
||||
copy_xla_sharding=copy_xla_sharding)
|
||||
else:
|
||||
return _create_slot_var(primary, initializer, "", validate_shape, shape,
|
||||
dtype)
|
||||
return _create_slot_var(
|
||||
primary,
|
||||
initializer,
|
||||
"",
|
||||
validate_shape,
|
||||
shape,
|
||||
dtype,
|
||||
copy_xla_sharding=copy_xla_sharding)
|
||||
|
||||
|
||||
def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
|
||||
def create_zeros_slot(primary,
|
||||
name,
|
||||
dtype=None,
|
||||
colocate_with_primary=True,
|
||||
*,
|
||||
copy_xla_sharding=False):
|
||||
"""Create a slot initialized to 0 with same shape as the primary object.
|
||||
|
||||
Args:
|
||||
@ -183,6 +241,8 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
|
||||
dtype: Type of the slot variable. Defaults to the type of `primary`.
|
||||
colocate_with_primary: Boolean. If True the slot is located
|
||||
on the same device as `primary`.
|
||||
copy_xla_sharding: Boolean. If True also copies XLA sharding
|
||||
from primary.
|
||||
|
||||
Returns:
|
||||
A `Variable` object.
|
||||
@ -193,13 +253,22 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
|
||||
if slot_shape.is_fully_defined():
|
||||
initializer = init_ops.zeros_initializer()
|
||||
return create_slot_with_initializer(
|
||||
primary, initializer, slot_shape, dtype, name,
|
||||
colocate_with_primary=colocate_with_primary)
|
||||
primary,
|
||||
initializer,
|
||||
slot_shape,
|
||||
dtype,
|
||||
name,
|
||||
colocate_with_primary=colocate_with_primary,
|
||||
copy_xla_sharding=copy_xla_sharding)
|
||||
else:
|
||||
if isinstance(primary, variables.Variable):
|
||||
slot_shape = array_ops.shape(primary.initialized_value())
|
||||
else:
|
||||
slot_shape = array_ops.shape(primary)
|
||||
val = array_ops.zeros(slot_shape, dtype=dtype)
|
||||
return create_slot(primary, val, name,
|
||||
colocate_with_primary=colocate_with_primary)
|
||||
return create_slot(
|
||||
primary,
|
||||
val,
|
||||
name,
|
||||
colocate_with_primary=colocate_with_primary,
|
||||
copy_xla_sharding=copy_xla_sharding)
|
||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -174,6 +177,31 @@ class SlotCreatorTest(test.TestCase):
|
||||
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
|
||||
self.assertAllEqual(1.0, slot)
|
||||
|
||||
def testCreateSlotFromVariableCopyXlaSharding(self):
|
||||
# slot_creator is used only in optimizer V1.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
v = variables.Variable([1.0, 2.5], name="var")
|
||||
v = xla_sharding.mesh_split(
|
||||
v, np.array([0, 1]), [0], use_sharding_op=False)
|
||||
slot = slot_creator.create_slot(
|
||||
v, v.initialized_value(), name="slot", copy_xla_sharding=True)
|
||||
self.assertEqual(
|
||||
xla_sharding.get_tensor_sharding(v),
|
||||
xla_sharding.get_tensor_sharding(slot))
|
||||
|
||||
def testCreateZerosSlotFromVariableCopyXlaSharding(self):
|
||||
# slot_creator is used only in optimizer V1.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
v = variables.Variable([1.0, 2.5], name="var")
|
||||
v = xla_sharding.mesh_split(
|
||||
v, np.array([0, 1]), [0], use_sharding_op=False)
|
||||
with ops.control_dependencies(None):
|
||||
slot = slot_creator.create_zeros_slot(
|
||||
v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True)
|
||||
self.assertEqual(
|
||||
xla_sharding.get_tensor_sharding(v),
|
||||
xla_sharding.get_tensor_sharding(slot))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user