Instead of using def_function, in test we can parameterize with these two
objects to test both tf.function and eager execution.
PiperOrigin-RevId: 351420316
Change-Id: I037d1678ca843f6df88694981efd4519c2947cd3
All strategies are supported except for CentralStorageStrategy and ParameterServerStrategy.
This CL also removes the CompositeTensor superclass from Generator. Generator is a wrapper around tf.Variable, and because tf.Variable is not a CompositeTensor, Generator can't be a CompositeTensor in theory. Previously we made it a CompositeTensor by returning Variable.handle, but that breaks down when the variable is a DistributedVariable (in cross-replica context).
PiperOrigin-RevId: 350851648
Change-Id: I5f4d77ddb990557fcc9c7336987203ecdaec5b9a
Auto control dep will chain operations with the same resource input. We'll do the same thing for all-gather after some refactoring is done.
PiperOrigin-RevId: 341868107
Change-Id: I5570a28c2e1c638980e3509088c0525e957c463b
We used to use one worker pool per strategy combination, but it's not necessary.
If the cluster topology is the same they can share the same worker pool. This
reduces the overhead of initializing worker pools, which can take O(10s) for
GPU builds.
PiperOrigin-RevId: 339117353
Change-Id: I1f631f79597b07991528c77482c44c201a01abe4
Ops like `tf.nn.nce_loss` and `tf.nn.sampled_softmax_loss` also benefit from this as they use embedding_lookup internally.
PiperOrigin-RevId: 338369985
Change-Id: I89ebe2a452fc1d599567cb80e80ee9b023e5aa1c
tf.function tracing depends on the inputs to the function. For a typical training loop:
x, y = next(iter)
train_fn(x,y)
it may retrace when getting a partial/batches. This is problematic for multi client training since different client may retrace at different time. We assign collective instance_key when tracing a function, retracing results in different sets of instance keys.
This change we overrides the PerReplica type spec, which is used to calculate function cache key. This tries to avoid retracing in common cases, but it doesn't guarantee that it won't happen.
Note that after such change, the function also gets partial shape information. This is the reason we only do it for multi client strategies (MWMS), to avoid performance penalty to e.g. TPU.
PiperOrigin-RevId: 338203534
Change-Id: Iae9d6c3c82113d623707e19142fbebe5597d7898
Over the past months we've several improvements:
- Test coverage is now on par with other strategies.
- Peer failure will no longer cause the cluster to hang.
- Major issues with saving are fixed.
- gather() API is added.
PiperOrigin-RevId: 338175223
Change-Id: I3c52a4d53d1c487558f1caaae7d094fe2245183b
compat.v1 partitioners are left unchanged. V2 partitioners are exported in tf.distribute namespace as they are supposed to work with sharded variable, which is a concept in tf.distribute. Implementations of the partitioners are reused.
While on it, also took the opportunity to refine the naming:
- variable_axis_size_partitioner -> MaxSizePartitioner (partitioner that keeps shards under a maximum size)
- min_max_variable_partitioner -> MinSizePartitioner (partitioner that allocates shards above a minimum size)
- fixed_size_partitioner -> FixedShardsPartitioner (partitioner that allocates fixed number of shards).
PiperOrigin-RevId: 338157380
Change-Id: I19f517e38f20e4e9c85745863e764da0aad6eeeb
This change exports the following class symbols, and adds relevant documentation and example code to
tf.distribute.experimental.ParameterServerStrategy
tf.distribute.experimental.coordinator.ClusterCoordinator
tf.distribute.experimental.coordinator.PerWorkerValues
tf.distribute.experimental.coordinator.RemoteValue
PiperOrigin-RevId: 338151262
Change-Id: If2d1c513d30a999c728cecc2e73b75adda1948c2
Over the past months we've several improvements:
- Test coverage is now on par with other strategies.
- Peer failure will no longer cause the cluster to hang.
- Major issues with saving are fixed.
- gather() API is added.
PiperOrigin-RevId: 338132035
Change-Id: I384c084717cd5f2b6167668ebe96af0f7b371530
Over the past months we've several improvements:
- Test coverage is now on par with other strategies.
- Peer failure will no longer cause the cluster to hang.
- Major issues with saving are fixed.
- gather() API is added.
PiperOrigin-RevId: 338110984
Change-Id: I92eeb981c67acb0c44f658316b6ad564162508bc
tf.function tracing depends on the inputs to the function. For a typical training loop:
x, y = next(iter)
train_fn(x,y)
it may retrace when getting a partial/batches. This is problematic for multi client training since different client may retrace at different time. We assign collective instance_key when tracing a function, retracing results in different sets of instance keys.
This change we overrides the PerReplica type spec, which is used to calculate function cache key. This tries to avoid retracing in common cases, but it doesn't guarantee that it won't happen.
Note that after such change, the function also gets partial shape information. This is the reason we only do it for multi client strategies (MWMS), to avoid performance penalty to e.g. TPU.
PiperOrigin-RevId: 337792983
Change-Id: Ib029d61cd360d6a25e38e894913e4d78af20d1dd
The CollectiveHints class is also renamed to CommunicationOptions. The communication enum is added to it.
CommunicationOptions stays experimental since the detailed options may change, but it's rather clear we need an options argument for these cross device communications.
PiperOrigin-RevId: 337547832
Change-Id: I376171672698d5923b4e52f2567d4a584c8e21b6
With distribution strategy, traced ConcreteFunctions may contain training specific logics that assumes the variable is a distributed variable. Such functions cannot be used for inference. Since we do not know if such ConcreteFunction will be saved for inference or not, we always mark them as unsaveable unless it's traced under a save context.
The user can tf.function instead, which can be retraced in saving.
Impacted usages:
- MultiWorkerMirroredStrategy
- Reading a synchronization=ON_READ variable. E.g. a batch norm layer.
- MultiWorkerMirroredStrategy, MirroredStrategy, TPUStrategy
- Updating a variable.
- Reading a synchronization=ON_READ aggregation=SUM variable.
It's TBD if we also need to mark functions that use packed handle as unsaveable. They do contain TPU:0 device annotations but with soft placement it may not be a problem.
PiperOrigin-RevId: 337438256
Change-Id: Ie89d0d6beb3e71d3ebbb867d1f91f2953468840c
This allows us to:
1. Pass `ShardedVariable` to tf.function inputs while avoiding retracing if spec of the ShardedVariable doesn't change.
2. Use `nest.flatten(sharded_variable, expand_composites=True)` to retrieve the list of component variables. This is used by `tf.module` and keras Layer to collect variables from attributes that are nested structures, so this change makes them be able to collect component variables when a sharded_variable is assigned to their attribute.
`layer.add_weight` already works, this change adds a test for that.
PiperOrigin-RevId: 337382403
Change-Id: I4c7e490cdc8fd772ed57c4074894637147986dac
It appears that the polarity of the use of has_tensor_list_arg was
inadvertently flipped.
Disable any MLIR bridge enabled tests that were passing because they weren't
using the MLIR bridge due to this issue.
PiperOrigin-RevId: 337125651
Change-Id: I93e9e61acda9a2aeffaee5cce13e93635d33f5a4
Collective v2 doesn't support scoped allocator. While it's possible to make scoped allocator work with it. Concat/split is much simpler.
PiperOrigin-RevId: 337109439
Change-Id: I3535f5e0b090696f3bb620617f2b57f1f4b78b22
Collective v2 doesn't support scoped allocator. While it's possible to make scoped allocator work with it. Concat/split is much simpler.
PiperOrigin-RevId: 337021023
Change-Id: I6e6e2fdc3c94ffbc59a52c20a451dcd74fd864e4