Makes slot_creator support copy sharding from the primary variable.
When creating EMA variables, copy sharding from the primary variables. PiperOrigin-RevId: 352102947 Change-Id: Ib29ae4dc1be7642924da8e386fb21d1a7e1a42e1
This commit is contained in:
parent
186d9b18da
commit
b754340bed
@ -228,7 +228,7 @@ def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
|
|||||||
Returns:
|
Returns:
|
||||||
A tensor with sharding annotation copied from `from_tensor`.
|
A tensor with sharding annotation copied from `from_tensor`.
|
||||||
"""
|
"""
|
||||||
sharding = get_op_sharding(from_tensor.op)
|
sharding = get_tensor_sharding(from_tensor)
|
||||||
if sharding is None:
|
if sharding is None:
|
||||||
return to_tensor
|
return to_tensor
|
||||||
|
|
||||||
@ -343,6 +343,25 @@ def get_op_sharding(op):
|
|||||||
return op.get_attr('_XlaSharding')
|
return op.get_attr('_XlaSharding')
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
except AttributeError:
|
||||||
|
# AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensor_sharding(tensor):
|
||||||
|
"""Returns sharding attribute of a Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: a Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The attribute representing XLA sharding on tensor's op.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return get_op_sharding(tensor.op)
|
||||||
|
except AttributeError:
|
||||||
|
# AttributeError: Tensor.op is meaningless when eager execution is enabled.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def auto_to_manual_spmd_partition(tensor, manual_sharding):
|
def auto_to_manual_spmd_partition(tensor, manual_sharding):
|
||||||
|
@ -499,6 +499,7 @@ py_library(
|
|||||||
srcs = ["slot_creator.py"],
|
srcs = ["slot_creator.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:init_ops",
|
"//tensorflow/python:init_ops",
|
||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
|
@ -457,7 +457,8 @@ class ExponentialMovingAverage(object):
|
|||||||
var,
|
var,
|
||||||
var.initialized_value(),
|
var.initialized_value(),
|
||||||
self.name,
|
self.name,
|
||||||
colocate_with_primary=True)
|
colocate_with_primary=True,
|
||||||
|
copy_xla_sharding=True)
|
||||||
# NOTE(mrry): We only add `tf.Variable` objects to the
|
# NOTE(mrry): We only add `tf.Variable` objects to the
|
||||||
# `MOVING_AVERAGE_VARIABLES` collection.
|
# `MOVING_AVERAGE_VARIABLES` collection.
|
||||||
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
|
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
|
||||||
@ -467,7 +468,8 @@ class ExponentialMovingAverage(object):
|
|||||||
self.name,
|
self.name,
|
||||||
colocate_with_primary=(var.op.type in [
|
colocate_with_primary=(var.op.type in [
|
||||||
"Variable", "VariableV2", "VarHandleOp"
|
"Variable", "VariableV2", "VarHandleOp"
|
||||||
]))
|
]),
|
||||||
|
copy_xla_sharding=True)
|
||||||
if self._zero_debias:
|
if self._zero_debias:
|
||||||
zero_debias_true.add(avg.ref())
|
zero_debias_true.add(avg.ref())
|
||||||
self._averages[var.ref()] = avg
|
self._averages[var.ref()] = avg
|
||||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -491,6 +494,20 @@ class ExponentialMovingAverageTest(test.TestCase):
|
|||||||
self.assertEqual(len(vars_to_restore), 1)
|
self.assertEqual(len(vars_to_restore), 1)
|
||||||
self.assertIn("v/foo_avg", vars_to_restore)
|
self.assertIn("v/foo_avg", vars_to_restore)
|
||||||
|
|
||||||
|
@test_util.deprecated_graph_mode_only
|
||||||
|
def testCopyXlaSharding(self):
|
||||||
|
ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
|
||||||
|
v = variables.Variable(_Repeat(10.0, 2), name="v")
|
||||||
|
self.assertIsNone(xla_sharding.get_tensor_sharding(v))
|
||||||
|
v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False)
|
||||||
|
self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
ema.apply([v])
|
||||||
|
avg = ema.average(v)
|
||||||
|
self.assertEqual(
|
||||||
|
xla_sharding.get_tensor_sharding(v),
|
||||||
|
xla_sharding.get_tensor_sharding(avg))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -39,6 +39,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
@ -47,7 +48,14 @@ from tensorflow.python.ops import variable_scope
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
|
def _create_slot_var(primary,
|
||||||
|
val,
|
||||||
|
scope,
|
||||||
|
validate_shape,
|
||||||
|
shape,
|
||||||
|
dtype,
|
||||||
|
*,
|
||||||
|
copy_xla_sharding=False):
|
||||||
"""Helper function for creating a slot variable."""
|
"""Helper function for creating a slot variable."""
|
||||||
|
|
||||||
# TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
|
# TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
|
||||||
@ -98,10 +106,19 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
|
|||||||
slice_info.full_shape[:n], slice_info.var_offset[:n],
|
slice_info.full_shape[:n], slice_info.var_offset[:n],
|
||||||
slice_info.var_shape[:n]))
|
slice_info.var_shape[:n]))
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
# Copy XLA sharding attributes from primary.
|
||||||
|
if copy_xla_sharding:
|
||||||
|
slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
|
||||||
return slot
|
return slot
|
||||||
|
|
||||||
|
|
||||||
def create_slot(primary, val, name, colocate_with_primary=True):
|
def create_slot(primary,
|
||||||
|
val,
|
||||||
|
name,
|
||||||
|
colocate_with_primary=True,
|
||||||
|
*,
|
||||||
|
copy_xla_sharding=False):
|
||||||
"""Create a slot initialized to the given value.
|
"""Create a slot initialized to the given value.
|
||||||
|
|
||||||
The type of the slot is determined by the given value.
|
The type of the slot is determined by the given value.
|
||||||
@ -112,6 +129,8 @@ def create_slot(primary, val, name, colocate_with_primary=True):
|
|||||||
name: Name to use for the slot variable.
|
name: Name to use for the slot variable.
|
||||||
colocate_with_primary: Boolean. If True the slot is located
|
colocate_with_primary: Boolean. If True the slot is located
|
||||||
on the same device as `primary`.
|
on the same device as `primary`.
|
||||||
|
copy_xla_sharding: Boolean. If True also copies XLA sharding
|
||||||
|
from primary.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Variable` object.
|
A `Variable` object.
|
||||||
@ -130,13 +149,33 @@ def create_slot(primary, val, name, colocate_with_primary=True):
|
|||||||
if colocate_with_primary:
|
if colocate_with_primary:
|
||||||
distribution_strategy = distribution_strategy_context.get_strategy()
|
distribution_strategy = distribution_strategy_context.get_strategy()
|
||||||
with distribution_strategy.extended.colocate_vars_with(primary):
|
with distribution_strategy.extended.colocate_vars_with(primary):
|
||||||
return _create_slot_var(primary, val, "", validate_shape, None, None)
|
return _create_slot_var(
|
||||||
|
primary,
|
||||||
|
val,
|
||||||
|
"",
|
||||||
|
validate_shape,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
copy_xla_sharding=copy_xla_sharding)
|
||||||
else:
|
else:
|
||||||
return _create_slot_var(primary, val, "", validate_shape, None, None)
|
return _create_slot_var(
|
||||||
|
primary,
|
||||||
|
val,
|
||||||
|
"",
|
||||||
|
validate_shape,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
copy_xla_sharding=copy_xla_sharding)
|
||||||
|
|
||||||
|
|
||||||
def create_slot_with_initializer(primary, initializer, shape, dtype, name,
|
def create_slot_with_initializer(primary,
|
||||||
colocate_with_primary=True):
|
initializer,
|
||||||
|
shape,
|
||||||
|
dtype,
|
||||||
|
name,
|
||||||
|
colocate_with_primary=True,
|
||||||
|
*,
|
||||||
|
copy_xla_sharding=False):
|
||||||
"""Creates a slot initialized using an `Initializer`.
|
"""Creates a slot initialized using an `Initializer`.
|
||||||
|
|
||||||
The type of the slot is determined by the given value.
|
The type of the slot is determined by the given value.
|
||||||
@ -149,6 +188,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
|
|||||||
name: Name to use for the slot variable.
|
name: Name to use for the slot variable.
|
||||||
colocate_with_primary: Boolean. If True the slot is located
|
colocate_with_primary: Boolean. If True the slot is located
|
||||||
on the same device as `primary`.
|
on the same device as `primary`.
|
||||||
|
copy_xla_sharding: Boolean. If True also copies XLA sharding
|
||||||
|
from primary.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Variable` object.
|
A `Variable` object.
|
||||||
@ -167,14 +208,31 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
|
|||||||
if colocate_with_primary:
|
if colocate_with_primary:
|
||||||
distribution_strategy = distribution_strategy_context.get_strategy()
|
distribution_strategy = distribution_strategy_context.get_strategy()
|
||||||
with distribution_strategy.extended.colocate_vars_with(primary):
|
with distribution_strategy.extended.colocate_vars_with(primary):
|
||||||
return _create_slot_var(primary, initializer, "", validate_shape, shape,
|
return _create_slot_var(
|
||||||
dtype)
|
primary,
|
||||||
|
initializer,
|
||||||
|
"",
|
||||||
|
validate_shape,
|
||||||
|
shape,
|
||||||
|
dtype,
|
||||||
|
copy_xla_sharding=copy_xla_sharding)
|
||||||
else:
|
else:
|
||||||
return _create_slot_var(primary, initializer, "", validate_shape, shape,
|
return _create_slot_var(
|
||||||
dtype)
|
primary,
|
||||||
|
initializer,
|
||||||
|
"",
|
||||||
|
validate_shape,
|
||||||
|
shape,
|
||||||
|
dtype,
|
||||||
|
copy_xla_sharding=copy_xla_sharding)
|
||||||
|
|
||||||
|
|
||||||
def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
|
def create_zeros_slot(primary,
|
||||||
|
name,
|
||||||
|
dtype=None,
|
||||||
|
colocate_with_primary=True,
|
||||||
|
*,
|
||||||
|
copy_xla_sharding=False):
|
||||||
"""Create a slot initialized to 0 with same shape as the primary object.
|
"""Create a slot initialized to 0 with same shape as the primary object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -183,6 +241,8 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
|
|||||||
dtype: Type of the slot variable. Defaults to the type of `primary`.
|
dtype: Type of the slot variable. Defaults to the type of `primary`.
|
||||||
colocate_with_primary: Boolean. If True the slot is located
|
colocate_with_primary: Boolean. If True the slot is located
|
||||||
on the same device as `primary`.
|
on the same device as `primary`.
|
||||||
|
copy_xla_sharding: Boolean. If True also copies XLA sharding
|
||||||
|
from primary.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Variable` object.
|
A `Variable` object.
|
||||||
@ -193,13 +253,22 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
|
|||||||
if slot_shape.is_fully_defined():
|
if slot_shape.is_fully_defined():
|
||||||
initializer = init_ops.zeros_initializer()
|
initializer = init_ops.zeros_initializer()
|
||||||
return create_slot_with_initializer(
|
return create_slot_with_initializer(
|
||||||
primary, initializer, slot_shape, dtype, name,
|
primary,
|
||||||
colocate_with_primary=colocate_with_primary)
|
initializer,
|
||||||
|
slot_shape,
|
||||||
|
dtype,
|
||||||
|
name,
|
||||||
|
colocate_with_primary=colocate_with_primary,
|
||||||
|
copy_xla_sharding=copy_xla_sharding)
|
||||||
else:
|
else:
|
||||||
if isinstance(primary, variables.Variable):
|
if isinstance(primary, variables.Variable):
|
||||||
slot_shape = array_ops.shape(primary.initialized_value())
|
slot_shape = array_ops.shape(primary.initialized_value())
|
||||||
else:
|
else:
|
||||||
slot_shape = array_ops.shape(primary)
|
slot_shape = array_ops.shape(primary)
|
||||||
val = array_ops.zeros(slot_shape, dtype=dtype)
|
val = array_ops.zeros(slot_shape, dtype=dtype)
|
||||||
return create_slot(primary, val, name,
|
return create_slot(
|
||||||
colocate_with_primary=colocate_with_primary)
|
primary,
|
||||||
|
val,
|
||||||
|
name,
|
||||||
|
colocate_with_primary=colocate_with_primary,
|
||||||
|
copy_xla_sharding=copy_xla_sharding)
|
||||||
|
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -174,6 +177,31 @@ class SlotCreatorTest(test.TestCase):
|
|||||||
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
|
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
|
||||||
self.assertAllEqual(1.0, slot)
|
self.assertAllEqual(1.0, slot)
|
||||||
|
|
||||||
|
def testCreateSlotFromVariableCopyXlaSharding(self):
|
||||||
|
# slot_creator is used only in optimizer V1.
|
||||||
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
|
v = variables.Variable([1.0, 2.5], name="var")
|
||||||
|
v = xla_sharding.mesh_split(
|
||||||
|
v, np.array([0, 1]), [0], use_sharding_op=False)
|
||||||
|
slot = slot_creator.create_slot(
|
||||||
|
v, v.initialized_value(), name="slot", copy_xla_sharding=True)
|
||||||
|
self.assertEqual(
|
||||||
|
xla_sharding.get_tensor_sharding(v),
|
||||||
|
xla_sharding.get_tensor_sharding(slot))
|
||||||
|
|
||||||
|
def testCreateZerosSlotFromVariableCopyXlaSharding(self):
|
||||||
|
# slot_creator is used only in optimizer V1.
|
||||||
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
|
v = variables.Variable([1.0, 2.5], name="var")
|
||||||
|
v = xla_sharding.mesh_split(
|
||||||
|
v, np.array([0, 1]), [0], use_sharding_op=False)
|
||||||
|
with ops.control_dependencies(None):
|
||||||
|
slot = slot_creator.create_zeros_slot(
|
||||||
|
v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True)
|
||||||
|
self.assertEqual(
|
||||||
|
xla_sharding.get_tensor_sharding(v),
|
||||||
|
xla_sharding.get_tensor_sharding(slot))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user