Support ShardedVariable in `tf.keras.layers.Embedding`.
A typical usage is user / dist strategy can define a variable_strategy_scope that creates ShardedVariable and build embedding layer under that scope. In this way `add_weights` returns ShardedVariable. Note that this CL also switches to use embedding_lookup_v2, which always use "div" partition_strategy whereas embedding_lookup defaults to"mod". I expect this to be a safe change as we don't explicitly support sharded embedding lookup yet. PiperOrigin-RevId: 312701263 Change-Id: Ic76ed454244ed4d77f7ee9ae9a07a8b663956458
This commit is contained in:
parent
31eeaec3b4
commit
808b545c38
|
@ -96,6 +96,10 @@ class ShardedVariable(trackable.Trackable):
|
|||
'to the order of the `Variable`s in the list passed to '
|
||||
'the constructor. Found {}'.format(save_slice_info))
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterable for accessing the underlying sharded variables."""
|
||||
return iter(self._variables)
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
"""The list of `Variable`s that make up the shards of this object."""
|
||||
|
|
|
@ -118,6 +118,7 @@ py_library(
|
|||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:constraints",
|
||||
|
|
|
@ -34,6 +34,7 @@ from tensorflow.python import tf2
|
|||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import function
|
||||
|
@ -590,7 +591,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
self._handle_weight_regularization(name_in_scope,
|
||||
variable,
|
||||
regularizer)
|
||||
if isinstance(variable, tf_variables.PartitionedVariable):
|
||||
if isinstance(
|
||||
variable,
|
||||
(tf_variables.PartitionedVariable, sharded_variable.ShardedVariable)):
|
||||
for v in variable:
|
||||
backend.track_variable(v)
|
||||
if trainable:
|
||||
|
|
|
@ -213,12 +213,13 @@ py_library(
|
|||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:base_layer",
|
||||
"//tensorflow/python/keras:constraints",
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras:regularizers",
|
||||
"//tensorflow/python/keras/engine:base_layer",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
],
|
||||
)
|
||||
|
@ -593,9 +594,15 @@ cuda_py_test(
|
|||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:training_lib",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python/keras:testing_utils",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
|
@ -183,7 +184,10 @@ class Embedding(Layer):
|
|||
dtype = K.dtype(inputs)
|
||||
if dtype != 'int32' and dtype != 'int64':
|
||||
inputs = math_ops.cast(inputs, 'int32')
|
||||
out = embedding_ops.embedding_lookup(self.embeddings, inputs)
|
||||
if isinstance(self.embeddings, sharded_variable.ShardedVariable):
|
||||
out = embedding_ops.embedding_lookup_v2(self.embeddings.variables, inputs)
|
||||
else:
|
||||
out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs)
|
||||
return out
|
||||
|
||||
def get_config(self):
|
||||
|
|
|
@ -21,12 +21,14 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adagrad
|
||||
|
@ -130,6 +132,20 @@ class EmbeddingTest(keras_parameterized.TestCase):
|
|||
[[[1., 1.], [2., 2.], [2., 2.]], [[0., 0.]], [[1., 1.], [2., 2.]]],
|
||||
ragged_rank=1))
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_embedding_with_sharded_variable(self):
|
||||
layer = keras.layers.Embedding(input_dim=5, output_dim=2)
|
||||
v = [
|
||||
variables.Variable([[1., 2.], [3., 4.]]),
|
||||
variables.Variable([[5., 6.], [7., 8.]]),
|
||||
variables.Variable([[9., 10.]])
|
||||
]
|
||||
model = keras.models.Sequential([layer])
|
||||
layer.embeddings = sharded_variable.ShardedVariable(v)
|
||||
model.run_eagerly = testing_utils.should_run_eagerly()
|
||||
outputs = model.predict(np.array([[0, 2, 4]], dtype='int32'))
|
||||
self.assertAllClose(outputs, [[[1., 2.], [5., 6.], [9., 10.]]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue