Rename ReplicaContext.replica_id
to replica_id_in_sync_group
.
PiperOrigin-RevId: 221602399
This commit is contained in:
parent
9a608da836
commit
cfdfcc311c
@ -47,7 +47,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
from tensorflow.python.training import distribution_strategy_context
|
||||
from tensorflow.python.training import distribution_strategy_context as ds_context
|
||||
|
||||
|
||||
class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase):
|
||||
@ -272,9 +272,9 @@ class MirroredStrategyOptimizerV2Test(test.TestCase):
|
||||
|
||||
|
||||
def _replica_id():
|
||||
# TODO(cjfj): Return `replica_id` directly, once it is a `Tensor`.
|
||||
# TODO(cjfj): Return `replica_id_...` directly, once it is a `Tensor`.
|
||||
return constant_op.constant(
|
||||
distribution_strategy_context.get_replica_context().replica_id)
|
||||
ds_context.get_replica_context().replica_id_in_sync_group)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -829,4 +829,5 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext):
|
||||
@property
|
||||
def devices(self):
|
||||
distribute_lib.require_replica_context(self)
|
||||
return [self._distribution_strategy.worker_devices[self._replica_id]]
|
||||
ds = self._distribution_strategy
|
||||
return [ds.worker_devices[self._replica_id_in_sync_group]]
|
||||
|
@ -48,7 +48,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import device_util
|
||||
from tensorflow.python.training import distribution_strategy_context
|
||||
from tensorflow.python.training import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import optimizer as optimizer_lib
|
||||
from tensorflow.python.training import server_lib
|
||||
@ -183,8 +183,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
# This variable should be created only once across the threads because of
|
||||
# special variable_creator functions used by `dist.call_for_each_replica`.
|
||||
v = variable_scope.variable(1.0, name="foo")
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -201,8 +200,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
|
||||
def model_fn():
|
||||
v = variable_scope.variable(1.0)
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -222,8 +220,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
vs = []
|
||||
for i in range(5):
|
||||
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return vs
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -245,8 +242,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
|
||||
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
|
||||
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return vs
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -269,8 +265,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
def model_fn():
|
||||
replica_id = self.evaluate(_replica_id())
|
||||
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -292,8 +287,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
layer2 = core.Dense(1)
|
||||
layer2(features)
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
layer3 = core.Dense(1)
|
||||
layer3(features)
|
||||
return [(layer1.kernel, layer1.bias),
|
||||
@ -330,8 +324,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
with variable_scope.variable_scope("common"):
|
||||
v1 = variable_scope.variable(1.0, name="var1")
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
v2 = variable_scope.variable(
|
||||
1.0,
|
||||
name="var2",
|
||||
@ -374,8 +367,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
with variable_scope.variable_scope("common"):
|
||||
v1 = variable_scope.get_variable("var1", [1])
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
v2 = variable_scope.get_variable(
|
||||
"var2", [1],
|
||||
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
||||
@ -564,8 +556,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
|
||||
def model_fn():
|
||||
v = variable_scope.variable(1.0, name="foo")
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -582,8 +573,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
|
||||
def model_fn(name):
|
||||
v = variable_scope.variable(1.0, name=name)
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -683,8 +673,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
def model_fn():
|
||||
with ops.name_scope("foo"):
|
||||
a = constant_op.constant(1.0, name="a")
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
b = constant_op.constant(1.0, name="b")
|
||||
return a, b
|
||||
|
||||
@ -705,8 +694,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
def model_fn():
|
||||
with ops.name_scope(None, "foo"):
|
||||
a = constant_op.constant(1.0, name="a")
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
b = constant_op.constant(2.0, name="b")
|
||||
return a, b
|
||||
|
||||
@ -734,8 +722,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
def model_fn():
|
||||
b = variable_scope.variable(1.0, name="b")
|
||||
with ops.name_scope("foo"):
|
||||
c = distribution_strategy_context.get_replica_context().merge_call(
|
||||
in_cross_replica)
|
||||
c = ds_context.get_replica_context().merge_call(in_cross_replica)
|
||||
return b, c
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -767,8 +754,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
def model_fn():
|
||||
b = variable_scope.get_variable("b", [1])
|
||||
with ops.name_scope("foo"):
|
||||
c = distribution_strategy_context.get_replica_context().merge_call(
|
||||
in_cross_replica)
|
||||
c = ds_context.get_replica_context().merge_call(in_cross_replica)
|
||||
return b, c
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
@ -951,7 +937,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
|
||||
def model_fn():
|
||||
value = math_ops.cast(
|
||||
distribution_strategy_context.get_replica_context().replica_id,
|
||||
ds_context.get_replica_context().replica_id_in_sync_group,
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign(value)
|
||||
|
||||
@ -1025,7 +1011,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
|
||||
def model_fn():
|
||||
value = math_ops.cast(
|
||||
distribution_strategy_context.get_replica_context().replica_id,
|
||||
ds_context.get_replica_context().replica_id_in_sync_group,
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign_add(value)
|
||||
|
||||
@ -1091,7 +1077,7 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||
|
||||
def model_fn():
|
||||
value = math_ops.cast(
|
||||
distribution_strategy_context.get_replica_context().replica_id,
|
||||
ds_context.get_replica_context().replica_id_in_sync_group,
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign_sub(value)
|
||||
|
||||
@ -1463,9 +1449,9 @@ class MultiWorkerMirroredStrategyTestWithChief(
|
||||
|
||||
|
||||
def _replica_id():
|
||||
# TODO(cjfj): Return `replica_id` directly, once it is a `Tensor`.
|
||||
# TODO(cjfj): Return `replica_id_...` directly, once it is a `Tensor`.
|
||||
return constant_op.constant(
|
||||
distribution_strategy_context.get_replica_context().replica_id)
|
||||
ds_context.get_replica_context().replica_id_in_sync_group)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -26,7 +26,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import distribution_strategy_context
|
||||
from tensorflow.python.training import distribution_strategy_context as ds_context
|
||||
|
||||
|
||||
class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
|
||||
@ -87,8 +87,7 @@ class VariableCreatorStackTest(test.TestCase):
|
||||
v = variable_scope.variable(1.0)
|
||||
|
||||
# This will pause the current thread, and execute the other thread.
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
lambda _: _)
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
def main_thread_creator(next_creator, *args, **kwargs):
|
||||
@ -106,9 +105,9 @@ class VariableCreatorStackTest(test.TestCase):
|
||||
|
||||
|
||||
def _replica_id():
|
||||
# TODO(cjfj): Return `replica_id` directly, once it is a `Tensor`.
|
||||
# TODO(cjfj): Return `replica_id_...` directly, once it is a `Tensor`.
|
||||
return constant_op.constant(
|
||||
distribution_strategy_context.get_replica_context().replica_id)
|
||||
ds_context.get_replica_context().replica_id_in_sync_group)
|
||||
|
||||
|
||||
class MultiWorkerMirroredStrategyTest(test.TestCase):
|
||||
|
@ -180,7 +180,7 @@ class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
|
||||
|
||||
def __init__(self, distribution_strategy):
|
||||
distribute_lib.ReplicaContext.__init__(
|
||||
self, distribution_strategy, replica_id=0)
|
||||
self, distribution_strategy, replica_id_in_sync_group=0)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
|
@ -44,7 +44,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import device_util
|
||||
from tensorflow.python.training import distribution_strategy_context
|
||||
from tensorflow.python.training import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
CHIEF = run_config.TaskType.CHIEF
|
||||
@ -98,7 +98,7 @@ class ParameterServerStrategyTestBase(
|
||||
else:
|
||||
last_part_device = (
|
||||
'device:GPU:%d' %
|
||||
distribution_strategy_context.get_replica_context().replica_id)
|
||||
ds_context.get_replica_context().replica_id_in_sync_group)
|
||||
|
||||
a = constant_op.constant(1.0)
|
||||
b = constant_op.constant(2.0)
|
||||
@ -265,7 +265,7 @@ class ParameterServerStrategyTestBase(
|
||||
else:
|
||||
replica_compute_device = (
|
||||
'/device:GPU:%d' %
|
||||
distribution_strategy_context.get_replica_context().replica_id)
|
||||
ds_context.get_replica_context().replica_id_in_sync_group)
|
||||
replica_compute_device = device_util.canonicalize(
|
||||
replica_compute_device)
|
||||
|
||||
@ -274,7 +274,7 @@ class ParameterServerStrategyTestBase(
|
||||
else:
|
||||
replica_variable_device = (
|
||||
'/device:GPU:%d' %
|
||||
distribution_strategy_context.get_replica_context().replica_id)
|
||||
ds_context.get_replica_context().replica_id_in_sync_group)
|
||||
replica_variable_device = device_util.canonicalize(
|
||||
replica_variable_device)
|
||||
|
||||
|
@ -29,7 +29,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import distribution_strategy_context
|
||||
from tensorflow.python.training import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.training import optimizer
|
||||
|
||||
|
||||
@ -46,8 +46,7 @@ def _raise_exception_fn(_=None):
|
||||
# Must be the argument to a distribution.call_for_each_replica() call, calls a
|
||||
# get_replica_context().merge_call() that raises an exception.
|
||||
def _merge_raises_fn():
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
_raise_exception_fn)
|
||||
ds_context.get_replica_context().merge_call(_raise_exception_fn)
|
||||
|
||||
|
||||
# Must be the argument to a get_replica_context().merge_call() call, calls
|
||||
@ -60,8 +59,7 @@ def _call_raises_fn(dist):
|
||||
# calls a get_replica_context().merge_call() that calls a
|
||||
# call_for_each_replica() that raises an exception.
|
||||
def _merge_call_raises_fn():
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
_call_raises_fn)
|
||||
ds_context.get_replica_context().merge_call(_call_raises_fn)
|
||||
|
||||
|
||||
# Must be the argument to a get_replica_context().merge_call() call, calls
|
||||
@ -75,8 +73,7 @@ def _call_merge_raises_fn(dist):
|
||||
# get_replica_context().merge_call() that calls a call_for_each_replica() that
|
||||
# calls a get_replica_context().merge_call() that raises an exception.
|
||||
def _merge_call_merge_raises_fn():
|
||||
distribution_strategy_context.get_replica_context().merge_call(
|
||||
_call_merge_raises_fn)
|
||||
ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
|
||||
|
||||
|
||||
class DistributionTestBase(test.TestCase):
|
||||
@ -193,8 +190,7 @@ class DistributionTestBase(test.TestCase):
|
||||
expected_devices = [False] * len(d.worker_devices)
|
||||
|
||||
def mark_devices_fn():
|
||||
replica_id = (
|
||||
distribution_strategy_context.get_replica_context().replica_id)
|
||||
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
|
||||
self.assertLess(replica_id, len(d.worker_devices))
|
||||
self.assertFalse(expected_devices[replica_id])
|
||||
expected_devices[replica_id] = True
|
||||
|
@ -587,7 +587,7 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
# TODO(sourabhbajaj): Call for each tower should be updating this.
|
||||
def __init__(self, distribution_strategy):
|
||||
distribute_lib.ReplicaContext.__init__(
|
||||
self, distribution_strategy, replica_id=0)
|
||||
self, distribution_strategy, replica_id_in_sync_group=0)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@ -596,4 +596,5 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
@property
|
||||
def devices(self):
|
||||
distribute_lib.require_replica_context(self)
|
||||
return [self._distribution_strategy.worker_devices[self._replica_id]]
|
||||
ds = self._distribution_strategy
|
||||
return [ds.worker_devices[self._replica_id_in_sync_group]]
|
||||
|
@ -50,7 +50,8 @@ def skip_summary():
|
||||
# alternatives to override default behavior. (e.g. run on last replica,
|
||||
# compute sum or mean across replicas).
|
||||
replica_context = distribution_strategy_context.get_replica_context()
|
||||
return replica_context and replica_context.replica_id > 0
|
||||
# TODO(cjfj): Also check is sync group ID > 0?
|
||||
return replica_context and replica_context.replica_id_in_sync_group > 0
|
||||
|
||||
|
||||
def clean_tag(name):
|
||||
|
@ -814,7 +814,7 @@ class DistributionStrategy(object):
|
||||
"""Run `fn` once per replica.
|
||||
|
||||
`fn` may call `tf.get_replica_context()` to access methods such as
|
||||
`replica_id()` and `merge_call()`.
|
||||
`replica_id_in_sync_group` and `merge_call()`.
|
||||
|
||||
`merge_call()` is used to communicate between the replicas and
|
||||
re-enter the cross-replica context. All replicas pause their execution
|
||||
@ -832,7 +832,7 @@ class DistributionStrategy(object):
|
||||
# Called once per replica in `distribution`, in a "replica" context.
|
||||
def fn(three):
|
||||
replica_ctx = tf.get_replica_context()
|
||||
v = three + replica_ctx.replica_id
|
||||
v = three + replica_ctx.replica_id_in_sync_group
|
||||
# Computes the sum of the `v` values across all replicas.
|
||||
s = replica_ctx.merge_call(merge_fn, args=(v,))
|
||||
return s + v
|
||||
@ -1182,11 +1182,11 @@ class DistributionStrategy(object):
|
||||
class ReplicaContext(object):
|
||||
"""DistributionStrategy API inside a `call_for_each_replica()` call."""
|
||||
|
||||
def __init__(self, distribution_strategy, replica_id):
|
||||
def __init__(self, distribution_strategy, replica_id_in_sync_group):
|
||||
self._distribution_strategy = distribution_strategy
|
||||
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
|
||||
self)
|
||||
self._replica_id = replica_id
|
||||
self._replica_id_in_sync_group = replica_id_in_sync_group
|
||||
|
||||
def __enter__(self):
|
||||
_push_per_thread_mode(self._thread_context)
|
||||
@ -1255,10 +1255,10 @@ class ReplicaContext(object):
|
||||
return self._distribution_strategy.num_replicas_in_sync
|
||||
|
||||
@property
|
||||
def replica_id(self):
|
||||
def replica_id_in_sync_group(self):
|
||||
"""Which replica is being defined, a number from 0 to `num_replicas - 1`."""
|
||||
require_replica_context(self)
|
||||
return self._replica_id
|
||||
return self._replica_id_in_sync_group
|
||||
|
||||
@property
|
||||
def distribution_strategy(self):
|
||||
@ -1327,7 +1327,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
|
||||
raise NotImplementedError("TODO")
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
with ReplicaContext(self, replica_id=0):
|
||||
with ReplicaContext(self, replica_id_in_sync_group=0):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def _reduce(self, reduce_op, value, destinations):
|
||||
|
@ -41,7 +41,7 @@ def _get_test_variable(name, synchronization, aggregation):
|
||||
class _TestStrategy(distribute_lib.DistributionStrategy):
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
with _TestReplicaContext(self, replica_id=0):
|
||||
with _TestReplicaContext(self, replica_id_in_sync_group=0):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
|
@ -196,7 +196,7 @@ def _get_default_distribution_strategy():
|
||||
def _get_default_replica_context():
|
||||
if _defaults["replica_context"] is None:
|
||||
_defaults["replica_context"] = distribute_lib.ReplicaContext(
|
||||
_get_default_distribution_strategy(), replica_id=0)
|
||||
_get_default_distribution_strategy(), replica_id_in_sync_group=0)
|
||||
return _defaults["replica_context"]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user