Add correctness test for ragged embeddings with the TPUEmbedding mid level API.
PiperOrigin-RevId: 354581897 Change-Id: Ifbbd351e6879d2ac0379520730510866e4bedf1f
This commit is contained in:
parent
597741a69d
commit
ef214a5a9d
@ -40,6 +40,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu_embedding_v2
|
||||
from tensorflow.python.tpu import tpu_embedding_v2_utils
|
||||
@ -149,12 +150,17 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
||||
@parameterized.parameters(
|
||||
*itertools.product(
|
||||
['sgd', 'adagrad', 'adam'],
|
||||
[True, False],
|
||||
[True, False]))
|
||||
def test_embedding(self, optimizer_name, training):
|
||||
def test_embedding(self, optimizer_name, training, sparse):
|
||||
strategy, mid_level_api, optimizer = (
|
||||
self._create_strategy_and_mid_level(optimizer_name))
|
||||
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
if sparse:
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
else:
|
||||
dataset = self._create_ragged_dataset(strategy)
|
||||
|
||||
dist = strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
@ -209,8 +215,7 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
||||
feature_config=self.feature_config,
|
||||
optimizer=optimizer)
|
||||
|
||||
def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
|
||||
# Create dataset for enqueue operation
|
||||
def _create_sparse_data(self, include_weights, weight=0.5):
|
||||
sparse_features = (
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=self.feature_watched_indices,
|
||||
@ -234,6 +239,11 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
||||
values=values,
|
||||
dense_shape=sparse.dense_shape))
|
||||
sparse_features = (sparse_features, tuple(weights))
|
||||
return sparse_features
|
||||
|
||||
def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
|
||||
# Create dataset for enqueue operation
|
||||
sparse_features = self._create_sparse_data(include_weights, weight)
|
||||
|
||||
dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
|
||||
|
||||
@ -241,6 +251,18 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
||||
return dataset.unbatch().repeat().batch(
|
||||
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
|
||||
|
||||
def _create_ragged_dataset(self, strategy, include_weights=False, weight=0.5):
|
||||
# Create dataset for enqueue operation
|
||||
sparse_features = self._create_sparse_data(include_weights, weight)
|
||||
ragged_features = nest.map_structure(ragged_tensor.RaggedTensor.from_sparse,
|
||||
sparse_features)
|
||||
|
||||
dataset = dataset_ops.DatasetV2.from_tensors(ragged_features)
|
||||
|
||||
# Data is batched to self.data_batch_size, rebatch to global batch size.
|
||||
return dataset.unbatch().repeat().batch(
|
||||
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
|
||||
|
||||
def _create_dense_input_fn(self, strategy, include_weights=False, weight=0.5):
|
||||
|
||||
def input_fn(ctx):
|
||||
@ -448,7 +470,8 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
||||
numpy_users[self.feature_friends_values[-2:]]))
|
||||
self.assertAllClose(shard0, golden)
|
||||
|
||||
def test_sequence_embeddings(self):
|
||||
@parameterized.parameters([True, False])
|
||||
def test_sequence_embeddings(self, sparse):
|
||||
feature_config = (
|
||||
tpu_embedding_v2_utils.FeatureConfig(
|
||||
table=self.table_video, name='watched',
|
||||
@ -470,7 +493,10 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
||||
# results in data where the shape of the sparse tensor is a tensor which we
|
||||
# can't tell the shape of at tracing time.
|
||||
mid_level.build(self.batch_size)
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
if sparse:
|
||||
dataset = self._create_sparse_dataset(strategy)
|
||||
else:
|
||||
dataset = self._create_ragged_dataset(strategy)
|
||||
data = next(iter(strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
options=distribute_lib.InputOptions(
|
||||
|
Loading…
x
Reference in New Issue
Block a user