Add correctness test for ragged embeddings with the TPUEmbedding mid level API.
PiperOrigin-RevId: 354581897 Change-Id: Ifbbd351e6879d2ac0379520730510866e4bedf1f
This commit is contained in:
parent
597741a69d
commit
ef214a5a9d
@ -40,6 +40,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import init_ops_v2
|
from tensorflow.python.ops import init_ops_v2
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.tpu import tpu_embedding_v2
|
from tensorflow.python.tpu import tpu_embedding_v2
|
||||||
from tensorflow.python.tpu import tpu_embedding_v2_utils
|
from tensorflow.python.tpu import tpu_embedding_v2_utils
|
||||||
@ -149,12 +150,17 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
|||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
*itertools.product(
|
*itertools.product(
|
||||||
['sgd', 'adagrad', 'adam'],
|
['sgd', 'adagrad', 'adam'],
|
||||||
|
[True, False],
|
||||||
[True, False]))
|
[True, False]))
|
||||||
def test_embedding(self, optimizer_name, training):
|
def test_embedding(self, optimizer_name, training, sparse):
|
||||||
strategy, mid_level_api, optimizer = (
|
strategy, mid_level_api, optimizer = (
|
||||||
self._create_strategy_and_mid_level(optimizer_name))
|
self._create_strategy_and_mid_level(optimizer_name))
|
||||||
|
|
||||||
|
if sparse:
|
||||||
dataset = self._create_sparse_dataset(strategy)
|
dataset = self._create_sparse_dataset(strategy)
|
||||||
|
else:
|
||||||
|
dataset = self._create_ragged_dataset(strategy)
|
||||||
|
|
||||||
dist = strategy.experimental_distribute_dataset(
|
dist = strategy.experimental_distribute_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
options=distribute_lib.InputOptions(
|
options=distribute_lib.InputOptions(
|
||||||
@ -209,8 +215,7 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
|||||||
feature_config=self.feature_config,
|
feature_config=self.feature_config,
|
||||||
optimizer=optimizer)
|
optimizer=optimizer)
|
||||||
|
|
||||||
def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
|
def _create_sparse_data(self, include_weights, weight=0.5):
|
||||||
# Create dataset for enqueue operation
|
|
||||||
sparse_features = (
|
sparse_features = (
|
||||||
sparse_tensor.SparseTensor(
|
sparse_tensor.SparseTensor(
|
||||||
indices=self.feature_watched_indices,
|
indices=self.feature_watched_indices,
|
||||||
@ -234,6 +239,11 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
|||||||
values=values,
|
values=values,
|
||||||
dense_shape=sparse.dense_shape))
|
dense_shape=sparse.dense_shape))
|
||||||
sparse_features = (sparse_features, tuple(weights))
|
sparse_features = (sparse_features, tuple(weights))
|
||||||
|
return sparse_features
|
||||||
|
|
||||||
|
def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
|
||||||
|
# Create dataset for enqueue operation
|
||||||
|
sparse_features = self._create_sparse_data(include_weights, weight)
|
||||||
|
|
||||||
dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
|
dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
|
||||||
|
|
||||||
@ -241,6 +251,18 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
|||||||
return dataset.unbatch().repeat().batch(
|
return dataset.unbatch().repeat().batch(
|
||||||
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
|
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
|
||||||
|
|
||||||
|
def _create_ragged_dataset(self, strategy, include_weights=False, weight=0.5):
|
||||||
|
# Create dataset for enqueue operation
|
||||||
|
sparse_features = self._create_sparse_data(include_weights, weight)
|
||||||
|
ragged_features = nest.map_structure(ragged_tensor.RaggedTensor.from_sparse,
|
||||||
|
sparse_features)
|
||||||
|
|
||||||
|
dataset = dataset_ops.DatasetV2.from_tensors(ragged_features)
|
||||||
|
|
||||||
|
# Data is batched to self.data_batch_size, rebatch to global batch size.
|
||||||
|
return dataset.unbatch().repeat().batch(
|
||||||
|
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
|
||||||
|
|
||||||
def _create_dense_input_fn(self, strategy, include_weights=False, weight=0.5):
|
def _create_dense_input_fn(self, strategy, include_weights=False, weight=0.5):
|
||||||
|
|
||||||
def input_fn(ctx):
|
def input_fn(ctx):
|
||||||
@ -448,7 +470,8 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
|||||||
numpy_users[self.feature_friends_values[-2:]]))
|
numpy_users[self.feature_friends_values[-2:]]))
|
||||||
self.assertAllClose(shard0, golden)
|
self.assertAllClose(shard0, golden)
|
||||||
|
|
||||||
def test_sequence_embeddings(self):
|
@parameterized.parameters([True, False])
|
||||||
|
def test_sequence_embeddings(self, sparse):
|
||||||
feature_config = (
|
feature_config = (
|
||||||
tpu_embedding_v2_utils.FeatureConfig(
|
tpu_embedding_v2_utils.FeatureConfig(
|
||||||
table=self.table_video, name='watched',
|
table=self.table_video, name='watched',
|
||||||
@ -470,7 +493,10 @@ class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
|
|||||||
# results in data where the shape of the sparse tensor is a tensor which we
|
# results in data where the shape of the sparse tensor is a tensor which we
|
||||||
# can't tell the shape of at tracing time.
|
# can't tell the shape of at tracing time.
|
||||||
mid_level.build(self.batch_size)
|
mid_level.build(self.batch_size)
|
||||||
|
if sparse:
|
||||||
dataset = self._create_sparse_dataset(strategy)
|
dataset = self._create_sparse_dataset(strategy)
|
||||||
|
else:
|
||||||
|
dataset = self._create_ragged_dataset(strategy)
|
||||||
data = next(iter(strategy.experimental_distribute_dataset(
|
data = next(iter(strategy.experimental_distribute_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
options=distribute_lib.InputOptions(
|
options=distribute_lib.InputOptions(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user