Rename ReplicaContext.replica_id to replica_id_in_sync_group.

PiperOrigin-RevId: 221602399
This commit is contained in:
Chris Jones 2018-11-15 04:27:12 -08:00 committed by TensorFlower Gardener
parent 9a608da836
commit cfdfcc311c
12 changed files with 53 additions and 69 deletions

View File

@ -47,7 +47,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import distribution_strategy_context as ds_context
class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase):
@ -272,9 +272,9 @@ class MirroredStrategyOptimizerV2Test(test.TestCase):
def _replica_id():
# TODO(cjfj): Return `replica_id` directly, once it is a `Tensor`.
# TODO(cjfj): Return `replica_id_...` directly, once it is a `Tensor`.
return constant_op.constant(
distribution_strategy_context.get_replica_context().replica_id)
ds_context.get_replica_context().replica_id_in_sync_group)
if __name__ == '__main__':

View File

@ -829,4 +829,5 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext):
@property
def devices(self):
distribute_lib.require_replica_context(self)
return [self._distribution_strategy.worker_devices[self._replica_id]]
ds = self._distribution_strategy
return [ds.worker_devices[self._replica_id_in_sync_group]]

View File

@ -48,7 +48,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib
@ -183,8 +183,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
# This variable should be created only once across the threads because of
# special variable_creator functions used by `dist.call_for_each_replica`.
v = variable_scope.variable(1.0, name="foo")
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@ -201,8 +200,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
v = variable_scope.variable(1.0)
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@ -222,8 +220,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
vs = []
for i in range(5):
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
return vs
dist = mirrored_strategy.MirroredStrategy(
@ -245,8 +242,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
return vs
dist = mirrored_strategy.MirroredStrategy(
@ -269,8 +265,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
replica_id = self.evaluate(_replica_id())
v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@ -292,8 +287,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
layer2 = core.Dense(1)
layer2(features)
# This will pause the current thread, and execute the other thread.
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
layer3 = core.Dense(1)
layer3(features)
return [(layer1.kernel, layer1.bias),
@ -330,8 +324,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with variable_scope.variable_scope("common"):
v1 = variable_scope.variable(1.0, name="var1")
# This will pause the current thread, and execute the other thread.
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
v2 = variable_scope.variable(
1.0,
name="var2",
@ -374,8 +367,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with variable_scope.variable_scope("common"):
v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
v2 = variable_scope.get_variable(
"var2", [1],
synchronization=variable_scope.VariableSynchronization.ON_READ,
@ -564,8 +556,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
v = variable_scope.variable(1.0, name="foo")
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@ -582,8 +573,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(name):
v = variable_scope.variable(1.0, name=name)
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@ -683,8 +673,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
with ops.name_scope("foo"):
a = constant_op.constant(1.0, name="a")
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
b = constant_op.constant(1.0, name="b")
return a, b
@ -705,8 +694,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
with ops.name_scope(None, "foo"):
a = constant_op.constant(1.0, name="a")
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
b = constant_op.constant(2.0, name="b")
return a, b
@ -734,8 +722,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
b = variable_scope.variable(1.0, name="b")
with ops.name_scope("foo"):
c = distribution_strategy_context.get_replica_context().merge_call(
in_cross_replica)
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
dist = mirrored_strategy.MirroredStrategy(
@ -767,8 +754,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
b = variable_scope.get_variable("b", [1])
with ops.name_scope("foo"):
c = distribution_strategy_context.get_replica_context().merge_call(
in_cross_replica)
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
dist = mirrored_strategy.MirroredStrategy(
@ -951,7 +937,7 @@ class MirroredVariableUpdateTest(test.TestCase):
def model_fn():
value = math_ops.cast(
distribution_strategy_context.get_replica_context().replica_id,
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign(value)
@ -1025,7 +1011,7 @@ class MirroredVariableUpdateTest(test.TestCase):
def model_fn():
value = math_ops.cast(
distribution_strategy_context.get_replica_context().replica_id,
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign_add(value)
@ -1091,7 +1077,7 @@ class MirroredVariableUpdateTest(test.TestCase):
def model_fn():
value = math_ops.cast(
distribution_strategy_context.get_replica_context().replica_id,
ds_context.get_replica_context().replica_id_in_sync_group,
mirrored_var.dtype)
return mirrored_var.assign_sub(value)
@ -1463,9 +1449,9 @@ class MultiWorkerMirroredStrategyTestWithChief(
def _replica_id():
# TODO(cjfj): Return `replica_id` directly, once it is a `Tensor`.
# TODO(cjfj): Return `replica_id_...` directly, once it is a `Tensor`.
return constant_op.constant(
distribution_strategy_context.get_replica_context().replica_id)
ds_context.get_replica_context().replica_id_in_sync_group)
if __name__ == "__main__":

View File

@ -26,7 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import distribution_strategy_context as ds_context
class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
@ -87,8 +87,7 @@ class VariableCreatorStackTest(test.TestCase):
v = variable_scope.variable(1.0)
# This will pause the current thread, and execute the other thread.
distribution_strategy_context.get_replica_context().merge_call(
lambda _: _)
ds_context.get_replica_context().merge_call(lambda _: _)
return v
def main_thread_creator(next_creator, *args, **kwargs):
@ -106,9 +105,9 @@ class VariableCreatorStackTest(test.TestCase):
def _replica_id():
# TODO(cjfj): Return `replica_id` directly, once it is a `Tensor`.
# TODO(cjfj): Return `replica_id_...` directly, once it is a `Tensor`.
return constant_op.constant(
distribution_strategy_context.get_replica_context().replica_id)
ds_context.get_replica_context().replica_id_in_sync_group)
class MultiWorkerMirroredStrategyTest(test.TestCase):

View File

@ -180,7 +180,7 @@ class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
def __init__(self, distribution_strategy):
distribute_lib.ReplicaContext.__init__(
self, distribution_strategy, replica_id=0)
self, distribution_strategy, replica_id_in_sync_group=0)
@property
def device(self):

View File

@ -44,7 +44,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import training_util
CHIEF = run_config.TaskType.CHIEF
@ -98,7 +98,7 @@ class ParameterServerStrategyTestBase(
else:
last_part_device = (
'device:GPU:%d' %
distribution_strategy_context.get_replica_context().replica_id)
ds_context.get_replica_context().replica_id_in_sync_group)
a = constant_op.constant(1.0)
b = constant_op.constant(2.0)
@ -265,7 +265,7 @@ class ParameterServerStrategyTestBase(
else:
replica_compute_device = (
'/device:GPU:%d' %
distribution_strategy_context.get_replica_context().replica_id)
ds_context.get_replica_context().replica_id_in_sync_group)
replica_compute_device = device_util.canonicalize(
replica_compute_device)
@ -274,7 +274,7 @@ class ParameterServerStrategyTestBase(
else:
replica_variable_device = (
'/device:GPU:%d' %
distribution_strategy_context.get_replica_context().replica_id)
ds_context.get_replica_context().replica_id_in_sync_group)
replica_variable_device = device_util.canonicalize(
replica_variable_device)

View File

@ -29,7 +29,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import optimizer
@ -46,8 +46,7 @@ def _raise_exception_fn(_=None):
# Must be the argument to a distribution.call_for_each_replica() call, calls a
# get_replica_context().merge_call() that raises an exception.
def _merge_raises_fn():
distribution_strategy_context.get_replica_context().merge_call(
_raise_exception_fn)
ds_context.get_replica_context().merge_call(_raise_exception_fn)
# Must be the argument to a get_replica_context().merge_call() call, calls
@ -60,8 +59,7 @@ def _call_raises_fn(dist):
# calls a get_replica_context().merge_call() that calls a
# call_for_each_replica() that raises an exception.
def _merge_call_raises_fn():
distribution_strategy_context.get_replica_context().merge_call(
_call_raises_fn)
ds_context.get_replica_context().merge_call(_call_raises_fn)
# Must be the argument to a get_replica_context().merge_call() call, calls
@ -75,8 +73,7 @@ def _call_merge_raises_fn(dist):
# get_replica_context().merge_call() that calls a call_for_each_replica() that
# calls a get_replica_context().merge_call() that raises an exception.
def _merge_call_merge_raises_fn():
distribution_strategy_context.get_replica_context().merge_call(
_call_merge_raises_fn)
ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
class DistributionTestBase(test.TestCase):
@ -193,8 +190,7 @@ class DistributionTestBase(test.TestCase):
expected_devices = [False] * len(d.worker_devices)
def mark_devices_fn():
replica_id = (
distribution_strategy_context.get_replica_context().replica_id)
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
self.assertLess(replica_id, len(d.worker_devices))
self.assertFalse(expected_devices[replica_id])
expected_devices[replica_id] = True

View File

@ -587,7 +587,7 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
# TODO(sourabhbajaj): Call for each tower should be updating this.
def __init__(self, distribution_strategy):
distribute_lib.ReplicaContext.__init__(
self, distribution_strategy, replica_id=0)
self, distribution_strategy, replica_id_in_sync_group=0)
@property
def device(self):
@ -596,4 +596,5 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
@property
def devices(self):
distribute_lib.require_replica_context(self)
return [self._distribution_strategy.worker_devices[self._replica_id]]
ds = self._distribution_strategy
return [ds.worker_devices[self._replica_id_in_sync_group]]

View File

@ -50,7 +50,8 @@ def skip_summary():
# alternatives to override default behavior. (e.g. run on last replica,
# compute sum or mean across replicas).
replica_context = distribution_strategy_context.get_replica_context()
return replica_context and replica_context.replica_id > 0
# TODO(cjfj): Also check is sync group ID > 0?
return replica_context and replica_context.replica_id_in_sync_group > 0
def clean_tag(name):

View File

@ -814,7 +814,7 @@ class DistributionStrategy(object):
"""Run `fn` once per replica.
`fn` may call `tf.get_replica_context()` to access methods such as
`replica_id()` and `merge_call()`.
`replica_id_in_sync_group` and `merge_call()`.
`merge_call()` is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
@ -832,7 +832,7 @@ class DistributionStrategy(object):
# Called once per replica in `distribution`, in a "replica" context.
def fn(three):
replica_ctx = tf.get_replica_context()
v = three + replica_ctx.replica_id
v = three + replica_ctx.replica_id_in_sync_group
# Computes the sum of the `v` values across all replicas.
s = replica_ctx.merge_call(merge_fn, args=(v,))
return s + v
@ -1182,11 +1182,11 @@ class DistributionStrategy(object):
class ReplicaContext(object):
"""DistributionStrategy API inside a `call_for_each_replica()` call."""
def __init__(self, distribution_strategy, replica_id):
def __init__(self, distribution_strategy, replica_id_in_sync_group):
self._distribution_strategy = distribution_strategy
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
self)
self._replica_id = replica_id
self._replica_id_in_sync_group = replica_id_in_sync_group
def __enter__(self):
_push_per_thread_mode(self._thread_context)
@ -1255,10 +1255,10 @@ class ReplicaContext(object):
return self._distribution_strategy.num_replicas_in_sync
@property
def replica_id(self):
def replica_id_in_sync_group(self):
"""Which replica is being defined, a number from 0 to `num_replicas - 1`."""
require_replica_context(self)
return self._replica_id
return self._replica_id_in_sync_group
@property
def distribution_strategy(self):
@ -1327,7 +1327,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
raise NotImplementedError("TODO")
def _call_for_each_replica(self, fn, args, kwargs):
with ReplicaContext(self, replica_id=0):
with ReplicaContext(self, replica_id_in_sync_group=0):
return fn(*args, **kwargs)
def _reduce(self, reduce_op, value, destinations):

View File

@ -41,7 +41,7 @@ def _get_test_variable(name, synchronization, aggregation):
class _TestStrategy(distribute_lib.DistributionStrategy):
def _call_for_each_replica(self, fn, args, kwargs):
with _TestReplicaContext(self, replica_id=0):
with _TestReplicaContext(self, replica_id_in_sync_group=0):
return fn(*args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs):

View File

@ -196,7 +196,7 @@ def _get_default_distribution_strategy():
def _get_default_replica_context():
if _defaults["replica_context"] is None:
_defaults["replica_context"] = distribute_lib.ReplicaContext(
_get_default_distribution_strategy(), replica_id=0)
_get_default_distribution_strategy(), replica_id_in_sync_group=0)
return _defaults["replica_context"]