Use ops dispatch to overwrite the behavior of embedding_lookup ops when called with ShardedVariable. Otherwise ShardedVariable will be converted to a dense tensor when passing to embedding_lookup.
Ops like `tf.nn.nce_loss` and `tf.nn.sampled_softmax_loss` also benefit from this as they use embedding_lookup internally. PiperOrigin-RevId: 338369985 Change-Id: I89ebe2a452fc1d599567cb80e80ee9b023e5aa1c
This commit is contained in:
parent
3aaf894c46
commit
785353f8d4
tensorflow/python/distribute
@ -1114,11 +1114,15 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:partitioned_variables",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/saved_model:save_context",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
@ -1138,11 +1142,13 @@ tf_py_test(
|
||||
":sharded_variable",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
|
@ -24,12 +24,14 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.saved_model import save_context
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -452,6 +454,7 @@ class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
|
||||
|
||||
|
||||
def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a `ShardedVariable` to a `Tensor`."""
|
||||
del name
|
||||
if dtype is not None and not dtype.is_compatible_with(var.dtype):
|
||||
raise ValueError(
|
||||
@ -460,9 +463,40 @@ def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
|
||||
if as_ref:
|
||||
raise NotImplementedError(
|
||||
"ShardedVariable doesn't support being used as a reference.")
|
||||
# We use op dispatch mechanism to override embedding_lookup ops when called
|
||||
# with ShardedVariable. This requires embedding_lookup ops to raise TypeError
|
||||
# when called with ShardedVariable. However since ShardedVariable can be
|
||||
# converted to a tensor via concat, embedding_lookup ops would silently
|
||||
# do the convertion and never raise a TypeError. To be able to properly
|
||||
# raise a TypeError, namescope is used to detect if this method is called
|
||||
# within a embedding_lookup op.
|
||||
# NOTE: This doesn't work in eager mode since op namescope is always cleared
|
||||
# in eager. This also breaks if user sets the name of embedding_lookup op
|
||||
# with something that doesn't contain str "embedding_lookup".
|
||||
#
|
||||
# TODO(chenkai): Find a more robust way to do this, which should not rely
|
||||
# on namescope.
|
||||
if 'embedding_lookup' in ops.get_name_scope():
|
||||
raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
|
||||
' ops is disallowed.')
|
||||
return array_ops.concat(var.variables, axis=0)
|
||||
|
||||
|
||||
# Register a conversion function which reads the value of the variable,
|
||||
# allowing instances of the class to be used as tensors.
|
||||
ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor)
|
||||
|
||||
|
||||
# Override the behavior of embedding_lookup(sharded_variable, ...)
|
||||
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
|
||||
def embedding_lookup(params,
|
||||
ids,
|
||||
partition_strategy='mod',
|
||||
name=None,
|
||||
validate_indices=True,
|
||||
max_norm=None):
|
||||
if isinstance(params, list):
|
||||
params = params[0]
|
||||
return embedding_ops.embedding_lookup(params.variables, ids,
|
||||
partition_strategy, name,
|
||||
validate_indices, max_norm)
|
||||
|
@ -24,8 +24,10 @@ from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
@ -462,6 +464,56 @@ class ShardedVariableTest(test.TestCase):
|
||||
checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies)
|
||||
self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))
|
||||
|
||||
def test_embedding_lookup(self):
|
||||
v = [
|
||||
variables_lib.Variable([[1., 2.], [3., 4.]]),
|
||||
variables_lib.Variable([[5., 6.], [7., 8.]]),
|
||||
variables_lib.Variable([[9., 10.]])
|
||||
]
|
||||
sv = sharded_variable.ShardedVariable(v)
|
||||
|
||||
@def_function.function
|
||||
def lookup():
|
||||
ids = constant_op.constant([0, 3, 4])
|
||||
return embedding_ops.embedding_lookup_v2(sv, ids)
|
||||
|
||||
@def_function.function
|
||||
def sparse_lookup():
|
||||
sp_ids = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
|
||||
values=[0, 3, 4, 1],
|
||||
dense_shape=[3, 3])
|
||||
return embedding_ops.embedding_lookup_sparse_v2(sv, sp_ids, None)
|
||||
|
||||
@def_function.function
|
||||
def safe_sparse_lookup():
|
||||
sp_ids = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
|
||||
values=[0, -1, 4, 1],
|
||||
dense_shape=[3, 3])
|
||||
sp_weights = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
|
||||
values=[1., 1., -1., 1.],
|
||||
dense_shape=[3, 3])
|
||||
return embedding_ops.safe_embedding_lookup_sparse_v2(
|
||||
sv, sp_ids, sp_weights)
|
||||
|
||||
# TODO(chenkai): Add safe_sparse_lookup to the list. Currently
|
||||
# ShardedVariable is converted to a tensor in safe_sparse_lookup.
|
||||
for func in [lookup, sparse_lookup]:
|
||||
num_gather_ops = 0
|
||||
for op in func.get_concrete_function().graph.get_operations():
|
||||
if op.type == 'ResourceGather':
|
||||
num_gather_ops += 1
|
||||
self.assertEqual(
|
||||
num_gather_ops, len(v), 'Number of ResourceGather op does not match'
|
||||
' expected, possibly due to ShardedVariable accidentally being'
|
||||
' converted to tensor in embedding_lookup ops.')
|
||||
|
||||
self.assertAllEqual(lookup(), [[1., 2.], [7., 8.], [9., 10.]])
|
||||
self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]])
|
||||
self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
Loading…
Reference in New Issue
Block a user