diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py index 8487581346b..e04f1f0281a 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py @@ -21,6 +21,7 @@ from __future__ import unicode_literals import abc import math +import typing from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union import six @@ -620,6 +621,29 @@ class TableConfig(object): self.combiner = combiner self.name = name + def __repr__(self): + # If using the default initializer, just print "None" for clarity. + initializer = self.initializer + + if isinstance(initializer, init_ops_v2.TruncatedNormal): + # PY2 type checking can't infer type of initializer even after if. + initializer = typing.cast(init_ops_v2.TruncatedNormal, initializer) + if (initializer.mean == 0.0 + and math.isclose(initializer.stddev, 1/math.sqrt(self.dim))): # pytype: disable=module-attr (math.isclose not in PY2) + initializer = None + + return ( + "TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, " + "initializer={initializer!r}, optimizer={optimizer!r}, " + "combiner={combiner!r}, name={name!r})".format( + vocabulary_size=self.vocabulary_size, + dim=self.dim, + initializer=initializer, + optimizer=self.optimizer, + combiner=self.combiner, + name=self.name,) + ) + @tf_export("tpu.experimental.embedding.FeatureConfig") class FeatureConfig(object): @@ -697,3 +721,13 @@ class FeatureConfig(object): self.table = table self.max_sequence_length = max_sequence_length self.name = name + + def __repr__(self): + return ( + "FeatureConfig(table={table!r}, " + "max_sequence_length={max_sequence_length!r}, name={name!r})" + .format( + table=self.table, + max_sequence_length=self.max_sequence_length, + name=self.name) + ) diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py index 14dfb32e075..48797b00009 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py @@ -60,6 +60,34 @@ class TPUEmbeddingOptimizerTest(parameterized.TestCase, test.TestCase): self.assertEqual(1., opt.clip_gradient_max) +class ConfigTest(test.TestCase): + + def test_table_config_repr(self): + table = tpu_embedding_v2_utils.TableConfig( + vocabulary_size=2, dim=4, initializer=None, + combiner='sum', name='table') + + self.assertEqual( + repr(table), + 'TableConfig(vocabulary_size=2, dim=4, initializer=None, ' + 'optimizer=None, combiner=\'sum\', name=\'table\')') + + def test_feature_config_repr(self): + table = tpu_embedding_v2_utils.TableConfig( + vocabulary_size=2, dim=4, initializer=None, + combiner='sum', name='table') + + feature_config = tpu_embedding_v2_utils.FeatureConfig( + table=table, name='feature') + + self.assertEqual( + repr(feature_config), + 'FeatureConfig(table=TableConfig(vocabulary_size=2, dim=4, ' + 'initializer=None, optimizer=None, combiner=\'sum\', name=\'table\'), ' + 'max_sequence_length=0, name=\'feature\')' + ) + + if __name__ == '__main__': v2_compat.enable_v2_behavior() test.main()