Support ShardedVariable in `tf.keras.layers.Embedding`.

A typical usage is user / dist strategy can define a variable_strategy_scope that creates ShardedVariable and build embedding layer under that scope. In this way `add_weights` returns ShardedVariable.

Note that this CL also switches to use embedding_lookup_v2, which always use "div" partition_strategy whereas embedding_lookup defaults to"mod". I expect this to be a safe change as we don't explicitly support sharded embedding lookup yet.

PiperOrigin-RevId: 312701263
Change-Id: Ic76ed454244ed4d77f7ee9ae9a07a8b663956458
This commit is contained in:
Chenkai Kuang 2020-05-21 10:58:55 -07:00 committed by TensorFlower Gardener
parent 31eeaec3b4
commit 808b545c38
6 changed files with 39 additions and 4 deletions

View File

@ -96,6 +96,10 @@ class ShardedVariable(trackable.Trackable):
'to the order of the `Variable`s in the list passed to '
'the constructor. Found {}'.format(save_slice_info))
def __iter__(self):
"""Return an iterable for accessing the underlying sharded variables."""
return iter(self._variables)
@property
def variables(self):
"""The list of `Variable`s that make up the shards of this object."""

View File

@ -118,6 +118,7 @@ py_library(
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/eager:monitoring",
"//tensorflow/python/keras:backend",
"//tensorflow/python/keras:constraints",

View File

@ -34,6 +34,7 @@ from tensorflow.python import tf2
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import function
@ -590,7 +591,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._handle_weight_regularization(name_in_scope,
variable,
regularizer)
if isinstance(variable, tf_variables.PartitionedVariable):
if isinstance(
variable,
(tf_variables.PartitionedVariable, sharded_variable.ShardedVariable)):
for v in variable:
backend.track_variable(v)
if trainable:

View File

@ -213,12 +213,13 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python/distribute:sharded_variable",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras:backend",
"//tensorflow/python/keras:base_layer",
"//tensorflow/python/keras:constraints",
"//tensorflow/python/keras:initializers",
"//tensorflow/python/keras:regularizers",
"//tensorflow/python/keras/engine:base_layer",
"//tensorflow/python/keras/utils:tf_utils",
],
)
@ -593,9 +594,15 @@ cuda_py_test(
python_version = "PY3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:training_lib",
"//tensorflow/python:variables",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/keras",
"//tensorflow/python/keras:combinations",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python/keras:testing_utils",
"//tensorflow/python/ops/ragged:ragged_factory_ops",
],
)

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
@ -183,7 +184,10 @@ class Embedding(Layer):
dtype = K.dtype(inputs)
if dtype != 'int32' and dtype != 'int64':
inputs = math_ops.cast(inputs, 'int32')
out = embedding_ops.embedding_lookup(self.embeddings, inputs)
if isinstance(self.embeddings, sharded_variable.ShardedVariable):
out = embedding_ops.embedding_lookup_v2(self.embeddings.variables, inputs)
else:
out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs)
return out
def get_config(self):

View File

@ -21,12 +21,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.eager import backprop
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
from tensorflow.python.training import adagrad
@ -130,6 +132,20 @@ class EmbeddingTest(keras_parameterized.TestCase):
[[[1., 1.], [2., 2.], [2., 2.]], [[0., 0.]], [[1., 1.], [2., 2.]]],
ragged_rank=1))
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_embedding_with_sharded_variable(self):
layer = keras.layers.Embedding(input_dim=5, output_dim=2)
v = [
variables.Variable([[1., 2.], [3., 4.]]),
variables.Variable([[5., 6.], [7., 8.]]),
variables.Variable([[9., 10.]])
]
model = keras.models.Sequential([layer])
layer.embeddings = sharded_variable.ShardedVariable(v)
model.run_eagerly = testing_utils.should_run_eagerly()
outputs = model.predict(np.array([[0, 2, 4]], dtype='int32'))
self.assertAllClose(outputs, [[[1., 2.], [5., 6.], [9., 10.]]])
if __name__ == '__main__':
test.main()