Update the variable partitioning section of ps strategy docstring.

PiperOrigin-RevId: 338367544
Change-Id: Iaff150c3b5a5e8179bcd49f02d3f5d9bdec20a02
This commit is contained in:
Chenkai Kuang 2020-10-21 16:51:15 -07:00 committed by TensorFlower Gardener
parent 6d4f1d5c09
commit 6cc0cf5e30

View File

@ -264,51 +264,75 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
__Variable partitioning__ __Variable partitioning__
Having dedicated servers to store variables means being able to divide up, or Having dedicated servers to store variables means being able to divide up, or
"shard" the variables across the ps. Large embeddings that would otherwise "shard" the variables across the ps. Partitioning large variable among ps is a
exceed memory limit of a single machine can be used in a cluster with enough commonly used technique to boost training throughput and mitigate memory
number of ps. constraints. It enables parallel computations and updates on different shards
of a variable, and often yields better load balancing across parameter servers
. Without sharding, models with large variables (e.g, embeddings) that can't
fit into one machine's memory would otherwise be unable to train.
With `tf.distribute.experimental.ParameterServerStrategy`, if a With `tf.distribute.experimental.ParameterServerStrategy`, if a
`variable_partitioner` is provided to `__init__` and certain conditions are `variable_partitioner` is provided to `__init__` and certain conditions are
satisfied, the resulting variables created in scope are sharded across the satisfied, the resulting variables created in scope are sharded across the
parameter servers, in a round-robin fashion. The variable reference returned parameter servers, in a round-robin fashion. The variable reference returned
from `tf.Variable` becomes a type that serves as the container of the sharded from `tf.Variable` becomes a type that serves as the container of the sharded
variables. Access `variables` attribute of this container for the actual variables. One can access `variables` attribute of this container for the
variable components. See arguments section of actual variable components. If building model with `tf.Module` or Keras,
`tf.distribute.experimental.ParameterServerStrategy.__init__` for more the variable components are collected in the `variables` alike attributes.
information.
To initialize the sharded variables in a more memory-efficient way, use an
initializer whose `__call__` accepts a `shard_info` argument, and use
`shard_info.offset` and `shard_info.shape` to create and return a
partition-aware `tf.Tensor` to initialize the variable components.
```python ```python
class PartitionAwareIdentity(object): class Dense(tf.Module):
def __init__(self, name=None):
super().__init__(name=name)
self.w = tf.Variable(tf.random.normal([100, 10]), name='w')
def __call__(self, shape, dtype, shard_info): def __call__(self, x):
value = tf.eye(*shape, dtype=dtype) return x * self.w
if shard_info is not None:
value = tf.slice(value, shard_info.offset, shard_info.shape)
return value
cluster_resolver = ... # Partition the dense layer into 2 shards.
strategy = tf.distribute.experimental.ParameterServerStrategy( variable_partitioiner = (
cluster_resolver, tf.fixed_size_partitioner(2)) tf.distribute.experimental.partitioners.FixedShardsPartitioner(
num_shards = 2))
strategy = ParameterServerStrategy(cluster_resolver=...,
variable_partitioner = variable_partitioner)
with strategy.scope(): with strategy.scope():
initializer = PartitionAwareIdentity() dense = Dense()
initial_value = functools.partial(initializer, shape=(4, 4), dtype=tf.int64) assert len(dense.variables) == 2
v = tf.Variable( assert isinstance(dense.variables[0], tf.Variable)
initial_value=initial_value, shape=(4, 4), dtype=tf.int64) assert isinstance(dense.variables[1], tf.Variable)
assert dense.variables[0].name == "w/part_0"
# `v.variables` gives the actual variable components. assert dense.variables[1].name == "w/part_1"
assert len(v.variables) == 2
assert v.variables[0].device == "/job:ps/replica:0/task:0/device:CPU:0"
assert v.variables[1].device == "/job:ps/replica:0/task:1/device:CPU:0"
assert np.array_equal(v.variables[0].numpy(), [[1, 0, 0, 0], [0, 1, 0, 0]])
assert np.array_equal(v.variables[1].numpy(), [[0, 0, 1, 0], [0, 0, 0, 1]])
``` ```
The sharded variable container can be converted to a `Tensor` via
`tf.convert_to_tensor`. This means the container can be directly used in most
Python Ops where such `Tensor` convertion automatically happens. For example
in the above code snippet, `x * self.w` would implicitly apply the said tensor
convertion. Note that such convertion can be expensive, as the variable
components need to be transferred from multiple parameter servers to where
the value is used.
`tf.nn.embedding_lookup` on the other hand doesn't apply the tensor convertion
, and performs parallel lookups on the variable components instead. This is
crutial to scale up embedding lookups when the embedding table variable is
large.
When a partitioned variable is saved to `SavedModel`, it will be saved as if
it is one single variable. This improves serving efficiency by eliminating
a number of Ops that handle the partiton aspects.
Known limitations of variable partitioning:
* Number of parttions must not change across Checkpoint save/load.
* After saving partitioned variables to a SavedModel, the SavedModel can't be
loaded via `tf.saved_model.load`.
* Partition variable doesn't directly work with `tf.GradientTape`, please use
the `variables` attributes to get the actual variable components and use
them in gradient APIs instead.
__Dataset preparation__ __Dataset preparation__
With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
@ -367,37 +391,34 @@ class ParameterServerStrategyV2(distribute_lib.Strategy):
cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
object. object.
variable_partitioner: variable_partitioner:
a callable with the signature `num_partitions = fn(shape, dtype)`, where a `distribute.experimental.partitioners.Partitioner` that specifies
`num_partitions` is a list/tuple representing the number of partitions how to partition variables. If `None`, variables will not be
on each axis, and `shape` and `dtype` are of types `tf.TensorShape` and partitioned.
`tf.dtypes.Dtype`. If `None`, variables will not be partitioned.
* `variable_partitioner` will be called for all variables created under * Predefined partitioners in `tf.distribute.experimental.partitioners`
strategy `scope` to instruct how the variables should be partitioned. can be used for this argument. A commonly used partitioner is
Variables will be created in multiple partitions if there are more than `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`,
one partition along the partitioning axis, otherwise it falls back to which allocates at least 256K per shard, and each ps gets at most one
normal `tf.Variable`. shard.
* Only the first / outermost axis partitioning is supported, namely, * `variable_partitioner` will be called for each variable created under
elements in `num_partitions` must be 1 other than the first element. strategy `scope` to instruct how the variable should be partitioned.
Variables that have only one partition along the partitioning axis
(i.e., no need for partition) will be created as normal `tf.Variable`.
* Partitioner like `tf.compat.v1.min_max_variable_partitioner`, * Only the first / outermost axis partitioning is supported.
`tf.compat.v1.variable_axis_size_partitioner` and
`tf.compat.v1.fixed_size_partitioner` are also supported since they
conform to the required signature.
* Div partition * Div partition strategy is used to partition variables. Assuming we
strategy is used to partition variables. Assuming we assign consecutive assign consecutive integer ids along the first axis of a variable, then
integer ids along the first axis of a variable, then ids are assigned to ids are assigned to shards in a contiguous manner, while attempting to
shards in a contiguous manner, while attempting to keep each shard size keep each shard size identical. If the ids do not evenly divide the
identical. If the ids do not evenly divide the number of shards, each of number of shards, each of the first several shards will be assigned one
the first several shards will be assigned one more id. For instance, a more id. For instance, a variable whose first dimension is 13 has 13
variable whose first dimension is 13 has 13 ids, and they are split ids, and they are split across 5 shards as:
across 5 shards as:
`[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
* Variables created under `strategy.extended.colocate_vars_with` will * Variables created under `strategy.extended.colocate_vars_with` will
not be partitioned, e.g, optimizer's slot variables. not be partitioned.
""" """
# pyformat: enable # pyformat: enable
self._cluster_resolver = cluster_resolver self._cluster_resolver = cluster_resolver