diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py index 9886e42a8b3..7accc066d8a 100644 --- a/tensorflow/python/distribute/sharded_variable.py +++ b/tensorflow/python/distribute/sharded_variable.py @@ -96,6 +96,10 @@ class ShardedVariable(trackable.Trackable): 'to the order of the `Variable`s in the list passed to ' 'the constructor. Found {}'.format(save_slice_info)) + def __iter__(self): + """Return an iterable for accessing the underlying sharded variables.""" + return iter(self._variables) + @property def variables(self): """The list of `Variable`s that make up the shards of this object.""" diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 1ff15d7e2e1..231ab7661f0 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -118,6 +118,7 @@ py_library( "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/eager:monitoring", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:constraints", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 0f4bec92e39..0421772a75a 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -34,6 +34,7 @@ from tensorflow.python import tf2 from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import function @@ -590,7 +591,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._handle_weight_regularization(name_in_scope, variable, regularizer) - if isinstance(variable, tf_variables.PartitionedVariable): + if isinstance( + variable, + (tf_variables.PartitionedVariable, sharded_variable.ShardedVariable)): for v in variable: backend.track_variable(v) if trainable: diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 46ac88754a8..10a9fe088ab 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -213,12 +213,13 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:util", + "//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/eager:context", "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:base_layer", "//tensorflow/python/keras:constraints", "//tensorflow/python/keras:initializers", "//tensorflow/python/keras:regularizers", + "//tensorflow/python/keras/engine:base_layer", "//tensorflow/python/keras/utils:tf_utils", ], ) @@ -593,9 +594,15 @@ cuda_py_test( python_version = "PY3", deps = [ "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training_lib", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", - "@absl_py//absl/testing:parameterized", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/ops/ragged:ragged_factory_ops", ], ) diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index e30e93f02dc..3444b3a7665 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K @@ -183,7 +184,10 @@ class Embedding(Layer): dtype = K.dtype(inputs) if dtype != 'int32' and dtype != 'int64': inputs = math_ops.cast(inputs, 'int32') - out = embedding_ops.embedding_lookup(self.embeddings, inputs) + if isinstance(self.embeddings, sharded_variable.ShardedVariable): + out = embedding_ops.embedding_lookup_v2(self.embeddings.variables, inputs) + else: + out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs) return out def get_config(self): diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py index 661b29cd7bf..6aa873b2bd7 100644 --- a/tensorflow/python/keras/layers/embeddings_test.py +++ b/tensorflow/python/keras/layers/embeddings_test.py @@ -21,12 +21,14 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras +from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import backprop from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test from tensorflow.python.training import adagrad @@ -130,6 +132,20 @@ class EmbeddingTest(keras_parameterized.TestCase): [[[1., 1.], [2., 2.], [2., 2.]], [[0., 0.]], [[1., 1.], [2., 2.]]], ragged_rank=1)) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_embedding_with_sharded_variable(self): + layer = keras.layers.Embedding(input_dim=5, output_dim=2) + v = [ + variables.Variable([[1., 2.], [3., 4.]]), + variables.Variable([[5., 6.], [7., 8.]]), + variables.Variable([[9., 10.]]) + ] + model = keras.models.Sequential([layer]) + layer.embeddings = sharded_variable.ShardedVariable(v) + model.run_eagerly = testing_utils.should_run_eagerly() + outputs = model.predict(np.array([[0, 2, 4]], dtype='int32')) + self.assertAllClose(outputs, [[[1., 2.], [5., 6.], [9., 10.]]]) + if __name__ == '__main__': test.main()