Assign to all component TPUMirroredVariables when assigning in replica context and aggregation=NONE.

PiperOrigin-RevId: 316754219
Change-Id: I791f392b892886404cb80868368ae4a167d8b3d8
This commit is contained in:
Anjali Sridhar 2020-06-16 14:07:19 -07:00 committed by TensorFlower Gardener
parent 23910c191f
commit 295ee8ab72
3 changed files with 52 additions and 2 deletions

View File

@ -635,8 +635,6 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignMirroredVarReplicaContextWithoutAggregationType(self,
distribution):
# Test that we always have an aggregation type set on the mirrored variable
# if we assign to it in replica mode.
def var_fn():
v = variable_scope.variable(1.0, name="foo")
return v

View File

@ -30,6 +30,7 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.tpu import tpu
@ -173,6 +174,16 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
"""Holds a map from replica to TPU variables whose values are kept in sync."""
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
if (enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE):
return _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
assign_sub_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_sub_variable_op)
return self._update(
@ -183,6 +194,16 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
read_value=read_value)
def assign_add(self, value, use_locking=False, name=None, read_value=True):
if (enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE):
return _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
assign_add_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_add_variable_op)
return self._update(
@ -193,6 +214,15 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
read_value=read_value)
def assign(self, value, use_locking=False, name=None, read_value=True):
if (enclosing_tpu_context() and
self.aggregation == variable_scope.VariableAggregation.NONE):
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
self,
value=value,
use_locking=use_locking,
name=name,
read_value=read_value)
assign_fn = _make_raw_assign_fn(
gen_resource_variable_ops.assign_variable_op)
return self._update(

View File

@ -915,6 +915,28 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
sess.run(variables_lib.global_variables_initializer())
sess.run({"complicated": mirrored})
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
mode=["eager"]))
def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
with distribution.scope():
v = variables_lib.Variable(1.0, name="foo")
@def_function.function
def mytest():
def model_fn():
v.assign(5.0)
return v.read_value()
return distribution.run(model_fn)
mytest()
self.assertAllEqual([5.0, 5.0], self.evaluate(v.values))
@combinations.generate(
combinations.combine(
distribution=[