Update TPUEmbedding api to ensure that the table order only depends on the order of the tables in feature_config and not on the order of the table objects in a python set.
PiperOrigin-RevId: 357015292 Change-Id: I2e61ef99b5c8d4e1c9e692c0484f0f56e1b08c88
This commit is contained in:
parent
8c81fcdae9
commit
c5fe250254
@ -297,9 +297,14 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
# Thus we must fix a common order to tables and ensure they have unique
|
||||
# names.
|
||||
|
||||
# Set table order here
|
||||
self._table_config = list(
|
||||
{feature.table for feature in nest.flatten(feature_config)})
|
||||
# Set table order here to the order of the first occurence of the table in a
|
||||
# feature provided by the user. The order of this struct must be fixed
|
||||
# to provide the user with deterministic behavior over multiple
|
||||
# instantiations.
|
||||
self._table_config = []
|
||||
for feature in nest.flatten(feature_config):
|
||||
if feature.table not in self._table_config:
|
||||
self._table_config.append(feature.table)
|
||||
|
||||
# Ensure tables have unique names. Also error check the optimizer as we
|
||||
# specifically don't do that in the TableConfig class to allow high level
|
||||
|
@ -1273,6 +1273,32 @@ class TPUEmbeddingTest(parameterized.TestCase, test.TestCase):
|
||||
# not matter.
|
||||
mid_level_api.build(self.batch_size)
|
||||
|
||||
def test_same_config_different_instantiations(self):
|
||||
num_tables = 30
|
||||
table_dim = np.random.randint(1, 128, size=[num_tables])
|
||||
table_vocab_size = np.random.randint(100, 1000, size=[num_tables])
|
||||
table_names = ['table{}'.format(i) for i in range(num_tables)]
|
||||
table_data = list(zip(table_dim, table_vocab_size, table_names))
|
||||
strategy = self._get_strategy()
|
||||
|
||||
def tpu_embedding_config():
|
||||
feature_configs = []
|
||||
for dim, vocab, name in table_data:
|
||||
feature_configs.append(tpu_embedding_v2_utils.FeatureConfig(
|
||||
table=tpu_embedding_v2_utils.TableConfig(
|
||||
vocabulary_size=int(vocab), dim=int(dim),
|
||||
initializer=init_ops_v2.Zeros(), name=name)))
|
||||
optimizer = tpu_embedding_v2_utils.Adagrad(
|
||||
learning_rate=0.1)
|
||||
with strategy.scope():
|
||||
mid_level_api = tpu_embedding_v2.TPUEmbedding(
|
||||
feature_config=feature_configs,
|
||||
optimizer=optimizer)
|
||||
mid_level_api._batch_size = 128
|
||||
return mid_level_api._create_config_proto()
|
||||
|
||||
self.assertProtoEquals(tpu_embedding_config(), tpu_embedding_config())
|
||||
|
||||
|
||||
def _unpack(strategy, per_replica_output):
|
||||
per_replica_output = strategy.experimental_local_results(per_replica_output)
|
||||
|
Loading…
Reference in New Issue
Block a user