Enable assign, assign_add and assign_sub to be called on Mirrored Variables in cross tower and tower context.
PiperOrigin-RevId: 202162272
This commit is contained in:
parent
d10213099d
commit
bfda539bef
@ -32,7 +32,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import device_util
|
||||
|
||||
|
||||
def _validate_destinations(destinations):
|
||||
def validate_destinations(destinations):
|
||||
if not isinstance(destinations,
|
||||
(value_lib.DistributedValues, six.string_types, list)):
|
||||
raise ValueError("destinations must be one of a `DistributedValues` object,"
|
||||
@ -55,7 +55,7 @@ def _validate_value_destination_pairs(value_destination_pairs):
|
||||
|
||||
|
||||
# TODO(yuefengz): consider calling this function in the caller of CrossTowerOps.
|
||||
def _get_devices_from(destinations):
|
||||
def get_devices_from(destinations):
|
||||
if isinstance(destinations, value_lib.DistributedValues):
|
||||
return list(destinations.devices)
|
||||
elif isinstance(destinations, six.string_types):
|
||||
@ -65,7 +65,7 @@ def _get_devices_from(destinations):
|
||||
|
||||
|
||||
def _devices_match(left, right):
|
||||
return set(_get_devices_from(left)) == set(_get_devices_from(right))
|
||||
return set(get_devices_from(left)) == set(get_devices_from(right))
|
||||
|
||||
|
||||
def _all_devices_match(value_destination_pairs):
|
||||
@ -80,7 +80,7 @@ def _all_devices_match(value_destination_pairs):
|
||||
|
||||
def _simple_broadcast(value, destinations):
|
||||
index = {}
|
||||
devices = _get_devices_from(destinations)
|
||||
devices = get_devices_from(destinations)
|
||||
for d in devices:
|
||||
index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
|
||||
value, d)
|
||||
@ -146,7 +146,7 @@ class CrossTowerOps(object):
|
||||
if not isinstance(per_device_value, value_lib.PerDevice):
|
||||
raise ValueError("`per_device_value` must be a `PerDevice` object.")
|
||||
if destinations is not None:
|
||||
_validate_destinations(destinations)
|
||||
validate_destinations(destinations)
|
||||
return self._reduce(method_string, per_device_value, destinations)
|
||||
|
||||
def batch_reduce(self, method_string, value_destination_pairs):
|
||||
@ -173,7 +173,7 @@ class CrossTowerOps(object):
|
||||
"tuples of PerDevice objects and destinations")
|
||||
for _, d in value_destination_pairs:
|
||||
if d is not None:
|
||||
_validate_destinations(d)
|
||||
validate_destinations(d)
|
||||
|
||||
return self._batch_reduce(method_string, value_destination_pairs)
|
||||
|
||||
@ -187,7 +187,7 @@ class CrossTowerOps(object):
|
||||
Returns:
|
||||
a Mirrored object.
|
||||
"""
|
||||
_validate_destinations(destinations)
|
||||
validate_destinations(destinations)
|
||||
return self._broadcast(tensor, destinations)
|
||||
|
||||
def _reduce(self, method_string, per_device_value, destinations):
|
||||
@ -221,7 +221,7 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
|
||||
super(ReductionToOneDeviceCrossTowerOps, self).__init__()
|
||||
|
||||
def _reduce(self, method_string, per_device_value, destinations):
|
||||
devices = _get_devices_from(destinations or per_device_value)
|
||||
devices = get_devices_from(destinations or per_device_value)
|
||||
reduce_to_device = self.reduce_to_device or devices[0]
|
||||
reduced = _simple_reduce(per_device_value, reduce_to_device,
|
||||
self.accumulation_fn, method_string)
|
||||
@ -501,7 +501,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
|
||||
logging.WARN,
|
||||
"Efficient allreduce is not supported for IndexedSlices.", 10)
|
||||
|
||||
devices = _get_devices_from(destinations or per_device_value)
|
||||
devices = get_devices_from(destinations or per_device_value)
|
||||
reduce_to_device = devices[0]
|
||||
reduced = _simple_reduce(per_device_value, reduce_to_device,
|
||||
math_ops.add_n, method_string)
|
||||
|
@ -36,7 +36,7 @@ from tensorflow.python.training import device_util
|
||||
|
||||
|
||||
def _make_per_device(values, devices):
|
||||
devices = cross_tower_ops_lib._get_devices_from(devices)
|
||||
devices = cross_tower_ops_lib.get_devices_from(devices)
|
||||
assert len(values) == len(devices)
|
||||
index = {}
|
||||
for d, v in zip(devices, values):
|
||||
@ -53,7 +53,7 @@ def _fake_mirrored(value, devices):
|
||||
All components of the returned Mirrored have the same objects, which is not
|
||||
true in reality.
|
||||
"""
|
||||
devices = cross_tower_ops_lib._get_devices_from(devices)
|
||||
devices = cross_tower_ops_lib.get_devices_from(devices)
|
||||
return value_lib.Mirrored(
|
||||
{d: v for d, v in zip(devices, [value] * len(devices))})
|
||||
|
||||
|
@ -309,9 +309,29 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
||||
return self._cross_tower_ops
|
||||
|
||||
def _reduce(self, method_string, value, destinations):
|
||||
if len(self._devices) == 1 and not isinstance(value, values.PerDevice):
|
||||
value = values.PerDevice({self._devices[0]: value})
|
||||
assert isinstance(value, values.PerDevice)
|
||||
assert not isinstance(value, values.Mirrored)
|
||||
if not isinstance(value, values.PerDevice):
|
||||
if value == 0:
|
||||
return 0
|
||||
if method_string == "mean":
|
||||
return self._broadcast(value, destinations)
|
||||
|
||||
cross_tower_ops_lib.validate_destinations(destinations)
|
||||
if len(self._devices) == 1:
|
||||
if destinations:
|
||||
# TODO(anjalisridhar): Moves these methods to a device utility file?
|
||||
devices = cross_tower_ops_lib.get_devices_from(destinations)
|
||||
if len(devices) == 1:
|
||||
with ops.device(devices[0]):
|
||||
return array_ops.identity(value)
|
||||
else:
|
||||
value_updates = {}
|
||||
for d in devices:
|
||||
with ops.device(d):
|
||||
value_updates[d] = array_ops.identity(value)
|
||||
return values.Mirrored(value_updates)
|
||||
raise ValueError("A non PerDevice value cannot be reduced with the given "
|
||||
"method_string.")
|
||||
|
||||
return self._get_cross_tower_ops().reduce(
|
||||
method_string, value, destinations=destinations)
|
||||
|
@ -32,12 +32,14 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
|
||||
|
||||
GPU_TEST = "test_gpu" in sys.argv[0]
|
||||
|
||||
|
||||
@ -118,6 +120,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
|
||||
expected = sum(range(len(dist.worker_devices)))
|
||||
self.assertEqual(expected, self.evaluate(unwrapped[0]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testReduceToMultipleDestinations(self):
|
||||
if not GPU_TEST:
|
||||
self.skipTest("Not GPU test")
|
||||
|
||||
devices = ["/device:GPU:0"]
|
||||
if GPU_TEST:
|
||||
self.assertGreater(context.num_gpus(), 0)
|
||||
print(self.id().split(".")[-1], "devices:", ", ".join(devices))
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(devices)
|
||||
with dist.scope():
|
||||
reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0",
|
||||
"/device:GPU:0"])
|
||||
unwrapped = dist.unwrap(reduced)
|
||||
self.assertEqual(2, len(unwrapped))
|
||||
self.assertEqual(1.0, self.evaluate(unwrapped[0]))
|
||||
|
||||
|
||||
class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
|
||||
@ -581,5 +601,201 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
self.assertEquals(10.0, self.evaluate(ret_v_sum))
|
||||
|
||||
|
||||
class MirroredVariableUpdateTest(test.TestCase):
|
||||
# The following tests check assign, assign_add and assign_sub on Mirrored
|
||||
# variables in tower and cross tower context.
|
||||
config = config_pb2.ConfigProto()
|
||||
config.allow_soft_placement = True
|
||||
|
||||
def _skip_eager_if_gpus_less_than(self, num_gpus):
|
||||
if context.num_gpus() < num_gpus and context.executing_eagerly():
|
||||
self.skipTest("Enough GPUs not available for this test in eager mode.")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||
def testAssignMirroredVarTowerContextWithoutAggregationType(self):
|
||||
# Test that we always have an aggregation type set on the mirrored variable
|
||||
# if we assign to it in tower mode.
|
||||
self._skip_eager_if_gpus_less_than(1)
|
||||
def var_fn():
|
||||
v = variable_scope.variable(1.0, name="foo")
|
||||
return v
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
|
||||
with dist.scope():
|
||||
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
def model_fn():
|
||||
return mirrored_var.assign(5.0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "You must specify an aggregation method to update a "
|
||||
"MirroredVariable in Tower Context."):
|
||||
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||
def testAssignMirroredVarTowerContextWithSum(self):
|
||||
# Test that we don't reduce a non-per-device value with the "sum"
|
||||
# aggregation type.
|
||||
self._skip_eager_if_gpus_less_than(1)
|
||||
def var_fn():
|
||||
v = variable_scope.variable(1.0, name="foo")
|
||||
return v
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
|
||||
with dist.scope():
|
||||
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
|
||||
# TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
|
||||
# aggregation method.
|
||||
mirrored_var._aggregation_method = "sum"
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
def model_fn():
|
||||
return mirrored_var.assign(5.0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "A non PerDevice value cannot be reduced with the given "
|
||||
"method_string."):
|
||||
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||
def testAssignMirroredVarCrossTowerContext(self):
|
||||
self._skip_eager_if_gpus_less_than(1)
|
||||
def var_fn():
|
||||
return variable_scope.variable(1.0, name="foo")
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
|
||||
with dist.scope():
|
||||
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEquals(1.0, self.evaluate(mirrored_var))
|
||||
mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
|
||||
self.assertEquals(6.0, mirrored_var_result)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||
def testAssignMirroredVarTowerContext(self):
|
||||
self._skip_eager_if_gpus_less_than(1)
|
||||
def var_fn():
|
||||
return variable_scope.variable(1.0, name="foo")
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
|
||||
with dist.scope():
|
||||
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
|
||||
# TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
|
||||
# aggregation method.
|
||||
mirrored_var._aggregation_method = "mean"
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEquals(1.0, self.evaluate(mirrored_var))
|
||||
|
||||
def model_fn():
|
||||
value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign(value)
|
||||
|
||||
self.evaluate(dist.unwrap(dist.call_for_each_tower(
|
||||
model_fn, run_concurrently=False)))
|
||||
self.assertEquals(0.5, self.evaluate(mirrored_var))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||
def testAssignAddMirroredVarCrossTowerContext(self):
|
||||
self._skip_eager_if_gpus_less_than(1)
|
||||
def var_fn():
|
||||
return variable_scope.variable(1.0, name="foo")
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
|
||||
with dist.scope():
|
||||
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEquals(1.0, self.evaluate(mirrored_var))
|
||||
mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
|
||||
self.assertEquals(7.0, mirrored_var_result)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||
def testAssignAddMirroredVarTowerContext(self):
|
||||
self._skip_eager_if_gpus_less_than(1)
|
||||
def var_fn():
|
||||
return variable_scope.variable(1.0, name="foo")
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
|
||||
with dist.scope():
|
||||
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
|
||||
# TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
|
||||
# aggregation method.
|
||||
mirrored_var._aggregation_method = "mean"
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEquals(1.0, self.evaluate(mirrored_var))
|
||||
|
||||
def model_fn():
|
||||
value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign_add(value)
|
||||
|
||||
self.evaluate(dist.unwrap(dist.call_for_each_tower(
|
||||
model_fn, run_concurrently=False)))
|
||||
self.assertEquals(1.5, self.evaluate(mirrored_var))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||
def testAssignSubMirroredVarCrossTowerContext(self):
|
||||
self._skip_eager_if_gpus_less_than(1)
|
||||
def var_fn():
|
||||
return variable_scope.variable(5.0, name="foo")
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
|
||||
with dist.scope():
|
||||
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEquals(5.0, self.evaluate(mirrored_var))
|
||||
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
|
||||
self.assertEquals(3.0, mirrored_var_result)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(config=config)
|
||||
def testAssignSubMirroredVarTowerContext(self):
|
||||
self._skip_eager_if_gpus_less_than(1)
|
||||
def var_fn():
|
||||
return variable_scope.variable(5.0, name="foo")
|
||||
|
||||
dist = mirrored_strategy.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:CPU:0"])
|
||||
|
||||
with dist.scope():
|
||||
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
|
||||
# TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
|
||||
# aggregation method.
|
||||
mirrored_var._aggregation_method = "mean"
|
||||
self.assertIsInstance(mirrored_var, values.MirroredVariable)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertEquals(5.0, self.evaluate(mirrored_var))
|
||||
|
||||
def model_fn():
|
||||
value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
|
||||
mirrored_var.dtype)
|
||||
return mirrored_var.assign_sub(value)
|
||||
|
||||
self.evaluate(dist.unwrap(dist.call_for_each_tower(
|
||||
model_fn, run_concurrently=False)))
|
||||
self.assertEquals(4.5, self.evaluate(mirrored_var))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -23,7 +23,6 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import weakref
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.distribute.python import input_ops
|
||||
@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import device_util
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.training import saver
|
||||
@ -251,21 +251,6 @@ class DistributedVariable(DistributedDelegate):
|
||||
ops.register_dense_tensor_like_type(DistributedVariable)
|
||||
|
||||
|
||||
class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
|
||||
"""Class for defining how to restore a MirroredVariable."""
|
||||
|
||||
def __init__(self, mirrored_variable, primary_variable, name):
|
||||
self._mirrored_variable = mirrored_variable
|
||||
super(_MirroredSaveable, self).__init__(primary_variable, "", name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
"""Restore the same value into all variables."""
|
||||
tensor, = restored_tensors
|
||||
return control_flow_ops.group([
|
||||
_assign_on_device(d, v, tensor)
|
||||
for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _get_update_device():
|
||||
"""Validate we are in update/update_non_slot() and return current device.
|
||||
|
||||
@ -286,30 +271,82 @@ def _get_update_device():
|
||||
return device
|
||||
|
||||
|
||||
class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
|
||||
"""Class for defining how to restore a MirroredVariable."""
|
||||
|
||||
def __init__(self, mirrored_variable, primary_variable, name):
|
||||
self._mirrored_variable = mirrored_variable
|
||||
super(_MirroredSaveable, self).__init__(primary_variable, "", name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
"""Restore the same value into all variables."""
|
||||
tensor, = restored_tensors
|
||||
return control_flow_ops.group([
|
||||
_assign_on_device(d, v, tensor)
|
||||
for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class MirroredVariable(DistributedVariable, Mirrored,
|
||||
checkpointable.CheckpointableBase):
|
||||
"""Holds a map from device to variables whose values are kept in sync."""
|
||||
|
||||
def __init__(self, index, primary_var):
|
||||
def __init__(self, index, primary_var, aggregation_method=None):
|
||||
# Use a weakref to make it easy to map from the contained values
|
||||
# to the container without introducing a reference cycle.
|
||||
for v in six.itervalues(index):
|
||||
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
|
||||
self._primary_var = primary_var
|
||||
self._aggregation_method = aggregation_method
|
||||
super(MirroredVariable, self).__init__(index)
|
||||
|
||||
# We use _get_update_device() for the assign* methods to enforce
|
||||
# that we are in an update() function. The arguments to update() are
|
||||
# automatically unwrapped so the update() function would normally
|
||||
# see regular variables, not MirroredVariables. However, the update
|
||||
# function can still operate on wrapped MirroredVariables through
|
||||
# object members, captured arguments, etc. This is more likely in an
|
||||
# The arguments to update() are automatically unwrapped so the update()
|
||||
# function would normally see regular variables, not MirroredVariables.
|
||||
# However, the update function can still operate on wrapped MirroredVariables
|
||||
# through object members, captured arguments, etc. This is more likely in an
|
||||
# update_non_slot() function (like OptimizerV2._finish), which can
|
||||
# update several non-slot variables in one call.
|
||||
def _assign_func(self, *args, **kwargs):
|
||||
f = kwargs.pop("f")
|
||||
if distribute_lib.get_cross_tower_context():
|
||||
update_device = distribute_lib.get_update_device()
|
||||
# We are calling update on the mirrored variable in cross tower context.
|
||||
if update_device is not None:
|
||||
# We are calling an assign function on the mirrored variable in cross
|
||||
# tower context.
|
||||
v = self.get(device=update_device)
|
||||
return f(v, *args, **kwargs)
|
||||
|
||||
return distribute_lib.get_distribution_strategy().update(
|
||||
self, f, *args, **kwargs)
|
||||
else:
|
||||
# We are calling an assign function on the mirrored variable in tower
|
||||
# context.
|
||||
# We reduce the value we want to assign/add/sub. More details about how we
|
||||
# handle the different use cases can be found in the _reduce method.
|
||||
# We call the function on each of the mirrored variables with the reduced
|
||||
# value.
|
||||
if not self._aggregation_method:
|
||||
raise ValueError("You must specify an aggregation method to update a "
|
||||
"MirroredVariable in Tower Context.")
|
||||
|
||||
def merge_fn(strategy, value):
|
||||
return strategy.update(self,
|
||||
f,
|
||||
strategy.reduce(
|
||||
method_string=self._aggregation_method,
|
||||
value=value,
|
||||
destinations=self))
|
||||
return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
|
||||
**kwargs)
|
||||
|
||||
def assign_sub(self, *args, **kwargs):
|
||||
return self.get(device=_get_update_device()).assign_sub(*args, **kwargs)
|
||||
return self._assign_func(f=state_ops.assign_sub, *args, **kwargs)
|
||||
|
||||
def assign_add(self, *args, **kwargs):
|
||||
return self.get(device=_get_update_device()).assign_add(*args, **kwargs)
|
||||
return self._assign_func(f=state_ops.assign_add, *args, **kwargs)
|
||||
|
||||
def assign(self, *args, **kwargs):
|
||||
return self.get(device=_get_update_device()).assign(*args, **kwargs)
|
||||
return self._assign_func(f=state_ops.assign, *args, **kwargs)
|
||||
|
||||
def _get_cross_tower(self):
|
||||
device = device_util.canonicalize(device_util.current())
|
||||
|
Loading…
x
Reference in New Issue
Block a user