Commit Graph

94 Commits

Author SHA1 Message Date
Ran Chen
380478ff5f Disallow saving if the function cannot be used for inference
With distribution strategy, traced ConcreteFunctions may contain training specific logics that assumes the variable is a distributed variable. Such functions cannot be used for inference. Since we do not know if such ConcreteFunction will be saved for inference or not, we always mark them as unsaveable unless it's traced under a save context.

The user can tf.function instead, which can be retraced in saving.

Impacted usages:
- MultiWorkerMirroredStrategy
  - Reading a synchronization=ON_READ variable. E.g. a batch norm layer.
- MultiWorkerMirroredStrategy, MirroredStrategy, TPUStrategy
  - Updating a variable.
  - Reading a synchronization=ON_READ aggregation=SUM variable.

It's TBD if we also need to mark functions that use packed handle as unsaveable. They do contain TPU:0 device annotations but with soft placement it may not be a problem.

PiperOrigin-RevId: 337438256
Change-Id: Ie89d0d6beb3e71d3ebbb867d1f91f2953468840c
2020-10-15 21:08:51 -07:00
Ran Chen
4ba6a1dc99 Add test_util.main() and test_util.set_logical_devices_to_at_least()
test_util.main() replaces combinations.main()
test_util.set_logical_devices_to_at_least() replaces strategy_combinations.set_virtual_cpus_to_at_least()

PiperOrigin-RevId: 335742598
Change-Id: Ie9967ed1f1fe866a83472319137aeb23a521c943
2020-10-06 16:30:51 -07:00
Anjali Sridhar
8a77ace2ac Remove tests that use a replica context scope with no strategy specified. This does not seem like a usage that we need to test.
PiperOrigin-RevId: 335587250
Change-Id: I34a8d12474b2e399ef9a07881cdf7104a7af52f0
2020-10-06 00:53:14 -07:00
Ran Chen
c8d3bd7823 Disallow MEAN non-floating distributed variables.
MEAN aggregation always produces a floating number. We ran into issues when assigning to a MEAN ON_WRITE variable which was caused by the dtype mismatch. The conclusion at that time was that we're going error instead of casting the number to an integer to avoid potentially surprising precision lost.

I recently found that SyncOnReadVariable.read_value() has a similar issue. If aggregation=MEAN, it returns a floating number, instead of a value of the dtype of the variable. Based on the same rational, this changes disables MEAN aggregation for non-floating variables completely.

