Update TPUEmbedding api to ensure that the table order only depends on the order of the tables in feature_config and not on the order of the table objects in a python set.

PiperOrigin-RevId: 357015292
Change-Id: I2e61ef99b5c8d4e1c9e692c0484f0f56e1b08c88
This commit is contained in:
Bruce Fontaine 2021-02-11 11:07:16 -08:00 committed by TensorFlower Gardener
parent 8c81fcdae9
commit c5fe250254
2 changed files with 34 additions and 3 deletions

View File

@ -297,9 +297,14 @@ class TPUEmbedding(tracking.AutoTrackable):
# Thus we must fix a common order to tables and ensure they have unique
# names.
# Set table order here
self._table_config = list(
{feature.table for feature in nest.flatten(feature_config)})
# Set table order here to the order of the first occurence of the table in a
# feature provided by the user. The order of this struct must be fixed
# to provide the user with deterministic behavior over multiple
# instantiations.
self._table_config = []
for feature in nest.flatten(feature_config):
if feature.table not in self._table_config:
self._table_config.append(feature.table)
# Ensure tables have unique names. Also error check the optimizer as we
# specifically don't do that in the TableConfig class to allow high level

View File

@ -1273,6 +1273,32 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
# not matter.
mid_level_api.build(self.batch_size)
def test_same_config_different_instantiations(self):
num_tables = 30
table_dim = np.random.randint(1, 128, size=[num_tables])
table_vocab_size = np.random.randint(100, 1000, size=[num_tables])
table_names = ['table{}'.format(i) for i in range(num_tables)]
table_data = list(zip(table_dim, table_vocab_size, table_names))
strategy = self._get_strategy()
def tpu_embedding_config():
feature_configs = []
for dim, vocab, name in table_data:
feature_configs.append(tpu_embedding_v2_utils.FeatureConfig(
table=tpu_embedding_v2_utils.TableConfig(
vocabulary_size=int(vocab), dim=int(dim),
initializer=init_ops_v2.Zeros(), name=name)))
optimizer = tpu_embedding_v2_utils.Adagrad(
learning_rate=0.1)
with strategy.scope():
mid_level_api = tpu_embedding_v2.TPUEmbedding(
feature_config=feature_configs,
optimizer=optimizer)
mid_level_api._batch_size = 128
return mid_level_api._create_config_proto()
self.assertProtoEquals(tpu_embedding_config(), tpu_embedding_config())
def _unpack(strategy, per_replica_output):
per_replica_output = strategy.experimental_local_results(per_replica_output)