diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 96e164b0480..861394347a7 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1114,11 +1114,15 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:partitioned_variables", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tf_export", + "//tensorflow/python:type_spec", + "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/saved_model:save_context", "//tensorflow/python/training/saving:saveable_object_util", @@ -1138,11 +1142,13 @@ tf_py_test( ":sharded_variable", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:embedding_ops", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework_ops", "//tensorflow/python:session", + "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", "//tensorflow/python:util", "//tensorflow/python:variable_scope", diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py index c40d0dd0e89..553d82e4a26 100644 --- a/tensorflow/python/distribute/sharded_variable.py +++ b/tensorflow/python/distribute/sharded_variable.py @@ -24,12 +24,14 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables as variables_lib from tensorflow.python.saved_model import save_context from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -452,6 +454,7 @@ class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor): def _var_to_tensor(var, dtype=None, name=None, as_ref=False): + """Converts a `ShardedVariable` to a `Tensor`.""" del name if dtype is not None and not dtype.is_compatible_with(var.dtype): raise ValueError( @@ -460,9 +463,40 @@ def _var_to_tensor(var, dtype=None, name=None, as_ref=False): if as_ref: raise NotImplementedError( "ShardedVariable doesn't support being used as a reference.") + # We use op dispatch mechanism to override embedding_lookup ops when called + # with ShardedVariable. This requires embedding_lookup ops to raise TypeError + # when called with ShardedVariable. However since ShardedVariable can be + # converted to a tensor via concat, embedding_lookup ops would silently + # do the convertion and never raise a TypeError. To be able to properly + # raise a TypeError, namescope is used to detect if this method is called + # within a embedding_lookup op. + # NOTE: This doesn't work in eager mode since op namescope is always cleared + # in eager. This also breaks if user sets the name of embedding_lookup op + # with something that doesn't contain str "embedding_lookup". + # + # TODO(chenkai): Find a more robust way to do this, which should not rely + # on namescope. + if 'embedding_lookup' in ops.get_name_scope(): + raise TypeError('Converting ShardedVariable to tensor in embedding lookup' + ' ops is disallowed.') return array_ops.concat(var.variables, axis=0) # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor) + + +# Override the behavior of embedding_lookup(sharded_variable, ...) +@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable) +def embedding_lookup(params, + ids, + partition_strategy='mod', + name=None, + validate_indices=True, + max_norm=None): + if isinstance(params, list): + params = params[0] + return embedding_ops.embedding_lookup(params.variables, ids, + partition_strategy, name, + validate_indices, max_norm) diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py index 2aa8caea4d8..8b88d7b016e 100644 --- a/tensorflow/python/distribute/sharded_variable_test.py +++ b/tensorflow/python/distribute/sharded_variable_test.py @@ -24,8 +24,10 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras.engine import base_layer @@ -462,6 +464,56 @@ class ShardedVariableTest(test.TestCase): checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies) self.assertEqual(checkpoint_deps, set([layer.w, layer.b])) + def test_embedding_lookup(self): + v = [ + variables_lib.Variable([[1., 2.], [3., 4.]]), + variables_lib.Variable([[5., 6.], [7., 8.]]), + variables_lib.Variable([[9., 10.]]) + ] + sv = sharded_variable.ShardedVariable(v) + + @def_function.function + def lookup(): + ids = constant_op.constant([0, 3, 4]) + return embedding_ops.embedding_lookup_v2(sv, ids) + + @def_function.function + def sparse_lookup(): + sp_ids = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [2, 2]], + values=[0, 3, 4, 1], + dense_shape=[3, 3]) + return embedding_ops.embedding_lookup_sparse_v2(sv, sp_ids, None) + + @def_function.function + def safe_sparse_lookup(): + sp_ids = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [2, 2]], + values=[0, -1, 4, 1], + dense_shape=[3, 3]) + sp_weights = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [2, 2]], + values=[1., 1., -1., 1.], + dense_shape=[3, 3]) + return embedding_ops.safe_embedding_lookup_sparse_v2( + sv, sp_ids, sp_weights) + + # TODO(chenkai): Add safe_sparse_lookup to the list. Currently + # ShardedVariable is converted to a tensor in safe_sparse_lookup. + for func in [lookup, sparse_lookup]: + num_gather_ops = 0 + for op in func.get_concrete_function().graph.get_operations(): + if op.type == 'ResourceGather': + num_gather_ops += 1 + self.assertEqual( + num_gather_ops, len(v), 'Number of ResourceGather op does not match' + ' expected, possibly due to ShardedVariable accidentally being' + ' converted to tensor in embedding_lookup ops.') + + self.assertAllEqual(lookup(), [[1., 2.], [7., 8.], [9., 10.]]) + self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]]) + self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]]) + if __name__ == '__main__': v2_compat.enable_v2_behavior()