PiperOrigin-RevId: 334682155
Change-Id: Ib0f96c2a90f9e5f0b4bb4e255f2622e3dd4670bd
2020-09-30 14:45:47 -07:00
Anjali Sridhar
aecc216071 Refactor DistributeVariable saveable object to extend from SaveableObject instead of ResourceVariableSaveable. This allows us move to a single type of DistributedVariable with attached policies.
PiperOrigin-RevId: 327884662
Change-Id: I4f3030f4a19248dfd7e9d3281a971e637735bc5f
2020-08-21 15:32:25 -07:00
Anjali Sridhar
7a26346ab6 Use a 1CPU-1GPU test combination to test CentralStorageStrategy so that we can verify tests locally without a multiGPU guitar run.
PiperOrigin-RevId: 327659073
Change-Id: Ic54c83b43c37995040674a0dbbeede92b7d215a7
2020-08-20 11:08:10 -07:00
Eugene Brevdo
654b45cd56 [TF DistStrat] Add support for deepcopy on AggregatingVariable (PS)
Tests passing on a multi-GPU system:
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE
[       OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD
[       OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE

PiperOrigin-RevId: 327440841
Change-Id: I86b33681b5ad187f5d3f5e8a0d6d374edfafc8a6
2020-08-19 08:57:16 -07:00
Eugene Brevdo
f6b3dec1b0 [TF DistStrat] Add proper __deepcopy__ support for all DistributedVariable objects (eager mode only).
PiperOrigin-RevId: 326669797
Change-Id: I3f7ec350efdd2c25456fce920ea43a2aa8ba373d
2020-08-14 09:31:56 -07:00
Anjali Sridhar
4a56b567d5 Restore SyncOnReadVariables by dividing the value across all the replicas in sync. The variable that is aggregated and assigned should be across all the replicas in the cluster and not just the local host.
PiperOrigin-RevId: 324131978
Change-Id: Ib112f2622918917e56704ce2557ae033900fb50f
2020-07-30 19:22:48 -07:00
Ran Chen
1cc4dbe2b2 Check is_saving_non_distributed() before entering strategy
Entering strategy may enter a device scope, which is something we would like to
avoid when saving a non distributed version of the model.

PiperOrigin-RevId: 324100191
Change-Id: I3f51c457e03710fbcd706bddf8ba1c4266667075
2020-07-30 16:02:52 -07:00
Priya Gupta
26dc0e3ee2 Add multi worker mirrored strategy to DistributedVariable test. Some cases that are broken are currently skipped and being fixed in separate changes.
PiperOrigin-RevId: 323948947
Change-Id: I3bd22d5309fb7be491b6036134b328883957e15c
2020-07-30 00:08:23 -07:00
Priya Gupta
640cdad89f Remove unused helper in values_test.py
PiperOrigin-RevId: 323927733
Change-Id: I4ff7e83d9f7419430259e488456a61f440a061ac
2020-07-29 20:53:01 -07:00
Ran Chen
8a0e1711f4 Save a non-distributed model correctly
We used to save in either the default replica context or the cross replica
context. In either the variables behave in the desired way, which is as if
there's no distribute strategy.

This change make this behavior explicitly implemented. Note that now you will be
able to optionally save a distributed version of the model by setting
experimental_variable_policy to EXPAND_DISTRIBUTED_VARIABLES in SaveOptions.

This change is somewhat messy due to the ongoing refactoring to
DistributedVariable, but the fix is important and we can clean up later.

PiperOrigin-RevId: 323746817
Change-Id: I5ec5db232d86be97c93a2d54c9d3b1ceb344b3df
2020-07-29 02:23:52 -07:00
Anjali Sridhar
765c642088 Add support for variable policy to be used by MirroredStrategy and TPUStrategy. Refactor existing values test and add an option to test variable policy.
PiperOrigin-RevId: 323391056
Change-Id: I4585e3a8e0a09300c09de95191c034a10255a58c
2020-07-27 10:46:05 -07:00
Anna R
03d890dfae Add support for variable policy to be used by MirroredStrategy and TPUStrategy. Refactor existing values test and add an option to test variable policy.
PiperOrigin-RevId: 323304491
Change-Id: I5c00791bc62a930274c254b33f4a47d671d0b7bf
2020-07-26 23:10:19 -07:00
Anjali Sridhar
f391524226 Add support for variable policy to be used by MirroredStrategy and TPUStrategy. Refactor existing values test and add an option to test variable policy.
PiperOrigin-RevId: 323190899
Change-Id: I8d97a625b33b0e97c0e0a3d76d96253a074d3359
2020-07-25 16:18:57 -07:00
Ran Chen
1c29b13764 Handle UpdateContext correctly in SyncOnReadVariable
Inside strategy.extended.update, in_cross_replica_context() returns True, but we
should check for UpdateContext as well. Inside UpdateContext, reads and writes
should behave as reading/writing the replica local variable.

Before this fix, updating synchronization=ON_READ variables from keras
optimizers likely yields incorrect results. All replicas are updated with the
value from the last replica.

PiperOrigin-RevId: 323188475
Change-Id: I2278332868857b0bc97563b311e0c059d7644720
2020-07-25 15:39:41 -07:00
Ran Chen
d378ee85b6 Implement experimental_distribute_values_from_function for CentralStorage
PiperOrigin-RevId: 322468547
Change-Id: I596d0d33f378299a74b853ac322d44e21df4b023
2020-07-21 16:57:10 -07:00
Ran Chen
32bb13dace Allow access SaveOptions through SaveContext
Distributed variables needs to behave differently when tracing functions with
different SaveOptions.

We need to access SaveContext in tf.distribute code instead of the another way
around because we may have a handle to the strategy in saving code.

PiperOrigin-RevId: 321815477
Change-Id: Ib69f6d42c60e198c0e8e174f76bc9424e21df5b5
2020-07-17 11:17:43 -07:00
Anjali Sridhar
4e73f5f852 Remove unit tests from values_test that test functions in distribute_utils.
PiperOrigin-RevId: 319609162
Change-Id: I8b81cc472742470c6b7705e562a8188760f560ad
2020-07-04 11:07:19 -07:00
Anjali Sridhar
12d209e54c Remove commented out test options for aggregation values.
PiperOrigin-RevId: 319476017
Change-Id: I7ecddf8ed454cd6afb89bc17cb7d5e38b8548320
2020-07-02 23:28:21 -07:00
Priya Gupta
593d153ad8 tf.distribute: Implement MultiWorkerMirroredStrategy.experimental_distribute_values_from_function and add tests.
Also fixes a small bug in the regroup utility.

PiperOrigin-RevId: 319336188
Change-Id: I7eb7bcd0bc6543d6cc797f97169f797803a52282
2020-07-01 19:04:58 -07:00
Yujing Zhang
f904a86441 Fix SyncOnReadVariable.value() to always return a tensor.
PiperOrigin-RevId: 319065089
Change-Id: Ie93caeaf6e24872eaacd11b7d316cb06cc9f3211
2020-06-30 11:55:05 -07:00
Anjali Sridhar
dc30240f53 Rename import to avoid conflict with parameters.
PiperOrigin-RevId: 318113928
Change-Id: I06da2eff0eacc8e5bedfd20cf7334c17bcab008f
2020-06-24 12:22:01 -07:00
Yujing Zhang
465ca119b2 Introduce a SaveContext to detect whether we are building a graph for a SavedModel. And don't use packed variables under a SaveContext.
PiperOrigin-RevId: 317914296
Change-Id: I92cc6043484d642a1919cb5ab238d5e5cacc4c2a
2020-06-23 12:21:25 -07:00
Yujing Zhang
7e6e549c46 Support packed variable in DistributedVariable. Add an option to enable packed variable in TPUStrategy.
PiperOrigin-RevId: 317234665
Change-Id: I09e806cb8261815cd87a6d98817556dd8f7e8ed7
2020-06-18 20:12:02 -07:00
Anjali Sridhar
295ee8ab72 Assign to all component TPUMirroredVariables when assigning in replica context and aggregation=NONE.
PiperOrigin-RevId: 316754219
Change-Id: I791f392b892886404cb80868368ae4a167d8b3d8
2020-06-16 14:12:04 -07:00
Ruoxin Sang
d4dac5d7bd Clean up tpu_values.py and add some uncovered tests.
PiperOrigin-RevId: 316198971
Change-Id: Id0b1b46be39fd303183e43caac3058260ec23f63
2020-06-12 16:33:38 -07:00
Anjali Sridhar
0f3562ba77 Another round of refactoring of values.py to split utility functions that use distributed Variable types defined in values.py.
PiperOrigin-RevId: 316147517
Change-Id: I72e17b02e8f41c9cee40f4ec7f56fec2f7d860a9
2020-06-12 12:04:39 -07:00
Yujing Zhang
813698a99a Skip Strategies not using DistributedVariables in testPackedVariable.
PiperOrigin-RevId: 316141304
Change-Id: I22bf040a1fba75ce70aa6c1c2fa92c7fa287cbc8
2020-06-12 11:32:50 -07:00
Yujing Zhang
60bcab7f1d Temporarily disable testPackedVariable since it breaks TensorflowMultiGpu.
PiperOrigin-RevId: 315950708
Change-Id: Ie25132f8786f6e02ed85cc91e6cd6fb4bc821b6a
2020-06-11 12:17:26 -07:00
Yujing Zhang
f5547e8125 Introduce PackedDistributedVariable which packs multiple variables distributed across devices.
Introduce PackedVarAndDevice which represents a packed variable in a given device.

PiperOrigin-RevId: 315769635
Change-Id: Ia63b72610afeb7139bd8370bc47067a1fb165307
2020-06-10 14:46:14 -07:00
Chenkai Kuang
e362bc5542 Allow shallow copy and deep copy of DistributeDelegate using python's copy module. Currently it leads to an infinite recursion.
PiperOrigin-RevId: 314996586
Change-Id: I985a1bc7a264e6d0b7e9c75baac488ee382e5213
2020-06-05 14:34:57 -07:00
Anjali Sridhar
8d31fb4b76 Refactor values.py into a utility file and a PS values file.
PiperOrigin-RevId: 313617630
Change-Id: Ie51b0f69af65b3f85701f58190da2c7eb46e1d29
2020-05-28 10:58:13 -07:00
Ran Chen
94ef9a2a9c SyncOnReadVariable.assign() should return Tensor
now it returns tf.Operation in cross replica context regardless of the read_value argument.

PiperOrigin-RevId: 312511470
Change-Id: Ia5b47cc2d4fbe4f80fa73d2649adb6b5e96a7bed
2020-05-20 10:51:54 -07:00
Anjali Sridhar
bfd3788101 Reorder functions in an effort to group utility functions that use symbols defined in values.py and are used by classes defined in values.py.
PiperOrigin-RevId: 312234995
Change-Id: I3ec7fbc1d35935da54e61d991a44bc81b0b61d67
2020-05-19 01:06:27 -07:00
Dan Moldovan
c5caa29b5e Make core.Tensor the base type for Tensor and replace the register_dense_tensor_like with direct subclassing.
PiperOrigin-RevId: 311206817
Change-Id: Id8ae234516d5409d6b70612a99f9f0b3ed53dc7e
2020-05-12 14:57:45 -07:00
A. Unique TensorFlower
62856c9366 Internal change
PiperOrigin-RevId: 308926386
Change-Id: I3185dfa09e948883d21ba7e3e16020e30639728c
2020-04-28 17:05:20 -07:00
Ran Chen
f7f727388c Refactor SyncOnReadVariable update methods to be consistent with MirroredVariable.
This brings the behaivor of SyncOnReadVariable consistent with MirroredVariable.
E.g. previously it always returns a tf.Operation in cross replica context.

After the last refactoring SyncOnReadVariable only needs to override _update_replica to have the desired behavior.

This change also moves assign* and scatter* override to DistributedVariable
level.

This is part of the effort to make DistributedVariable.assign* returns a variable.

PiperOrigin-RevId: 308894562
Change-Id: I1a58352e2af2ff57402d8fc744fcfc9610a48d8b
2020-04-28 14:14:47 -07:00
Ran Chen
5fd8e05a5b Add tests for SyncOnReadVariable.scatter*
These tests seem overly complicated because they're retrofit from a future version where scatter* is supported by SyncOnReadVariable.

PiperOrigin-RevId: 308869807
Change-Id: I4ceb98e8369de6c7f7acb2798ee59acecdcd44cd
2020-04-28 12:09:12 -07:00
Ran Chen
79abfee5c3 Install _distributed_container only at variable creation
With a following change we're going to create new DistributedVariable as return
value of assign*() and scatter*(). Installing _distributed_container multiple
times will be messy.

PiperOrigin-RevId: 307753201
Change-Id: I3c87abc301ea32b0169034324a108d6967229889
2020-04-21 23:06:33 -07:00
Ran Chen
9b027623d7 Use assertLen to avoid linter complaining
PiperOrigin-RevId: 307437879
Change-Id: I3dc9435a7635e0a57cbe7c32bd2beb3178343311
2020-04-20 11:21:08 -07:00
Ran Chen
69776d9377 DistributedVariable update methods always pass keyword arguments to
_mirrored_update again

In this way it's easier to modify the arugments, which is needed to make the
return type another DistributedVariable.

PiperOrigin-RevId: 306957879
Change-Id: I590543c47bf93abdc180ea9aec5eba12a3f1a888
2020-04-16 17:59:10 -07:00
Ran Chen
c6456bb63b Skip testing make_distributed_value_from_function for CentralStorage
It's not supported and the test is failing. CentralStorage is essentially
ParameterServer.

PiperOrigin-RevId: 306455178
Change-Id: Ia201b4da4962067660e877bce01115619c08e57b
2020-04-14 09:38:05 -07:00
Ran Chen
b16d24a342 Separate mirrored call_for_each_replica to its own file
Both ParameterServer and Mirrored uses this, and itself is complicated enough.

This also fix a issue that you can't strategy.run(tf.function) under CentralStorageStrategy, by applying the same workaround we have in MirroredStrategy.

PiperOrigin-RevId: 304437119
Change-Id: I6a7a67b88e7a5b7217aa9ffe05882d0ef4097896
2020-04-02 11:11:09 -07:00
Ran Chen
f0d9ae52dd Remove init_scope test of AggregateVariable
PiperOrigin-RevId: 303443594
Change-Id: I6eaa0e54e401c938fcccc31c731a45b5da0e9d22
2020-03-27 18:09:43 -07:00
Ruoxin Sang
b70b14a462 Only create TPU replicated variable handle in graph mode.
PiperOrigin-RevId: 302461289
Change-Id: I4923d3db3e59db45e95a7a52c0c60fb42b3ee911
2020-03-23 10:22:50 -07:00
Ran Chen
c984ec0b36 Re-organize values_test
This change separated common parts into DistributedVariableTest, and AggregateVaraible tests into its own ones as well.

PiperOrigin-RevId: 302125354
Change-Id: I1cfba4d5956a70b7b743913eea4d0301c4c8d1ce
2020-03-20 16:50:02 -07:00
Ken Franko
0b8f0a5b84 Drop experimental and v2 qualifiers from Strategy experimental_run_v2 method.
- experimental_run_v2 -> run

PiperOrigin-RevId: 300574367
Change-Id: I5d82ea5450a4d32aea6d05ed3db4f02b8edb2eea
2020-03-12 10:26:26 -07:00
Ken Franko
3f35d6d8b0 Add experimental_make_distributed_values_from_function method to distribution strategy.
PiperOrigin-RevId: 298445422
Change-Id: I907cf7808e5bedaf45adccf3b6355ccf219e4116
2020-03-02 14:23:56 -08:00