With distribution strategy, traced ConcreteFunctions may contain training specific logics that assumes the variable is a distributed variable. Such functions cannot be used for inference. Since we do not know if such ConcreteFunction will be saved for inference or not, we always mark them as unsaveable unless it's traced under a save context.
The user can tf.function instead, which can be retraced in saving.
Impacted usages:
- MultiWorkerMirroredStrategy
- Reading a synchronization=ON_READ variable. E.g. a batch norm layer.
- MultiWorkerMirroredStrategy, MirroredStrategy, TPUStrategy
- Updating a variable.
- Reading a synchronization=ON_READ aggregation=SUM variable.
It's TBD if we also need to mark functions that use packed handle as unsaveable. They do contain TPU:0 device annotations but with soft placement it may not be a problem.
PiperOrigin-RevId: 337438256
Change-Id: Ie89d0d6beb3e71d3ebbb867d1f91f2953468840c
MEAN aggregation always produces a floating number. We ran into issues when assigning to a MEAN ON_WRITE variable which was caused by the dtype mismatch. The conclusion at that time was that we're going error instead of casting the number to an integer to avoid potentially surprising precision lost.
I recently found that SyncOnReadVariable.read_value() has a similar issue. If aggregation=MEAN, it returns a floating number, instead of a value of the dtype of the variable. Based on the same rational, this changes disables MEAN aggregation for non-floating variables completely.
PiperOrigin-RevId: 334682155
Change-Id: Ib0f96c2a90f9e5f0b4bb4e255f2622e3dd4670bd
Tests passing on a multi-GPU system:
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
PiperOrigin-RevId: 327440841
Change-Id: I86b33681b5ad187f5d3f5e8a0d6d374edfafc8a6
Entering strategy may enter a device scope, which is something we would like to
avoid when saving a non distributed version of the model.
PiperOrigin-RevId: 324100191
Change-Id: I3f51c457e03710fbcd706bddf8ba1c4266667075
We used to save in either the default replica context or the cross replica
context. In either the variables behave in the desired way, which is as if
there's no distribute strategy.
This change make this behavior explicitly implemented. Note that now you will be
able to optionally save a distributed version of the model by setting
experimental_variable_policy to EXPAND_DISTRIBUTED_VARIABLES in SaveOptions.
This change is somewhat messy due to the ongoing refactoring to
DistributedVariable, but the fix is important and we can clean up later.
PiperOrigin-RevId: 323746817
Change-Id: I5ec5db232d86be97c93a2d54c9d3b1ceb344b3df
Inside strategy.extended.update, in_cross_replica_context() returns True, but we
should check for UpdateContext as well. Inside UpdateContext, reads and writes
should behave as reading/writing the replica local variable.
Before this fix, updating synchronization=ON_READ variables from keras
optimizers likely yields incorrect results. All replicas are updated with the
value from the last replica.
PiperOrigin-RevId: 323188475
Change-Id: I2278332868857b0bc97563b311e0c059d7644720
Distributed variables needs to behave differently when tracing functions with
different SaveOptions.
We need to access SaveContext in tf.distribute code instead of the another way
around because we may have a handle to the strategy in saving code.
PiperOrigin-RevId: 321815477
Change-Id: Ib69f6d42c60e198c0e8e174f76bc9424e21df5b5
Introduce PackedVarAndDevice which represents a packed variable in a given device.
PiperOrigin-RevId: 315769635
Change-Id: Ia63b72610afeb7139bd8370bc47067a1fb165307
now it returns tf.Operation in cross replica context regardless of the read_value argument.
PiperOrigin-RevId: 312511470
Change-Id: Ia5b47cc2d4fbe4f80fa73d2649adb6b5e96a7bed
This brings the behaivor of SyncOnReadVariable consistent with MirroredVariable.
E.g. previously it always returns a tf.Operation in cross replica context.
After the last refactoring SyncOnReadVariable only needs to override _update_replica to have the desired behavior.
This change also moves assign* and scatter* override to DistributedVariable
level.
This is part of the effort to make DistributedVariable.assign* returns a variable.
PiperOrigin-RevId: 308894562
Change-Id: I1a58352e2af2ff57402d8fc744fcfc9610a48d8b
These tests seem overly complicated because they're retrofit from a future version where scatter* is supported by SyncOnReadVariable.
PiperOrigin-RevId: 308869807
Change-Id: I4ceb98e8369de6c7f7acb2798ee59acecdcd44cd
With a following change we're going to create new DistributedVariable as return
value of assign*() and scatter*(). Installing _distributed_container multiple
times will be messy.
PiperOrigin-RevId: 307753201
Change-Id: I3c87abc301ea32b0169034324a108d6967229889
_mirrored_update again
In this way it's easier to modify the arugments, which is needed to make the
return type another DistributedVariable.
PiperOrigin-RevId: 306957879
Change-Id: I590543c47bf93abdc180ea9aec5eba12a3f1a888
It's not supported and the test is failing. CentralStorage is essentially
ParameterServer.
PiperOrigin-RevId: 306455178
Change-Id: Ia201b4da4962067660e877bce01115619c08e57b
Both ParameterServer and Mirrored uses this, and itself is complicated enough.
This also fix a issue that you can't strategy.run(tf.function) under CentralStorageStrategy, by applying the same workaround we have in MirroredStrategy.
PiperOrigin-RevId: 304437119
Change-Id: I6a7a67b88e7a5b7217aa9ffe05882d0ef4097896
This change separated common parts into DistributedVariableTest, and AggregateVaraible tests into its own ones as well.
PiperOrigin-RevId: 302125354
Change-Id: I1cfba4d5956a70b7b743913eea4d0301c4c8d1ce