Assign to all component TPUMirroredVariables when assigning in replica context and aggregation=NONE.
PiperOrigin-RevId: 316754219 Change-Id: I791f392b892886404cb80868368ae4a167d8b3d8
This commit is contained in:
parent
23910c191f
commit
295ee8ab72
|
@ -635,8 +635,6 @@ class MirroredVariableUpdateTest(test.TestCase):
|
|||
|
||||
def testAssignMirroredVarReplicaContextWithoutAggregationType(self,
|
||||
distribution):
|
||||
# Test that we always have an aggregation type set on the mirrored variable
|
||||
# if we assign to it in replica mode.
|
||||
def var_fn():
|
||||
v = variable_scope.variable(1.0, name="foo")
|
||||
return v
|
||||
|
|
|
@ -30,6 +30,7 @@ from tensorflow.python.eager import tape
|
|||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.tpu import tpu
|
||||
|
||||
|
||||
|
@ -173,6 +174,16 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
|||
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||
|
||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||
if (enclosing_tpu_context() and
|
||||
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||
return _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_sub_variable_op)(
|
||||
self,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
assign_sub_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_sub_variable_op)
|
||||
return self._update(
|
||||
|
@ -183,6 +194,16 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
|||
read_value=read_value)
|
||||
|
||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||
if (enclosing_tpu_context() and
|
||||
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||
return _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_add_variable_op)(
|
||||
self,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
assign_add_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_add_variable_op)
|
||||
return self._update(
|
||||
|
@ -193,6 +214,15 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
|||
read_value=read_value)
|
||||
|
||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||
if (enclosing_tpu_context() and
|
||||
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||
self,
|
||||
value=value,
|
||||
use_locking=use_locking,
|
||||
name=name,
|
||||
read_value=read_value)
|
||||
|
||||
assign_fn = _make_raw_assign_fn(
|
||||
gen_resource_variable_ops.assign_variable_op)
|
||||
return self._update(
|
||||
|
|
|
@ -915,6 +915,28 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
|||
sess.run(variables_lib.global_variables_initializer())
|
||||
sess.run({"complicated": mirrored})
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
mode=["eager"]))
|
||||
def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variables_lib.Variable(1.0, name="foo")
|
||||
|
||||
@def_function.function
|
||||
def mytest():
|
||||
def model_fn():
|
||||
v.assign(5.0)
|
||||
return v.read_value()
|
||||
|
||||
return distribution.run(model_fn)
|
||||
|
||||
mytest()
|
||||
self.assertAllEqual([5.0, 5.0], self.evaluate(v.values))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
|
|
Loading…
Reference in New Issue