Add sequence support for TPU Embeddings class on CPU.

PiperOrigin-RevId: 350838762
Change-Id: I8d8e334ffaa9339012adfb21ecc5ba87fb34c963
This commit is contained in:
Bruce Fontaine 2021-01-08 14:36:31 -08:00 committed by TensorFlower Gardener
parent 64bcff84a4
commit 5f5041fc9c
2 changed files with 91 additions and 15 deletions

View File

@ -41,6 +41,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
@ -1532,8 +1533,6 @@ def cpu_embedding_lookup(inputs, weights, tables, feature_config):
for inp, weight, (path, feature) in zip(
flat_inputs, flat_weights, flat_features):
table = tables[feature.table]
if feature.max_sequence_length > 0:
raise ValueError("Sequence features unsupported at this time.")
if weight is not None:
if isinstance(inp, ops.Tensor):
@ -1543,17 +1542,50 @@ def cpu_embedding_lookup(inputs, weights, tables, feature_config):
raise ValueError(
"Weight for {} is of type {} but it does not match type of the "
"input which is {}.".format(path, type(weight), type(inp)))
elif feature.max_sequence_length > 0:
raise ValueError("Weight specified for {}, but this is a sequence "
"feature.".format(path))
if isinstance(inp, ops.Tensor):
if feature.max_sequence_length > 0:
raise ValueError("Feature {} is a sequence feature but a dense tensor "
"was passed.".format(path))
outputs.append(embedding_ops.embedding_lookup_v2(table, inp))
elif isinstance(inp, sparse_tensor.SparseTensor):
outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2(
table, inp, sparse_weights=weight, combiner=feature.table.combiner))
if feature.max_sequence_length > 0:
batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64)
sparse_shape = array_ops.concat(
[batch_size, feature.max_sequence_length], axis=0)
# TPU Embedding truncates sequences to max_sequence_length, and if we
# don't truncate, scatter_nd will error out if the index was out of
# bounds.
truncated_inp = sparse_ops.sparse_slice(inp, start=[0, 0],
size=sparse_shape)
dense_output_shape = array_ops.concat(
[batch_size, feature.max_sequence_length, feature.table.dim],
axis=0)
outputs.append(
array_ops.scatter_nd(
inp.indices, array_ops.gather(table, truncated_inp.values),
dense_output_shape))
else:
outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2(
table, inp, sparse_weights=weight, combiner=feature.table.combiner))
elif isinstance(inp, ragged_tensor.RaggedTensor):
outputs.append(_ragged_embedding_lookup_with_reduce(
table, inp, weight, feature.table.combiner))
if feature.max_sequence_length > 0:
batch_size = inp.shape[0]
dense_output_shape = [
batch_size, feature.max_sequence_length, feature.table.dim]
ragged_lookup = embedding_ops.embedding_lookup_v2(table, inp)
# Unlike scatter_nd, RaggedTensor.to_tensor truncates to the given
# shape.
outputs.append(ragged_lookup.to_tensor(shape=dense_output_shape))
else:
outputs.append(_ragged_embedding_lookup_with_reduce(
table, inp, weight, feature.table.combiner))
else:
raise ValueError("Input {} is type {}. Tensor, SparseTensor or "

View File

@ -277,7 +277,17 @@ class CPUEmbeddingTest(test.TestCase):
tables=mid_level.embedding_tables,
feature_config=self.feature_config)
def test_cpu_sequence_lookup(self):
def _numpy_sequence_lookup(
self, table, indices, values, batch_size, max_sequence_length, dim):
# First we gather the values
lookup = table[values]
# Then we scatter them into the result array.
scatter_result = np.zeros([batch_size, max_sequence_length, dim])
for i, index in enumerate(indices):
scatter_result[index[0], index[1], :] = lookup[i]
return scatter_result
def test_cpu_sequence_lookup_sparse(self):
feature_config = (
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='watched', max_sequence_length=2),)
@ -285,14 +295,48 @@ class CPUEmbeddingTest(test.TestCase):
mid_level = tpu_embedding_v2.TPUEmbedding(
feature_config=feature_config,
optimizer=optimizer)
features = tuple(self._get_sparse_tensors()[:1])
with self.assertRaisesRegex(
ValueError, 'Sequence features unsupported at this time.'):
tpu_embedding_v2.cpu_embedding_lookup(
features,
weights=None,
tables=mid_level.embedding_tables,
feature_config=feature_config)
features = self._get_sparse_tensors()[:1]
result = tpu_embedding_v2.cpu_embedding_lookup(
features,
weights=None,
tables=mid_level.embedding_tables,
feature_config=feature_config)
golden = self._numpy_sequence_lookup(
mid_level.embedding_tables[self.table_video].numpy(),
features[0].indices.numpy(),
features[0].values.numpy(),
self.data_batch_size,
feature_config[0].max_sequence_length,
self.table_video.dim)
self.assertAllClose(result[0], golden)
def test_cpu_sequence_lookup_ragged(self):
feature_config = (
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='watched', max_sequence_length=2),)
optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
mid_level = tpu_embedding_v2.TPUEmbedding(
feature_config=feature_config,
optimizer=optimizer)
features = self._get_ragged_tensors()[:1]
result = tpu_embedding_v2.cpu_embedding_lookup(
features,
weights=None,
tables=mid_level.embedding_tables,
feature_config=feature_config)
sparse_ver = features[0].to_sparse()
golden = self._numpy_sequence_lookup(
mid_level.embedding_tables[self.table_video].numpy(),
sparse_ver.indices.numpy(),
sparse_ver.values.numpy(),
self.data_batch_size,
feature_config[0].max_sequence_length,
self.table_video.dim)
self.assertAllClose(result[0], golden)
def test_cpu_no_optimizer(self):
feature_config = (