Disallow MEAN non-floating distributed variables.
MEAN aggregation always produces a floating number. We ran into issues when assigning to a MEAN ON_WRITE variable which was caused by the dtype mismatch. The conclusion at that time was that we're going error instead of casting the number to an integer to avoid potentially surprising precision lost. I recently found that SyncOnReadVariable.read_value() has a similar issue. If aggregation=MEAN, it returns a floating number, instead of a value of the dtype of the variable. Based on the same rational, this changes disables MEAN aggregation for non-floating variables completely. PiperOrigin-RevId: 334682155 Change-Id: Ib0f96c2a90f9e5f0b4bb4e255f2622e3dd4670bd
This commit is contained in:
parent
84df34f818
commit
c8d3bd7823
@ -443,6 +443,12 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
"""Holds a map from replica to variables."""
|
||||
|
||||
def __init__(self, strategy, values, aggregation, var_policy=None):
|
||||
if (aggregation == variables_lib.VariableAggregation.MEAN and
|
||||
not values[0].dtype.is_floating):
|
||||
raise ValueError(
|
||||
"creating distributed tf.Variable with aggregation=MEAN and a "
|
||||
"non-floating dtype is not supported, please use a different "
|
||||
"aggregation or dtype")
|
||||
self._distribute_strategy = strategy
|
||||
self._aggregation = aggregation
|
||||
super(DistributedVariable, self).__init__(values)
|
||||
|
@ -682,13 +682,13 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
with distribution.scope():
|
||||
# We use four variables for convenience reasons. They have no special
|
||||
# meaning.
|
||||
# - v is used whenever possible, and for the methods that require the
|
||||
# dtype to be integer.
|
||||
# - v is used whenever possible.
|
||||
# - w is used for scatter and gather, which require the variable to be
|
||||
# non-scalar.
|
||||
# - y is used when the dtype needs to be float.
|
||||
# - y is used when the dtype needs to be integer. Note that aggregation
|
||||
# cannot be MEAN for integers.
|
||||
v = variables_lib.Variable(
|
||||
0,
|
||||
0.,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation,
|
||||
trainable=True)
|
||||
@ -696,10 +696,11 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation,
|
||||
trainable=True)
|
||||
y = variables_lib.Variable(
|
||||
7.,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation)
|
||||
if aggregation != variables_lib.VariableAggregation.MEAN:
|
||||
y = variables_lib.Variable(
|
||||
0,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation)
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
|
||||
@ -708,7 +709,7 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
_test(lambda: self.assertIs(v.constraint, None), v)
|
||||
# TODO(crccw): should we raise an error instead?
|
||||
_test(lambda: self.assertEqual(v.device, v._primary.device), v)
|
||||
_test(lambda: self.assertEqual(v.dtype, dtypes.int32), v)
|
||||
_test(lambda: self.assertEqual(v.dtype, dtypes.float32), v)
|
||||
if not context.executing_eagerly():
|
||||
_test(lambda: self.assertIs(v.graph, v._primary.graph), v)
|
||||
if not context.executing_eagerly():
|
||||
@ -722,9 +723,9 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
_test(lambda: self.assertTrue(v.trainable, True), v)
|
||||
|
||||
# tf.Variable methods.
|
||||
_test(lambda: check_ops.assert_equal_v2(v.assign(1), 1), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v.assign_add(1), 2), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v.assign_sub(1), 1), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v)
|
||||
# TODO(b/148689177): Implement batch_scatter_update.
|
||||
# count_up_to() is skipped since it's deprecated.
|
||||
# eval() is skipped since it shouldn't called in a tf.function.
|
||||
@ -736,7 +737,7 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
tensor_shape.TensorShape(())), v)
|
||||
# initialized_value() is skipped since it shouldn't called in a tf.function.
|
||||
# load() is skipped since it shouldn't called in a tf.function.
|
||||
_test(lambda: check_ops.assert_equal_v2(v.read_value(), 1), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v)
|
||||
# ref() is skipped since it shouldn't called in a tf.function.
|
||||
_test(
|
||||
lambda: check_ops.assert_equal_v2(
|
||||
@ -770,62 +771,65 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
[2., 0.5, 1.]), w)
|
||||
# set_shape() is skipped since ResourceVariable doesn't implement it.
|
||||
# to_proto() is skipped since it shouldn't called in a tf.function.
|
||||
_test(lambda: check_ops.assert_equal_v2(v.value(), 1), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v)
|
||||
|
||||
# DistributedVariable should be treated as ResourceVariable, so it needs to
|
||||
# conform to ResourceVariable interface as well.
|
||||
_test(lambda: self.assertIs(v.handle, v._primary.handle), v)
|
||||
|
||||
# Convert to tensor.
|
||||
_test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v)
|
||||
|
||||
# Control dependency.
|
||||
def _with_control_dep():
|
||||
with ops.control_dependencies([v.assign(1)]):
|
||||
with ops.control_dependencies([v.assign(1.)]):
|
||||
return array_ops.identity(1)
|
||||
|
||||
_test(_with_control_dep, v)
|
||||
|
||||
# Operator overloads.
|
||||
_test(lambda: check_ops.assert_equal_v2(v.assign(7), 7), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v + 1, 8), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(3 + v, 10), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v + v, 14), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v - 2, 5), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v - v, 0), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v * 2, 14), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(3 * v, 21), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v * v, 49), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v + v, 14.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v - v, 0.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v * v, 49.), v)
|
||||
_test(
|
||||
lambda: check_ops.assert_equal_v2(
|
||||
math_ops.cast(v / 2, dtypes.float32), 3.5), v)
|
||||
math_ops.cast(v / 2., dtypes.float32), 3.5), v)
|
||||
_test(
|
||||
lambda: check_ops.assert_equal_v2(
|
||||
math_ops.cast(14 / v, dtypes.float32), 2.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v // 2, 3), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(15 // v, 2), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v % 2, 1), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(16 % v, 2), v)
|
||||
_test(lambda: _assert(v < 12), v)
|
||||
_test(lambda: _assert(v <= 12), v)
|
||||
_test(lambda: _assert(not v > 12), v)
|
||||
_test(lambda: _assert(not v >= 12), v)
|
||||
_test(lambda: _assert(not 12 < v), v)
|
||||
_test(lambda: _assert(not 12 <= v), v)
|
||||
_test(lambda: _assert(12 > v), v)
|
||||
_test(lambda: _assert(12 >= v), v)
|
||||
# XLA doesn't implement pow() with integers.
|
||||
_test(lambda: check_ops.assert_near_v2(pow(y, 3.), 343.), y)
|
||||
_test(lambda: check_ops.assert_near_v2(pow(2., y), 128.), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(abs(v), 7), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v & 3, 3), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(3 & v, 3), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v | 8, 15), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(16 | v, 23), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(v ^ 3, 4), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(11 ^ v, 12), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(-v, -7), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(~v, ~7), v)
|
||||
math_ops.cast(14. / v, dtypes.float32), 2.), v)
|
||||
_test(lambda: _assert(v < 12.), v)
|
||||
_test(lambda: _assert(v <= 12.), v)
|
||||
_test(lambda: _assert(not v > 12.), v)
|
||||
_test(lambda: _assert(not v >= 12.), v)
|
||||
_test(lambda: _assert(not 12. < v), v)
|
||||
_test(lambda: _assert(not 12. <= v), v)
|
||||
_test(lambda: _assert(12. > v), v)
|
||||
_test(lambda: _assert(12. >= v), v)
|
||||
_test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v)
|
||||
_test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v)
|
||||
_test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v)
|
||||
|
||||
# Operator overloads that only works for integers.
|
||||
if aggregation != variables_lib.VariableAggregation.MEAN:
|
||||
_test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(y // 2, 3), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(15 // y, 2), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(y % 2, 1), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(16 % y, 2), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(y & 3, 3), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(3 & y, 3), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(y | 8, 15), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(16 | y, 23), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(-y, -7), y)
|
||||
_test(lambda: check_ops.assert_equal_v2(~y, ~7), y)
|
||||
|
||||
# Index.
|
||||
if isinstance(distribution.extended, tpu_strategy.TPUExtended):
|
||||
|
@ -302,63 +302,6 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
for component in mirrored.values:
|
||||
self.assertEqual(self.evaluate(component.read_value()), 3.)
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testAssignAggregationMeanDTypeNonFloat(self, distribution):
|
||||
if isinstance(distribution, _TPU_STRATEGIES):
|
||||
self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before "
|
||||
"reenabling test.")
|
||||
|
||||
with distribution.scope():
|
||||
v = variables_lib.Variable(
|
||||
1,
|
||||
aggregation=variable_scope.VariableAggregation.MEAN,
|
||||
dtype=dtypes.int32)
|
||||
self.evaluate(v.initializer)
|
||||
|
||||
@def_function.function
|
||||
def assign():
|
||||
ctx = ds_context.get_replica_context()
|
||||
return v.assign(ctx.replica_id_in_sync_group)
|
||||
|
||||
# disallow assign() with distributed value in replica context.
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Cannot update non-float variables"):
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(
|
||||
distribution.run(assign)))
|
||||
|
||||
# allow assign() with same value in replica context.
|
||||
@def_function.function
|
||||
def assign_same():
|
||||
return v.assign(2)
|
||||
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(
|
||||
distribution.run(assign_same)))
|
||||
self.assertEqual(self.evaluate(v.read_value()), 2)
|
||||
|
||||
# allow assign() with mirrored variable in replica context.
|
||||
with distribution.scope():
|
||||
v2 = variables_lib.Variable(
|
||||
3,
|
||||
aggregation=variable_scope.VariableAggregation.SUM,
|
||||
dtype=dtypes.int32)
|
||||
self.evaluate(v2.initializer)
|
||||
|
||||
@def_function.function
|
||||
def assign_mirrored():
|
||||
return v.assign(v2)
|
||||
|
||||
self.evaluate(
|
||||
distribution.experimental_local_results(
|
||||
distribution.run(assign_mirrored)))
|
||||
self.assertEqual(self.evaluate(v.read_value()), 3)
|
||||
|
||||
# allow assign() in cross replica context.
|
||||
with distribution.scope():
|
||||
self.evaluate(v.assign(4))
|
||||
self.assertEqual(self.evaluate(v.read_value()), 4)
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testInitializedToSameValueInsideEagerRun(self, distribution):
|
||||
if not context.executing_eagerly(): self.skipTest("eager only test")
|
||||
@ -415,24 +358,24 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
|
||||
with ops.init_scope():
|
||||
if obj.w is None:
|
||||
obj.w = variables_lib.Variable(
|
||||
0, aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
0., aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
obj.v = variables_lib.Variable(
|
||||
obj.w.read_value(),
|
||||
aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
|
||||
return obj.v.assign_add(2)
|
||||
return obj.v.assign_add(2.)
|
||||
|
||||
per_replica_results = self.evaluate(
|
||||
distribution.experimental_local_results(distribution.run(assign)))
|
||||
self.assertAllEqual([2, 2], per_replica_results)
|
||||
self.assertAllEqual([2., 2.], per_replica_results)
|
||||
|
||||
@combinations.generate(strategy_with_var_policy())
|
||||
def testOperatorOverride(self, distribution):
|
||||
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
1, aggregation=variables_lib.VariableAggregation.MEAN)
|
||||
1, aggregation=variables_lib.VariableAggregation.SUM)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
|
||||
self.assertEqual(2, self.evaluate(v + 1))
|
||||
|
Loading…
Reference in New Issue
Block a user