522 lines
20 KiB
Python
522 lines
20 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""ShardedVariable class."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import copy
|
|
|
|
from tensorflow.python.framework import composite_tensor
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import type_spec
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import embedding_ops
|
|
from tensorflow.python.ops import partitioned_variables
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variables as variables_lib
|
|
from tensorflow.python.saved_model import revived_types
|
|
from tensorflow.python.saved_model import save_context
|
|
from tensorflow.python.training.saving import saveable_object_util
|
|
from tensorflow.python.training.tracking import base as trackable
|
|
from tensorflow.python.util import dispatch
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
|
|
class Partitioner(object):
|
|
"""Partitioner base class: all partitiners inherit from this class.
|
|
|
|
Partitioners should implement a `__call__` method with the following
|
|
signature:
|
|
|
|
```python
|
|
def __call__(self, shape, dtype, axis=0):
|
|
# Partitions the given `shape` and returns the partition results.
|
|
# See docstring of `__call__` method for the format of partition results.
|
|
```
|
|
"""
|
|
|
|
def __call__(self, shape, dtype, axis=0):
|
|
"""Partitions the given `shape` and returns the partition results.
|
|
|
|
Examples of a partitioner that allocates a fixed number of shards:
|
|
|
|
```python
|
|
partitioner = FixedShardsPartitioner(num_shards=2)
|
|
partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
|
|
print(partitions) # [2, 0]
|
|
```
|
|
|
|
Args:
|
|
shape: a `tf.TensorShape`, the shape to partition.
|
|
dtype: a `tf.dtypes.Dtype` indicating the type of the partition value.
|
|
axis: The axis to partition along. Default: outermost axis.
|
|
|
|
Returns:
|
|
A list of integers representing the number of partitions on each axis,
|
|
where i-th value correponds to i-th axis.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
|
|
class FixedShardsPartitioner(Partitioner):
|
|
"""Partitioner that allocates a fixed number of shards.
|
|
|
|
Examples:
|
|
|
|
>>> # standalone usage:
|
|
>>> partitioner = FixedShardsPartitioner(num_shards=2)
|
|
>>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
|
|
>>> [2, 1]
|
|
>>>
|
|
>>> # use in ParameterServerStrategy
|
|
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
|
|
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
|
|
|
|
"""
|
|
|
|
def __init__(self, num_shards):
|
|
"""Creates a new `FixedShardsPartitioner`.
|
|
|
|
Args:
|
|
num_shards: `int`, number of shards to partition.
|
|
"""
|
|
self._num_shards = num_shards
|
|
|
|
def __call__(self, shape, dtype, axis=0):
|
|
del dtype
|
|
result = [1] * len(shape)
|
|
result[axis] = min(self._num_shards, shape.dims[axis].value)
|
|
return result
|
|
|
|
|
|
@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
|
|
class MinSizePartitioner(Partitioner):
|
|
"""Partitioner that allocates a minimum size per shard.
|
|
|
|
This partitioner ensures each shard has at least `min_shard_bytes`, and tries
|
|
to allocate as many shards as possible, i.e., keeping shard size as small as
|
|
possible. The maximum number of such shards (upper bound) is given by
|
|
`max_shards`.
|
|
|
|
Examples:
|
|
|
|
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
|
|
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
|
|
>>> [2, 1]
|
|
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
|
|
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
|
|
>>> [6, 1]
|
|
>>>
|
|
>>> # use in ParameterServerStrategy
|
|
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
|
|
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
|
|
"""
|
|
|
|
def __init__(self,
|
|
min_shard_bytes=256 << 10,
|
|
max_shards=1,
|
|
bytes_per_string=16):
|
|
"""Creates a new `MinSizePartitioner`.
|
|
|
|
Args:
|
|
min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
|
|
max_shards: Upper bound on the number of shards. Defaults to 1.
|
|
bytes_per_string: If the partition value is of type string, this provides
|
|
an estimate of how large each string is.
|
|
"""
|
|
if min_shard_bytes < 1:
|
|
raise ValueError('min_shard_bytes must be positive, got: %r' %
|
|
min_shard_bytes)
|
|
if max_shards < 1:
|
|
raise ValueError('max_shards must be positive, got: %r' % max_shards)
|
|
if bytes_per_string < 1:
|
|
raise ValueError('bytes_per_string must be positive, got: %r' %
|
|
bytes_per_string)
|
|
self._min_shard_bytes = min_shard_bytes
|
|
self._max_shards = max_shards
|
|
self._bytes_per_string = bytes_per_string
|
|
|
|
def __call__(self, shape, dtype, axis=0):
|
|
return partitioned_variables.min_max_variable_partitioner(
|
|
max_partitions=self._max_shards,
|
|
axis=axis,
|
|
min_slice_size=self._min_shard_bytes,
|
|
bytes_per_string_element=self._bytes_per_string)(shape, dtype)
|
|
|
|
|
|
@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
|
|
class MaxSizePartitioner(Partitioner):
|
|
"""Partitioner that keeps shards below `max_shard_bytes`.
|
|
|
|
This partitioner ensures each shard has at most `max_shard_bytes`, and tries
|
|
to allocate as few shards as possible, i.e., keeping shard size as large
|
|
as possible.
|
|
|
|
If the partitioner hits the `max_shards` limit, then each shard may end up
|
|
larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
|
|
limit on the number of shards is enforced.
|
|
|
|
Examples:
|
|
|
|
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
|
|
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
|
|
>>> [6, 1]
|
|
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
|
|
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
|
|
>>> [2, 1]
|
|
>>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
|
|
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
|
|
>>> [1, 1]
|
|
>>>
|
|
>>> # use in ParameterServerStrategy
|
|
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
|
|
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
|
|
"""
|
|
|
|
def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
|
|
"""Creates a new `MaxSizePartitioner`.
|
|
|
|
Args:
|
|
max_shard_bytes: The maximum size any given shard is allowed to be.
|
|
max_shards: The maximum number of shards in `int` created taking
|
|
precedence over `max_shard_bytes`.
|
|
bytes_per_string: If the partition value is of type string, this provides
|
|
an estimate of how large each string is.
|
|
"""
|
|
if max_shard_bytes < 1:
|
|
raise ValueError('max_shard_bytes must be positive, got: %r' %
|
|
max_shard_bytes)
|
|
if max_shards and max_shards < 1:
|
|
raise ValueError('max_shards must be positive, got: %r' % max_shards)
|
|
if bytes_per_string < 1:
|
|
raise ValueError('bytes_per_string must be positive, got: %r' %
|
|
bytes_per_string)
|
|
|
|
self._max_shard_bytes = max_shard_bytes
|
|
self._max_shards = max_shards
|
|
self._bytes_per_string = bytes_per_string
|
|
|
|
def __call__(self, shape, dtype, axis=0):
|
|
return partitioned_variables.variable_axis_size_partitioner(
|
|
max_shard_bytes=self._max_shard_bytes,
|
|
max_shards=self._max_shards,
|
|
bytes_per_string_element=self._bytes_per_string,
|
|
axis=axis)(shape, dtype)
|
|
|
|
|
|
class ShardedVariableSpec(type_spec.TypeSpec):
|
|
"""Type specification for a `ShardedVariable`."""
|
|
|
|
__slots__ = ['_variable_specs']
|
|
|
|
value_type = property(lambda self: ShardedVariable)
|
|
|
|
def __init__(self, *variable_specs):
|
|
self._variable_specs = tuple(variable_specs)
|
|
|
|
def _serialize(self):
|
|
return self._variable_specs
|
|
|
|
@property
|
|
def _component_specs(self):
|
|
return self._variable_specs
|
|
|
|
def _to_components(self, value):
|
|
return value.variables
|
|
|
|
def _from_components(self, variables):
|
|
return ShardedVariable(variables)
|
|
|
|
|
|
class ShardedVariableMixin(trackable.Trackable):
|
|
"""Mixin for ShardedVariable."""
|
|
|
|
# TODO(b/170877138): Remove this mixin once fixed. This mixin is required
|
|
# since TPUShardedVariable can't be a CompositeTensor.
|
|
|
|
def __init__(self, variables, name='ShardedVariable'):
|
|
"""Treats `variables` as shards of a larger Variable.
|
|
|
|
|
|
Example:
|
|
|
|
```
|
|
variables = [
|
|
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
|
|
tf.Variable(..., shape=(15, 100), dtype=tf.float32),
|
|
tf.Variable(..., shape=(5, 100), dtype=tf.float32)
|
|
]
|
|
sharded_variable = ShardedVariableMixin(variables)
|
|
assert sharded_variable.shape.as_list() == [30, 100]
|
|
```
|
|
|
|
Args:
|
|
variables: A list of `ResourceVariable`s that comprise this sharded
|
|
variable. Variables should not be shared between different
|
|
`ShardedVariableMixin` objects.
|
|
name: String. Name of this container. Defaults to "ShardedVariable".
|
|
"""
|
|
super(ShardedVariableMixin, self).__init__()
|
|
self._variables = variables
|
|
self._name = name
|
|
|
|
first_var = variables[0]
|
|
|
|
if any(not isinstance(v, variables_lib.Variable) for v in variables):
|
|
raise ValueError(
|
|
'Expected a list of `Variable`s, found: {}'.format(variables))
|
|
|
|
dtypes = {v.dtype for v in variables}
|
|
if len(dtypes) > 1:
|
|
raise ValueError(
|
|
'All `Variable`s must have the same dtype, found: {}'.format(
|
|
[v.dtype for v in variables]))
|
|
self._dtype = first_var.dtype
|
|
|
|
# All variables must have the same shape for axes > 0.
|
|
higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
|
|
if len(higher_dim_shapes) > 1:
|
|
raise ValueError(
|
|
'All `Variables`s must have the same shapes except for the first '
|
|
'axis, found {}'.format([v.shape for v in variables]))
|
|
first_dim = sum(int(v.shape[0]) for v in variables)
|
|
self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:])
|
|
self._var_offsets = [
|
|
[0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
|
|
]
|
|
for i in range(1, len(variables)):
|
|
# Always partition on the first axis. Offsets on other axes are 0.
|
|
self._var_offsets[i][0] += (
|
|
self._var_offsets[i - 1][0] + variables[i - 1].shape[0])
|
|
|
|
save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access
|
|
if any(slice_info is not None for slice_info in save_slice_info):
|
|
raise ValueError('`SaveSliceInfo` should not be set for `Variable`s. '
|
|
'`ShardedVariable` will infer `SaveSliceInfo` according '
|
|
'to the order of the `Variable`s in the list passed to '
|
|
'the constructor. Found {}'.format(save_slice_info))
|
|
|
|
# We create an uninitialized saving_variable with the full shape, which can
|
|
# be later captured in signatures so that the signatures can treat this
|
|
# ShardedVariable as one single variable.
|
|
self._saving_variable = resource_variable_ops.UninitializedVariable(
|
|
shape=self._shape, dtype=self._dtype, name=self._name)
|
|
|
|
def __iter__(self):
|
|
"""Return an iterable for accessing the underlying sharded variables."""
|
|
return iter(self._variables)
|
|
|
|
@property
|
|
def _type_spec(self):
|
|
return ShardedVariableSpec(*(
|
|
resource_variable_ops.VariableSpec(v.shape, v.dtype)
|
|
for v in self._variables))
|
|
|
|
@property
|
|
def variables(self):
|
|
"""The list of `Variable`s that make up the shards of this object."""
|
|
if save_context.in_save_context():
|
|
return [self._saving_variable]
|
|
return self._variables
|
|
|
|
@property
|
|
def name(self):
|
|
"""The name of this object. Used for checkpointing."""
|
|
return self._name
|
|
|
|
@property
|
|
def dtype(self):
|
|
"""The dtype of all `Variable`s in this object."""
|
|
return self._dtype
|
|
|
|
@property
|
|
def shape(self):
|
|
"""The overall shape, combining all shards along axis `0`."""
|
|
return self._shape
|
|
|
|
def assign(self, value, use_locking=None, name=None, read_value=True):
|
|
for i, v in enumerate(self._variables):
|
|
v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
|
|
|
|
def assign_add(self, delta, use_locking=False, name=None, read_value=True):
|
|
for i, v in enumerate(self._variables):
|
|
v.assign_add(
|
|
array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
|
|
|
|
def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
|
|
for i, v in enumerate(self._variables):
|
|
v.assign_sub(
|
|
array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
|
|
|
|
def _gather_saveables_for_checkpoint(self):
|
|
"""Return a `Saveable` for each shard. See `Trackable`."""
|
|
|
|
def _saveable_factory(name=self.name):
|
|
"""Creates `SaveableObject`s for this `ShardedVariable`."""
|
|
saveables = []
|
|
dims = len(self._variables[0].shape)
|
|
var_offset = [0 for _ in range(dims)]
|
|
for v in self._variables:
|
|
save_slice_info = variables_lib.Variable.SaveSliceInfo(
|
|
full_name=self.name,
|
|
full_shape=self.shape.as_list(),
|
|
var_offset=copy.copy(var_offset),
|
|
var_shape=v.shape.as_list())
|
|
saveables.append(
|
|
saveable_object_util.ResourceVariableSaveable(
|
|
v, save_slice_info.spec, name))
|
|
var_offset[0] += int(v.shape[0])
|
|
return saveables
|
|
|
|
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
|
|
|
def _map_resources(self, save_options):
|
|
"""For implementing `Trackable`."""
|
|
obj_map, resource_map = {}, {}
|
|
for v in self._variables + [self._saving_variable]:
|
|
v_obj_map, v_resource_map = v._map_resources(save_options) # pylint:disable=protected-access
|
|
obj_map.update(v_obj_map)
|
|
resource_map.update(v_resource_map)
|
|
obj_map[self] = ShardedVariable([obj_map[self._saving_variable]],
|
|
name=self.name)
|
|
|
|
return obj_map, resource_map
|
|
|
|
|
|
class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
|
|
"""A container for `Variables` that should be treated as shards.
|
|
|
|
Variables that are too large to fit on a single device (e.g., large
|
|
embeddings)
|
|
may need to be sharded over multiple devices. This class maintains a list of
|
|
smaller variables that can be independently stored on separate devices (eg,
|
|
multiple parameter servers), and saves and restores those variables as if they
|
|
were a single larger variable.
|
|
|
|
Objects of this class can be saved with a given number of shards and then
|
|
restored from a checkpoint into a different number of shards.
|
|
|
|
Objects of this class can be saved to SavedModel format using
|
|
`tf.saved_model.save`. The SavedModel can be used by programs like TF serving
|
|
APIs. It is not yet supported to load the SavedModel with
|
|
`tf.saved_model.load`.
|
|
|
|
Since `ShardedVariable` can be saved and then restored to different number of
|
|
shards depending on the restore environments, for example, TF serving APIs
|
|
would restore to one shard for serving efficiency, when using
|
|
`ShardedVariable` in a tf.function, one should generally not assume it has the
|
|
same number of shards across save and load.
|
|
|
|
Sharding is only supported along the first dimension.
|
|
|
|
>>> class Model(tf.Module):
|
|
... def __init__(self):
|
|
... self.sharded_variable = ShardedVariable([
|
|
... tf.Variable([3.0], dtype=tf.float32),
|
|
... tf.Variable([2.0], dtype=tf.float32)
|
|
... ])
|
|
...
|
|
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
|
|
... def fn(self, x):
|
|
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
|
|
...
|
|
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
|
|
... def serve_fn(self, x):
|
|
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
|
|
>>>
|
|
>>> model = Model()
|
|
>>> model.fn(1).numpy()
|
|
2.0
|
|
>>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
|
|
... signatures=model.serve_fn)
|
|
"""
|
|
|
|
@property
|
|
def _type_spec(self):
|
|
return ShardedVariableSpec(*(
|
|
resource_variable_ops.VariableSpec(v.shape, v.dtype)
|
|
for v in self._variables))
|
|
|
|
|
|
def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
|
|
"""Converts a `ShardedVariable` to a `Tensor`."""
|
|
del name
|
|
if dtype is not None and not dtype.is_compatible_with(var.dtype):
|
|
raise ValueError(
|
|
'Incompatible type conversion requested to type {!r} for variable '
|
|
'of type {!r}'.format(dtype.name, var.dtype.name))
|
|
if as_ref:
|
|
raise NotImplementedError(
|
|
"ShardedVariable doesn't support being used as a reference.")
|
|
# We use op dispatch mechanism to override embedding_lookup ops when called
|
|
# with ShardedVariable. This requires embedding_lookup ops to raise TypeError
|
|
# when called with ShardedVariable. However since ShardedVariable can be
|
|
# converted to a tensor via concat, embedding_lookup ops would silently
|
|
# do the convertion and never raise a TypeError. To be able to properly
|
|
# raise a TypeError, namescope is used to detect if this method is called
|
|
# within a embedding_lookup op.
|
|
# NOTE: This doesn't work in eager mode since op namescope is always cleared
|
|
# in eager. This also breaks if user sets the name of embedding_lookup op
|
|
# with something that doesn't contain str "embedding_lookup".
|
|
#
|
|
# TODO(chenkai): Find a more robust way to do this, which should not rely
|
|
# on namescope.
|
|
if 'embedding_lookup' in ops.get_name_scope():
|
|
raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
|
|
' ops is disallowed.')
|
|
return array_ops.concat(var.variables, axis=0)
|
|
|
|
|
|
# Register a conversion function which reads the value of the variable,
|
|
# allowing instances of the class to be used as tensors.
|
|
ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor)
|
|
|
|
|
|
# Override the behavior of embedding_lookup(sharded_variable, ...)
|
|
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
|
|
def embedding_lookup(params,
|
|
ids,
|
|
partition_strategy='mod',
|
|
name=None,
|
|
validate_indices=True,
|
|
max_norm=None):
|
|
if isinstance(params, list):
|
|
params = params[0]
|
|
return embedding_ops.embedding_lookup(params.variables, ids,
|
|
partition_strategy, name,
|
|
validate_indices, max_norm)
|
|
|
|
|
|
def _raise_when_load(_):
|
|
# We don't have serialization and deserialization mechanisms for
|
|
# `ShardedVariable` in 2.x style save/load yet.
|
|
raise ValueError('Loading `ShardedVariable` is not supported')
|
|
|
|
|
|
revived_types.register_revived_type(
|
|
'_tf_distribute_sharded_variable',
|
|
lambda obj: isinstance(obj, ShardedVariable),
|
|
versions=[
|
|
revived_types.VersionedTypeRegistration(
|
|
object_factory=_raise_when_load,
|
|
version=0,
|
|
min_producer_version=0,
|
|
min_consumer_version=0)
|
|
])
|