From 5f5041fc9c8c616eb8a14de173f5ec3fc27dfef6 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Fri, 8 Jan 2021 14:36:31 -0800 Subject: [PATCH] Add sequence support for TPU Embeddings class on CPU. PiperOrigin-RevId: 350838762 Change-Id: I8d8e334ffaa9339012adfb21ecc5ba87fb34c963 --- tensorflow/python/tpu/tpu_embedding_v2.py | 44 +++++++++++-- .../python/tpu/tpu_embedding_v2_cpu_test.py | 62 ++++++++++++++++--- 2 files changed, 91 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index ec2ad9bb2dc..155f865a14c 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -41,6 +41,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops.ragged import ragged_tensor @@ -1532,8 +1533,6 @@ def cpu_embedding_lookup(inputs, weights, tables, feature_config): for inp, weight, (path, feature) in zip( flat_inputs, flat_weights, flat_features): table = tables[feature.table] - if feature.max_sequence_length > 0: - raise ValueError("Sequence features unsupported at this time.") if weight is not None: if isinstance(inp, ops.Tensor): @@ -1543,17 +1542,50 @@ def cpu_embedding_lookup(inputs, weights, tables, feature_config): raise ValueError( "Weight for {} is of type {} but it does not match type of the " "input which is {}.".format(path, type(weight), type(inp))) + elif feature.max_sequence_length > 0: + raise ValueError("Weight specified for {}, but this is a sequence " + "feature.".format(path)) if isinstance(inp, ops.Tensor): + if feature.max_sequence_length > 0: + raise ValueError("Feature {} is a sequence feature but a dense tensor " + "was passed.".format(path)) outputs.append(embedding_ops.embedding_lookup_v2(table, inp)) elif isinstance(inp, sparse_tensor.SparseTensor): - outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2( - table, inp, sparse_weights=weight, combiner=feature.table.combiner)) + if feature.max_sequence_length > 0: + batch_size = math_ops.cast(array_ops.shape(inp)[0], dtype=dtypes.int64) + sparse_shape = array_ops.concat( + [batch_size, feature.max_sequence_length], axis=0) + # TPU Embedding truncates sequences to max_sequence_length, and if we + # don't truncate, scatter_nd will error out if the index was out of + # bounds. + truncated_inp = sparse_ops.sparse_slice(inp, start=[0, 0], + size=sparse_shape) + + dense_output_shape = array_ops.concat( + [batch_size, feature.max_sequence_length, feature.table.dim], + axis=0) + outputs.append( + array_ops.scatter_nd( + inp.indices, array_ops.gather(table, truncated_inp.values), + dense_output_shape)) + else: + outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2( + table, inp, sparse_weights=weight, combiner=feature.table.combiner)) elif isinstance(inp, ragged_tensor.RaggedTensor): - outputs.append(_ragged_embedding_lookup_with_reduce( - table, inp, weight, feature.table.combiner)) + if feature.max_sequence_length > 0: + batch_size = inp.shape[0] + dense_output_shape = [ + batch_size, feature.max_sequence_length, feature.table.dim] + ragged_lookup = embedding_ops.embedding_lookup_v2(table, inp) + # Unlike scatter_nd, RaggedTensor.to_tensor truncates to the given + # shape. + outputs.append(ragged_lookup.to_tensor(shape=dense_output_shape)) + else: + outputs.append(_ragged_embedding_lookup_with_reduce( + table, inp, weight, feature.table.combiner)) else: raise ValueError("Input {} is type {}. Tensor, SparseTensor or " diff --git a/tensorflow/python/tpu/tpu_embedding_v2_cpu_test.py b/tensorflow/python/tpu/tpu_embedding_v2_cpu_test.py index fa1e843179f..cc98b0d6b09 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_cpu_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_cpu_test.py @@ -277,7 +277,17 @@ class CPUEmbeddingTest(test.TestCase): tables=mid_level.embedding_tables, feature_config=self.feature_config) - def test_cpu_sequence_lookup(self): + def _numpy_sequence_lookup( + self, table, indices, values, batch_size, max_sequence_length, dim): + # First we gather the values + lookup = table[values] + # Then we scatter them into the result array. + scatter_result = np.zeros([batch_size, max_sequence_length, dim]) + for i, index in enumerate(indices): + scatter_result[index[0], index[1], :] = lookup[i] + return scatter_result + + def test_cpu_sequence_lookup_sparse(self): feature_config = ( tpu_embedding_v2_utils.FeatureConfig( table=self.table_video, name='watched', max_sequence_length=2),) @@ -285,14 +295,48 @@ class CPUEmbeddingTest(test.TestCase): mid_level = tpu_embedding_v2.TPUEmbedding( feature_config=feature_config, optimizer=optimizer) - features = tuple(self._get_sparse_tensors()[:1]) - with self.assertRaisesRegex( - ValueError, 'Sequence features unsupported at this time.'): - tpu_embedding_v2.cpu_embedding_lookup( - features, - weights=None, - tables=mid_level.embedding_tables, - feature_config=feature_config) + features = self._get_sparse_tensors()[:1] + result = tpu_embedding_v2.cpu_embedding_lookup( + features, + weights=None, + tables=mid_level.embedding_tables, + feature_config=feature_config) + + golden = self._numpy_sequence_lookup( + mid_level.embedding_tables[self.table_video].numpy(), + features[0].indices.numpy(), + features[0].values.numpy(), + self.data_batch_size, + feature_config[0].max_sequence_length, + self.table_video.dim) + + self.assertAllClose(result[0], golden) + + def test_cpu_sequence_lookup_ragged(self): + feature_config = ( + tpu_embedding_v2_utils.FeatureConfig( + table=self.table_video, name='watched', max_sequence_length=2),) + optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) + mid_level = tpu_embedding_v2.TPUEmbedding( + feature_config=feature_config, + optimizer=optimizer) + features = self._get_ragged_tensors()[:1] + result = tpu_embedding_v2.cpu_embedding_lookup( + features, + weights=None, + tables=mid_level.embedding_tables, + feature_config=feature_config) + + sparse_ver = features[0].to_sparse() + golden = self._numpy_sequence_lookup( + mid_level.embedding_tables[self.table_video].numpy(), + sparse_ver.indices.numpy(), + sparse_ver.values.numpy(), + self.data_batch_size, + feature_config[0].max_sequence_length, + self.table_video.dim) + + self.assertAllClose(result[0], golden) def test_cpu_no_optimizer(self): feature_config = (