Use ops dispatch to overwrite the behavior of embedding_lookup ops when called with ShardedVariable. Otherwise ShardedVariable will be converted to a dense tensor when passing to embedding_lookup.

Ops like `tf.nn.nce_loss` and `tf.nn.sampled_softmax_loss` also benefit from this as they use embedding_lookup internally.

PiperOrigin-RevId: 338369985
Change-Id: I89ebe2a452fc1d599567cb80e80ee9b023e5aa1c
This commit is contained in:
Chenkai Kuang 2020-10-21 17:04:20 -07:00 committed by TensorFlower Gardener
parent 3aaf894c46
commit 785353f8d4
3 changed files with 92 additions and 0 deletions

View File

@ -1114,11 +1114,15 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:partitioned_variables",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tf_export",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/saved_model:save_context",
"//tensorflow/python/training/saving:saveable_object_util",
@ -1138,11 +1142,13 @@ tf_py_test(
":sharded_variable",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_ops",
"//tensorflow/python:session",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",

View File

@ -24,12 +24,14 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import save_context
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@ -452,6 +454,7 @@ class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
"""Converts a `ShardedVariable` to a `Tensor`."""
del name
if dtype is not None and not dtype.is_compatible_with(var.dtype):
raise ValueError(
@ -460,9 +463,40 @@ def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
if as_ref:
raise NotImplementedError(
"ShardedVariable doesn't support being used as a reference.")
# We use op dispatch mechanism to override embedding_lookup ops when called
# with ShardedVariable. This requires embedding_lookup ops to raise TypeError
# when called with ShardedVariable. However since ShardedVariable can be
# converted to a tensor via concat, embedding_lookup ops would silently
# do the convertion and never raise a TypeError. To be able to properly
# raise a TypeError, namescope is used to detect if this method is called
# within a embedding_lookup op.
# NOTE: This doesn't work in eager mode since op namescope is always cleared
# in eager. This also breaks if user sets the name of embedding_lookup op
# with something that doesn't contain str "embedding_lookup".
#
# TODO(chenkai): Find a more robust way to do this, which should not rely
# on namescope.
if 'embedding_lookup' in ops.get_name_scope():
raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
' ops is disallowed.')
return array_ops.concat(var.variables, axis=0)
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor)
# Override the behavior of embedding_lookup(sharded_variable, ...)
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
def embedding_lookup(params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None):
if isinstance(params, list):
params = params[0]
return embedding_ops.embedding_lookup(params.variables, ids,
partition_strategy, name,
validate_indices, max_norm)

View File

@ -24,8 +24,10 @@ from tensorflow.python.client import session as session_lib
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras.engine import base_layer
@ -462,6 +464,56 @@ class ShardedVariableTest(test.TestCase):
checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies)
self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))
def test_embedding_lookup(self):
v = [
variables_lib.Variable([[1., 2.], [3., 4.]]),
variables_lib.Variable([[5., 6.], [7., 8.]]),
variables_lib.Variable([[9., 10.]])
]
sv = sharded_variable.ShardedVariable(v)
@def_function.function
def lookup():
ids = constant_op.constant([0, 3, 4])
return embedding_ops.embedding_lookup_v2(sv, ids)
@def_function.function
def sparse_lookup():
sp_ids = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
values=[0, 3, 4, 1],
dense_shape=[3, 3])
return embedding_ops.embedding_lookup_sparse_v2(sv, sp_ids, None)
@def_function.function
def safe_sparse_lookup():
sp_ids = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
values=[0, -1, 4, 1],
dense_shape=[3, 3])
sp_weights = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
values=[1., 1., -1., 1.],
dense_shape=[3, 3])
return embedding_ops.safe_embedding_lookup_sparse_v2(
sv, sp_ids, sp_weights)
# TODO(chenkai): Add safe_sparse_lookup to the list. Currently
# ShardedVariable is converted to a tensor in safe_sparse_lookup.
for func in [lookup, sparse_lookup]:
num_gather_ops = 0
for op in func.get_concrete_function().graph.get_operations():
if op.type == 'ResourceGather':
num_gather_ops += 1
self.assertEqual(
num_gather_ops, len(v), 'Number of ResourceGather op does not match'
' expected, possibly due to ShardedVariable accidentally being'
' converted to tensor in embedding_lookup ops.')
self.assertAllEqual(lookup(), [[1., 2.], [7., 8.], [9., 10.]])
self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]])
self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]])
if __name__ == '__main__':
v2_compat.enable_v2_behavior()