Add correctness test for ragged embeddings with the TPUEmbedding mid level API.

PiperOrigin-RevId: 354581897
Change-Id: Ifbbd351e6879d2ac0379520730510866e4bedf1f
This commit is contained in:
Bruce Fontaine 2021-01-29 12:03:54 -08:00 committed by TensorFlower Gardener
parent 597741a69d
commit ef214a5a9d

View File

@ -40,6 +40,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
@ -149,12 +150,17 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
@parameterized.parameters(
*itertools.product(
['sgd', 'adagrad', 'adam'],
[True, False],
[True, False]))
def test_embedding(self, optimizer_name, training):
def test_embedding(self, optimizer_name, training, sparse):
strategy, mid_level_api, optimizer = (
self._create_strategy_and_mid_level(optimizer_name))
dataset = self._create_sparse_dataset(strategy)
if sparse:
dataset = self._create_sparse_dataset(strategy)
else:
dataset = self._create_ragged_dataset(strategy)
dist = strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
@ -209,8 +215,7 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
feature_config=self.feature_config,
optimizer=optimizer)
def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
# Create dataset for enqueue operation
def _create_sparse_data(self, include_weights, weight=0.5):
sparse_features = (
sparse_tensor.SparseTensor(
indices=self.feature_watched_indices,
@ -234,6 +239,11 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
values=values,
dense_shape=sparse.dense_shape))
sparse_features = (sparse_features, tuple(weights))
return sparse_features
def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
# Create dataset for enqueue operation
sparse_features = self._create_sparse_data(include_weights, weight)
dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
@ -241,6 +251,18 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
return dataset.unbatch().repeat().batch(
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
def _create_ragged_dataset(self, strategy, include_weights=False, weight=0.5):
# Create dataset for enqueue operation
sparse_features = self._create_sparse_data(include_weights, weight)
ragged_features = nest.map_structure(ragged_tensor.RaggedTensor.from_sparse,
sparse_features)
dataset = dataset_ops.DatasetV2.from_tensors(ragged_features)
# Data is batched to self.data_batch_size, rebatch to global batch size.
return dataset.unbatch().repeat().batch(
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
def _create_dense_input_fn(self, strategy, include_weights=False, weight=0.5):
def input_fn(ctx):
@ -448,7 +470,8 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
numpy_users[self.feature_friends_values[-2:]]))
self.assertAllClose(shard0, golden)
def test_sequence_embeddings(self):
@parameterized.parameters([True, False])
def test_sequence_embeddings(self, sparse):
feature_config = (
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='watched',
@ -470,7 +493,10 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
# results in data where the shape of the sparse tensor is a tensor which we
# can't tell the shape of at tracing time.
mid_level.build(self.batch_size)
dataset = self._create_sparse_dataset(strategy)
if sparse:
dataset = self._create_sparse_dataset(strategy)
else:
dataset = self._create_ragged_dataset(strategy)
data = next(iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(