Make MirroredStrategy throw an error when creating a trainable ReplicaLocalVariable.

PiperOrigin-RevId: 251596345
This commit is contained in:
Peter Buchlovsky 2019-06-05 01:11:17 -07:00 committed by TensorFlower Gardener
parent f57f1ce4cc
commit d5f641cbb5
5 changed files with 27 additions and 1 deletions

View File

@ -844,6 +844,10 @@ class ParameterServerStrategyTest(
num_gpus_per_worker=2)
self._test_all_reduce_mean_gradient_tape(distribution)
def testTrainableVariables(self):
distribution = parameter_server_strategy.ParameterServerStrategy()
self._test_trainable_variable(distribution)
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
parameterized.TestCase):

View File

@ -217,7 +217,6 @@ def _create_mirrored_variable(strategy, device_map, logical_device, # pylint: d
elif synchronization == variable_scope.VariableSynchronization.ON_READ:
# Variables that are to be synced on read are replica local.
is_sync_on_read = True
kwargs["trainable"] = False
elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
synchronization == variable_scope.VariableSynchronization.AUTO):
# `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.

View File

@ -226,6 +226,9 @@ class MirroredTwoDeviceDistributionTest(
def testSummaryForReplicaZeroOnly(self, distribution):
self._test_summary_for_replica_zero_only(distribution)
def testTrainableVariables(self, distribution):
self._test_trainable_variable(distribution)
def one_device_combinations():
return combinations.combine(

View File

@ -110,6 +110,9 @@ class OneDeviceStrategyTest(
def testAllReduceMeanGradientTape(self, distribution):
self._test_all_reduce_mean_gradient_tape(distribution)
def testTrainableVariables(self, distribution):
self._test_trainable_variable(distribution)
@combinations.generate(
combinations.combine(

View File

@ -426,6 +426,23 @@ class DistributionTestBase(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
run_and_concatenate(strategy, i)
def _test_trainable_variable(self, strategy):
with strategy.scope():
v1 = variables.Variable(1.0)
self.assertEqual(True, v1.trainable)
v2 = variables.Variable(
1.0, synchronization=variables.VariableSynchronization.ON_READ)
self.assertEqual(False, v2.trainable)
with self.assertRaisesRegexp(
ValueError,
"Synchronization value can be set to VariableSynchronization.ON_READ "
"only for non-trainable variables"):
_ = variables.Variable(
1.0, trainable=True,
synchronization=variables.VariableSynchronization.ON_READ)
class OneDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any one-device DistributionStrategy."""