STT-tensorflow/tensorflow/python/distribute/sharded_variable.py
Chenkai Kuang 12c67c0d47 Raise meaningful error message when loading a ShardedVariable.
PiperOrigin-RevId: 348539354
Change-Id: I2c4a8466c3d1355ec8e5984ed039194c18c4305c
2020-12-21 15:59:59 -08:00

522 lines
20 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ShardedVariable class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import save_context
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
class Partitioner(object):
"""Partitioner base class: all partitiners inherit from this class.
Partitioners should implement a `__call__` method with the following
signature:
```python
def __call__(self, shape, dtype, axis=0):
# Partitions the given `shape` and returns the partition results.
# See docstring of `__call__` method for the format of partition results.
```
"""
def __call__(self, shape, dtype, axis=0):
"""Partitions the given `shape` and returns the partition results.
Examples of a partitioner that allocates a fixed number of shards:
```python
partitioner = FixedShardsPartitioner(num_shards=2)
partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
print(partitions) # [2, 0]
```
Args:
shape: a `tf.TensorShape`, the shape to partition.
dtype: a `tf.dtypes.Dtype` indicating the type of the partition value.
axis: The axis to partition along. Default: outermost axis.
Returns:
A list of integers representing the number of partitions on each axis,
where i-th value correponds to i-th axis.
"""
raise NotImplementedError
@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
class FixedShardsPartitioner(Partitioner):
"""Partitioner that allocates a fixed number of shards.
Examples:
>>> # standalone usage:
>>> partitioner = FixedShardsPartitioner(num_shards=2)
>>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
>>> [2, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self, num_shards):
"""Creates a new `FixedShardsPartitioner`.
Args:
num_shards: `int`, number of shards to partition.
"""
self._num_shards = num_shards
def __call__(self, shape, dtype, axis=0):
del dtype
result = [1] * len(shape)
result[axis] = min(self._num_shards, shape.dims[axis].value)
return result
@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
class MinSizePartitioner(Partitioner):
"""Partitioner that allocates a minimum size per shard.
This partitioner ensures each shard has at least `min_shard_bytes`, and tries
to allocate as many shards as possible, i.e., keeping shard size as small as
possible. The maximum number of such shards (upper bound) is given by
`max_shards`.
Examples:
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [2, 1]
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [6, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self,
min_shard_bytes=256 << 10,
max_shards=1,
bytes_per_string=16):
"""Creates a new `MinSizePartitioner`.
Args:
min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
max_shards: Upper bound on the number of shards. Defaults to 1.
bytes_per_string: If the partition value is of type string, this provides
an estimate of how large each string is.
"""
if min_shard_bytes < 1:
raise ValueError('min_shard_bytes must be positive, got: %r' %
min_shard_bytes)
if max_shards < 1:
raise ValueError('max_shards must be positive, got: %r' % max_shards)
if bytes_per_string < 1:
raise ValueError('bytes_per_string must be positive, got: %r' %
bytes_per_string)
self._min_shard_bytes = min_shard_bytes
self._max_shards = max_shards
self._bytes_per_string = bytes_per_string
def __call__(self, shape, dtype, axis=0):
return partitioned_variables.min_max_variable_partitioner(
max_partitions=self._max_shards,
axis=axis,
min_slice_size=self._min_shard_bytes,
bytes_per_string_element=self._bytes_per_string)(shape, dtype)
@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
class MaxSizePartitioner(Partitioner):
"""Partitioner that keeps shards below `max_shard_bytes`.
This partitioner ensures each shard has at most `max_shard_bytes`, and tries
to allocate as few shards as possible, i.e., keeping shard size as large
as possible.
If the partitioner hits the `max_shards` limit, then each shard may end up
larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
limit on the number of shards is enforced.
Examples:
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [6, 1]
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [2, 1]
>>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [1, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
"""Creates a new `MaxSizePartitioner`.
Args:
max_shard_bytes: The maximum size any given shard is allowed to be.
max_shards: The maximum number of shards in `int` created taking
precedence over `max_shard_bytes`.
bytes_per_string: If the partition value is of type string, this provides
an estimate of how large each string is.
"""
if max_shard_bytes < 1:
raise ValueError('max_shard_bytes must be positive, got: %r' %
max_shard_bytes)
if max_shards and max_shards < 1:
raise ValueError('max_shards must be positive, got: %r' % max_shards)
if bytes_per_string < 1:
raise ValueError('bytes_per_string must be positive, got: %r' %
bytes_per_string)
self._max_shard_bytes = max_shard_bytes
self._max_shards = max_shards
self._bytes_per_string = bytes_per_string
def __call__(self, shape, dtype, axis=0):
return partitioned_variables.variable_axis_size_partitioner(
max_shard_bytes=self._max_shard_bytes,
max_shards=self._max_shards,
bytes_per_string_element=self._bytes_per_string,
axis=axis)(shape, dtype)
class ShardedVariableSpec(type_spec.TypeSpec):
"""Type specification for a `ShardedVariable`."""
__slots__ = ['_variable_specs']
value_type = property(lambda self: ShardedVariable)
def __init__(self, *variable_specs):
self._variable_specs = tuple(variable_specs)
def _serialize(self):
return self._variable_specs
@property
def _component_specs(self):
return self._variable_specs
def _to_components(self, value):
return value.variables
def _from_components(self, variables):
return ShardedVariable(variables)
class ShardedVariableMixin(trackable.Trackable):
"""Mixin for ShardedVariable."""
# TODO(b/170877138): Remove this mixin once fixed. This mixin is required
# since TPUShardedVariable can't be a CompositeTensor.
def __init__(self, variables, name='ShardedVariable'):
"""Treats `variables` as shards of a larger Variable.
Example:
```
variables = [
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
tf.Variable(..., shape=(15, 100), dtype=tf.float32),
tf.Variable(..., shape=(5, 100), dtype=tf.float32)
]
sharded_variable = ShardedVariableMixin(variables)
assert sharded_variable.shape.as_list() == [30, 100]
```
Args:
variables: A list of `ResourceVariable`s that comprise this sharded
variable. Variables should not be shared between different
`ShardedVariableMixin` objects.
name: String. Name of this container. Defaults to "ShardedVariable".
"""
super(ShardedVariableMixin, self).__init__()
self._variables = variables
self._name = name
first_var = variables[0]
if any(not isinstance(v, variables_lib.Variable) for v in variables):
raise ValueError(
'Expected a list of `Variable`s, found: {}'.format(variables))
dtypes = {v.dtype for v in variables}
if len(dtypes) > 1:
raise ValueError(
'All `Variable`s must have the same dtype, found: {}'.format(
[v.dtype for v in variables]))
self._dtype = first_var.dtype
# All variables must have the same shape for axes > 0.
higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
if len(higher_dim_shapes) > 1:
raise ValueError(
'All `Variables`s must have the same shapes except for the first '
'axis, found {}'.format([v.shape for v in variables]))
first_dim = sum(int(v.shape[0]) for v in variables)
self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:])
self._var_offsets = [
[0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
]
for i in range(1, len(variables)):
# Always partition on the first axis. Offsets on other axes are 0.
self._var_offsets[i][0] += (
self._var_offsets[i - 1][0] + variables[i - 1].shape[0])
save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access
if any(slice_info is not None for slice_info in save_slice_info):
raise ValueError('`SaveSliceInfo` should not be set for `Variable`s. '
'`ShardedVariable` will infer `SaveSliceInfo` according '
'to the order of the `Variable`s in the list passed to '
'the constructor. Found {}'.format(save_slice_info))
# We create an uninitialized saving_variable with the full shape, which can
# be later captured in signatures so that the signatures can treat this
# ShardedVariable as one single variable.
self._saving_variable = resource_variable_ops.UninitializedVariable(
shape=self._shape, dtype=self._dtype, name=self._name)
def __iter__(self):
"""Return an iterable for accessing the underlying sharded variables."""
return iter(self._variables)
@property
def _type_spec(self):
return ShardedVariableSpec(*(
resource_variable_ops.VariableSpec(v.shape, v.dtype)
for v in self._variables))
@property
def variables(self):
"""The list of `Variable`s that make up the shards of this object."""
if save_context.in_save_context():
return [self._saving_variable]
return self._variables
@property
def name(self):
"""The name of this object. Used for checkpointing."""
return self._name
@property
def dtype(self):
"""The dtype of all `Variable`s in this object."""
return self._dtype
@property
def shape(self):
"""The overall shape, combining all shards along axis `0`."""
return self._shape
def assign(self, value, use_locking=None, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
def assign_add(self, delta, use_locking=False, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign_add(
array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign_sub(
array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
def _gather_saveables_for_checkpoint(self):
"""Return a `Saveable` for each shard. See `Trackable`."""
def _saveable_factory(name=self.name):
"""Creates `SaveableObject`s for this `ShardedVariable`."""
saveables = []
dims = len(self._variables[0].shape)
var_offset = [0 for _ in range(dims)]
for v in self._variables:
save_slice_info = variables_lib.Variable.SaveSliceInfo(
full_name=self.name,
full_shape=self.shape.as_list(),
var_offset=copy.copy(var_offset),
var_shape=v.shape.as_list())
saveables.append(
saveable_object_util.ResourceVariableSaveable(
v, save_slice_info.spec, name))
var_offset[0] += int(v.shape[0])
return saveables
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _map_resources(self, save_options):
"""For implementing `Trackable`."""
obj_map, resource_map = {}, {}
for v in self._variables + [self._saving_variable]:
v_obj_map, v_resource_map = v._map_resources(save_options) # pylint:disable=protected-access
obj_map.update(v_obj_map)
resource_map.update(v_resource_map)
obj_map[self] = ShardedVariable([obj_map[self._saving_variable]],
name=self.name)
return obj_map, resource_map
class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
"""A container for `Variables` that should be treated as shards.
Variables that are too large to fit on a single device (e.g., large
embeddings)
may need to be sharded over multiple devices. This class maintains a list of
smaller variables that can be independently stored on separate devices (eg,
multiple parameter servers), and saves and restores those variables as if they
were a single larger variable.
Objects of this class can be saved with a given number of shards and then
restored from a checkpoint into a different number of shards.
Objects of this class can be saved to SavedModel format using
`tf.saved_model.save`. The SavedModel can be used by programs like TF serving
APIs. It is not yet supported to load the SavedModel with
`tf.saved_model.load`.
Since `ShardedVariable` can be saved and then restored to different number of
shards depending on the restore environments, for example, TF serving APIs
would restore to one shard for serving efficiency, when using
`ShardedVariable` in a tf.function, one should generally not assume it has the
same number of shards across save and load.
Sharding is only supported along the first dimension.
>>> class Model(tf.Module):
... def __init__(self):
... self.sharded_variable = ShardedVariable([
... tf.Variable([3.0], dtype=tf.float32),
... tf.Variable([2.0], dtype=tf.float32)
... ])
...
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
... def fn(self, x):
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
...
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
... def serve_fn(self, x):
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
>>>
>>> model = Model()
>>> model.fn(1).numpy()
2.0
>>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
... signatures=model.serve_fn)
"""
@property
def _type_spec(self):
return ShardedVariableSpec(*(
resource_variable_ops.VariableSpec(v.shape, v.dtype)
for v in self._variables))
def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
"""Converts a `ShardedVariable` to a `Tensor`."""
del name
if dtype is not None and not dtype.is_compatible_with(var.dtype):
raise ValueError(
'Incompatible type conversion requested to type {!r} for variable '
'of type {!r}'.format(dtype.name, var.dtype.name))
if as_ref:
raise NotImplementedError(
"ShardedVariable doesn't support being used as a reference.")
# We use op dispatch mechanism to override embedding_lookup ops when called
# with ShardedVariable. This requires embedding_lookup ops to raise TypeError
# when called with ShardedVariable. However since ShardedVariable can be
# converted to a tensor via concat, embedding_lookup ops would silently
# do the convertion and never raise a TypeError. To be able to properly
# raise a TypeError, namescope is used to detect if this method is called
# within a embedding_lookup op.
# NOTE: This doesn't work in eager mode since op namescope is always cleared
# in eager. This also breaks if user sets the name of embedding_lookup op
# with something that doesn't contain str "embedding_lookup".
#
# TODO(chenkai): Find a more robust way to do this, which should not rely
# on namescope.
if 'embedding_lookup' in ops.get_name_scope():
raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
' ops is disallowed.')
return array_ops.concat(var.variables, axis=0)
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor)
# Override the behavior of embedding_lookup(sharded_variable, ...)
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
def embedding_lookup(params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None):
if isinstance(params, list):
params = params[0]
return embedding_ops.embedding_lookup(params.variables, ids,
partition_strategy, name,
validate_indices, max_norm)
def _raise_when_load(_):
# We don't have serialization and deserialization mechanisms for
# `ShardedVariable` in 2.x style save/load yet.
raise ValueError('Loading `ShardedVariable` is not supported')
revived_types.register_revived_type(
'_tf_distribute_sharded_variable',
lambda obj: isinstance(obj, ShardedVariable),
versions=[
revived_types.VersionedTypeRegistration(
object_factory=_raise_when_load,
version=0,
min_producer_version=0,
min_consumer_version=0)
])