Add sequence support for TPU Embeddings class on CPU.
PiperOrigin-RevId: 350838762 Change-Id: I8d8e334ffaa9339012adfb21ecc5ba87fb34c963
This commit is contained in:
parent
64bcff84a4
commit
5f5041fc9c
@ -41,6 +41,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -1532,8 +1533,6 @@ def cpu_embedding_lookup(inputs, weights, tables, feature_config):
|
||||
for inp, weight, (path, feature) in zip(
|
||||
flat_inputs, flat_weights, flat_features):
|
||||
table = tables[feature.table]
|
||||
if feature.max_sequence_length > 0:
|
||||
raise ValueError("Sequence features unsupported at this time.")
|
||||
|
||||
if weight is not None:
|
||||
if isinstance(inp, ops.Tensor):
|
||||
@ -1543,17 +1542,50 @@ def cpu_embedding_lookup(inputs, weights, tables, feature_config):
|
||||
raise ValueError(
|
||||
"Weight for {} is of type {} but it does not match type of the "
|
||||
"input which is {}.".format(path, type(weight), type(inp)))
|
||||
elif feature.max_sequence_length > 0:
|
||||
raise ValueError("Weight specified for {}, but this is a sequence "
|
||||
"feature.".format(path))
|
||||
|
||||
if isinstance(inp, ops.Tensor):
|
||||
if feature.max_sequence_length > 0:
|
||||
raise ValueError("Feature {} is a sequence feature but a dense tensor "
|
||||
"was passed.".format(path))
|
||||
outputs.append(embedding_ops.embedding_lookup_v2(table, inp))
|
||||
|
||||
elif isinstance(inp, sparse_tensor.SparseTensor):
|
||||
outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2(
|
||||
table, inp, sparse_weights=weight, combiner=feature.table.combiner))
|
||||
if feature.max_sequence_length > 0:
|
||||
batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64)
|
||||
sparse_shape = array_ops.concat(
|
||||
[batch_size, feature.max_sequence_length], axis=0)
|
||||
# TPU Embedding truncates sequences to max_sequence_length, and if we
|
||||
# don't truncate, scatter_nd will error out if the index was out of
|
||||
# bounds.
|
||||
truncated_inp = sparse_ops.sparse_slice(inp, start=[0, 0],
|
||||
size=sparse_shape)
|
||||
|
||||
dense_output_shape = array_ops.concat(
|
||||
[batch_size, feature.max_sequence_length, feature.table.dim],
|
||||
axis=0)
|
||||
outputs.append(
|
||||
array_ops.scatter_nd(
|
||||
inp.indices, array_ops.gather(table, truncated_inp.values),
|
||||
dense_output_shape))
|
||||
else:
|
||||
outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2(
|
||||
table, inp, sparse_weights=weight, combiner=feature.table.combiner))
|
||||
|
||||
elif isinstance(inp, ragged_tensor.RaggedTensor):
|
||||
outputs.append(_ragged_embedding_lookup_with_reduce(
|
||||
table, inp, weight, feature.table.combiner))
|
||||
if feature.max_sequence_length > 0:
|
||||
batch_size = inp.shape[0]
|
||||
dense_output_shape = [
|
||||
batch_size, feature.max_sequence_length, feature.table.dim]
|
||||
ragged_lookup = embedding_ops.embedding_lookup_v2(table, inp)
|
||||
# Unlike scatter_nd, RaggedTensor.to_tensor truncates to the given
|
||||
# shape.
|
||||
outputs.append(ragged_lookup.to_tensor(shape=dense_output_shape))
|
||||
else:
|
||||
outputs.append(_ragged_embedding_lookup_with_reduce(
|
||||
table, inp, weight, feature.table.combiner))
|
||||
|
||||
else:
|
||||
raise ValueError("Input {} is type {}. Tensor, SparseTensor or "
|
||||
|
@ -277,7 +277,17 @@ class CPUEmbeddingTest(test.TestCase):
|
||||
tables=mid_level.embedding_tables,
|
||||
feature_config=self.feature_config)
|
||||
|
||||
def test_cpu_sequence_lookup(self):
|
||||
def _numpy_sequence_lookup(
|
||||
self, table, indices, values, batch_size, max_sequence_length, dim):
|
||||
# First we gather the values
|
||||
lookup = table[values]
|
||||
# Then we scatter them into the result array.
|
||||
scatter_result = np.zeros([batch_size, max_sequence_length, dim])
|
||||
for i, index in enumerate(indices):
|
||||
scatter_result[index[0], index[1], :] = lookup[i]
|
||||
return scatter_result
|
||||
|
||||
def test_cpu_sequence_lookup_sparse(self):
|
||||
feature_config = (
|
||||
tpu_embedding_v2_utils.FeatureConfig(
|
||||
table=self.table_video, name='watched', max_sequence_length=2),)
|
||||
@ -285,14 +295,48 @@ class CPUEmbeddingTest(test.TestCase):
|
||||
mid_level = tpu_embedding_v2.TPUEmbedding(
|
||||
feature_config=feature_config,
|
||||
optimizer=optimizer)
|
||||
features = tuple(self._get_sparse_tensors()[:1])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Sequence features unsupported at this time.'):
|
||||
tpu_embedding_v2.cpu_embedding_lookup(
|
||||
features,
|
||||
weights=None,
|
||||
tables=mid_level.embedding_tables,
|
||||
feature_config=feature_config)
|
||||
features = self._get_sparse_tensors()[:1]
|
||||
result = tpu_embedding_v2.cpu_embedding_lookup(
|
||||
features,
|
||||
weights=None,
|
||||
tables=mid_level.embedding_tables,
|
||||
feature_config=feature_config)
|
||||
|
||||
golden = self._numpy_sequence_lookup(
|
||||
mid_level.embedding_tables[self.table_video].numpy(),
|
||||
features[0].indices.numpy(),
|
||||
features[0].values.numpy(),
|
||||
self.data_batch_size,
|
||||
feature_config[0].max_sequence_length,
|
||||
self.table_video.dim)
|
||||
|
||||
self.assertAllClose(result[0], golden)
|
||||
|
||||
def test_cpu_sequence_lookup_ragged(self):
|
||||
feature_config = (
|
||||
tpu_embedding_v2_utils.FeatureConfig(
|
||||
table=self.table_video, name='watched', max_sequence_length=2),)
|
||||
optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
|
||||
mid_level = tpu_embedding_v2.TPUEmbedding(
|
||||
feature_config=feature_config,
|
||||
optimizer=optimizer)
|
||||
features = self._get_ragged_tensors()[:1]
|
||||
result = tpu_embedding_v2.cpu_embedding_lookup(
|
||||
features,
|
||||
weights=None,
|
||||
tables=mid_level.embedding_tables,
|
||||
feature_config=feature_config)
|
||||
|
||||
sparse_ver = features[0].to_sparse()
|
||||
golden = self._numpy_sequence_lookup(
|
||||
mid_level.embedding_tables[self.table_video].numpy(),
|
||||
sparse_ver.indices.numpy(),
|
||||
sparse_ver.values.numpy(),
|
||||
self.data_batch_size,
|
||||
feature_config[0].max_sequence_length,
|
||||
self.table_video.dim)
|
||||
|
||||
self.assertAllClose(result[0], golden)
|
||||
|
||||
def test_cpu_no_optimizer(self):
|
||||
feature_config = (
|
||||
|
Loading…
Reference in New Issue
Block a user