Make MirroredStrategy throw an error when creating a trainable ReplicaLocalVariable.
PiperOrigin-RevId: 251596345
This commit is contained in:
parent
f57f1ce4cc
commit
d5f641cbb5
@ -844,6 +844,10 @@ class ParameterServerStrategyTest(
|
||||
num_gpus_per_worker=2)
|
||||
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||
|
||||
def testTrainableVariables(self):
|
||||
distribution = parameter_server_strategy.ParameterServerStrategy()
|
||||
self._test_trainable_variable(distribution)
|
||||
|
||||
|
||||
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
|
||||
parameterized.TestCase):
|
||||
|
@ -217,7 +217,6 @@ def _create_mirrored_variable(strategy, device_map, logical_device, # pylint: d
|
||||
elif synchronization == variable_scope.VariableSynchronization.ON_READ:
|
||||
# Variables that are to be synced on read are replica local.
|
||||
is_sync_on_read = True
|
||||
kwargs["trainable"] = False
|
||||
elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
|
||||
synchronization == variable_scope.VariableSynchronization.AUTO):
|
||||
# `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
|
||||
|
@ -226,6 +226,9 @@ class MirroredTwoDeviceDistributionTest(
|
||||
def testSummaryForReplicaZeroOnly(self, distribution):
|
||||
self._test_summary_for_replica_zero_only(distribution)
|
||||
|
||||
def testTrainableVariables(self, distribution):
|
||||
self._test_trainable_variable(distribution)
|
||||
|
||||
|
||||
def one_device_combinations():
|
||||
return combinations.combine(
|
||||
|
@ -110,6 +110,9 @@ class OneDeviceStrategyTest(
|
||||
def testAllReduceMeanGradientTape(self, distribution):
|
||||
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||
|
||||
def testTrainableVariables(self, distribution):
|
||||
self._test_trainable_variable(distribution)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
|
@ -426,6 +426,23 @@ class DistributionTestBase(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
run_and_concatenate(strategy, i)
|
||||
|
||||
def _test_trainable_variable(self, strategy):
|
||||
with strategy.scope():
|
||||
v1 = variables.Variable(1.0)
|
||||
self.assertEqual(True, v1.trainable)
|
||||
|
||||
v2 = variables.Variable(
|
||||
1.0, synchronization=variables.VariableSynchronization.ON_READ)
|
||||
self.assertEqual(False, v2.trainable)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Synchronization value can be set to VariableSynchronization.ON_READ "
|
||||
"only for non-trainable variables"):
|
||||
_ = variables.Variable(
|
||||
1.0, trainable=True,
|
||||
synchronization=variables.VariableSynchronization.ON_READ)
|
||||
|
||||
|
||||
class OneDeviceDistributionTestBase(test.TestCase):
|
||||
"""Some tests that should work with any one-device DistributionStrategy."""
|
||||
|
Loading…
Reference in New Issue
Block a user