Assign to all component TPUMirroredVariables when assigning in replica context and aggregation=NONE.
PiperOrigin-RevId: 316754219 Change-Id: I791f392b892886404cb80868368ae4a167d8b3d8
This commit is contained in:
parent
23910c191f
commit
295ee8ab72
|
@ -635,8 +635,6 @@ class MirroredVariableUpdateTest(test.TestCase):
|
||||||
|
|
||||||
def testAssignMirroredVarReplicaContextWithoutAggregationType(self,
|
def testAssignMirroredVarReplicaContextWithoutAggregationType(self,
|
||||||
distribution):
|
distribution):
|
||||||
# Test that we always have an aggregation type set on the mirrored variable
|
|
||||||
# if we assign to it in replica mode.
|
|
||||||
def var_fn():
|
def var_fn():
|
||||||
v = variable_scope.variable(1.0, name="foo")
|
v = variable_scope.variable(1.0, name="foo")
|
||||||
return v
|
return v
|
||||||
|
|
|
@ -30,6 +30,7 @@ from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_resource_variable_ops
|
from tensorflow.python.ops import gen_resource_variable_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.tpu import tpu
|
from tensorflow.python.tpu import tpu
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,6 +174,16 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
||||||
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
"""Holds a map from replica to TPU variables whose values are kept in sync."""
|
||||||
|
|
||||||
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||||||
|
if (enclosing_tpu_context() and
|
||||||
|
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||||
|
return _make_raw_assign_fn(
|
||||||
|
gen_resource_variable_ops.assign_sub_variable_op)(
|
||||||
|
self,
|
||||||
|
value=value,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name,
|
||||||
|
read_value=read_value)
|
||||||
|
|
||||||
assign_sub_fn = _make_raw_assign_fn(
|
assign_sub_fn = _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_sub_variable_op)
|
gen_resource_variable_ops.assign_sub_variable_op)
|
||||||
return self._update(
|
return self._update(
|
||||||
|
@ -183,6 +194,16 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
||||||
read_value=read_value)
|
read_value=read_value)
|
||||||
|
|
||||||
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||||||
|
if (enclosing_tpu_context() and
|
||||||
|
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||||
|
return _make_raw_assign_fn(
|
||||||
|
gen_resource_variable_ops.assign_add_variable_op)(
|
||||||
|
self,
|
||||||
|
value=value,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name,
|
||||||
|
read_value=read_value)
|
||||||
|
|
||||||
assign_add_fn = _make_raw_assign_fn(
|
assign_add_fn = _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_add_variable_op)
|
gen_resource_variable_ops.assign_add_variable_op)
|
||||||
return self._update(
|
return self._update(
|
||||||
|
@ -193,6 +214,15 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
|
||||||
read_value=read_value)
|
read_value=read_value)
|
||||||
|
|
||||||
def assign(self, value, use_locking=False, name=None, read_value=True):
|
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||||||
|
if (enclosing_tpu_context() and
|
||||||
|
self.aggregation == variable_scope.VariableAggregation.NONE):
|
||||||
|
return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
|
||||||
|
self,
|
||||||
|
value=value,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name,
|
||||||
|
read_value=read_value)
|
||||||
|
|
||||||
assign_fn = _make_raw_assign_fn(
|
assign_fn = _make_raw_assign_fn(
|
||||||
gen_resource_variable_ops.assign_variable_op)
|
gen_resource_variable_ops.assign_variable_op)
|
||||||
return self._update(
|
return self._update(
|
||||||
|
|
|
@ -915,6 +915,28 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase):
|
||||||
sess.run(variables_lib.global_variables_initializer())
|
sess.run(variables_lib.global_variables_initializer())
|
||||||
sess.run({"complicated": mirrored})
|
sess.run({"complicated": mirrored})
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.tpu_strategy,
|
||||||
|
],
|
||||||
|
mode=["eager"]))
|
||||||
|
def testAssignValueInReplicaContextWithoutAggregation(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
v = variables_lib.Variable(1.0, name="foo")
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def mytest():
|
||||||
|
def model_fn():
|
||||||
|
v.assign(5.0)
|
||||||
|
return v.read_value()
|
||||||
|
|
||||||
|
return distribution.run(model_fn)
|
||||||
|
|
||||||
|
mytest()
|
||||||
|
self.assertAllEqual([5.0, 5.0], self.evaluate(v.values))
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=[
|
distribution=[
|
||||||
|
|
Loading…
Reference in New Issue