Add sensible __repr__ implementation for TableConfig, FeatureConfig.

PiperOrigin-RevId: 335060379
Change-Id: I456fae3f07fd74b08db8f069ab817ac18f458fb7
This commit is contained in:
Revan Sopher 2020-10-02 11:02:37 -07:00 committed by TensorFlower Gardener
parent 4cfe0b124f
commit 12833639d5
2 changed files with 62 additions and 0 deletions

View File

@ -21,6 +21,7 @@ from __future__ import unicode_literals
import abc
import math
import typing
from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
import six
@ -620,6 +621,29 @@ class TableConfig(object):
self.combiner = combiner
self.name = name
def __repr__(self):
# If using the default initializer, just print "None" for clarity.
initializer = self.initializer
if isinstance(initializer, init_ops_v2.TruncatedNormal):
# PY2 type checking can't infer type of initializer even after if.
initializer = typing.cast(init_ops_v2.TruncatedNormal, initializer)
if (initializer.mean == 0.0
and math.isclose(initializer.stddev, 1/math.sqrt(self.dim))): # pytype: disable=module-attr (math.isclose not in PY2)
initializer = None
return (
"TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, "
"initializer={initializer!r}, optimizer={optimizer!r}, "
"combiner={combiner!r}, name={name!r})".format(
vocabulary_size=self.vocabulary_size,
dim=self.dim,
initializer=initializer,
optimizer=self.optimizer,
combiner=self.combiner,
name=self.name,)
)
@tf_export("tpu.experimental.embedding.FeatureConfig")
class FeatureConfig(object):
@ -697,3 +721,13 @@ class FeatureConfig(object):
self.table = table
self.max_sequence_length = max_sequence_length
self.name = name
def __repr__(self):
return (
"FeatureConfig(table={table!r}, "
"max_sequence_length={max_sequence_length!r}, name={name!r})"
.format(
table=self.table,
max_sequence_length=self.max_sequence_length,
name=self.name)
)

View File

@ -60,6 +60,34 @@ class TPUEmbeddingOptimizerTest(parameterized.TestCase, test.TestCase):
self.assertEqual(1., opt.clip_gradient_max)
class ConfigTest(test.TestCase):
def test_table_config_repr(self):
table = tpu_embedding_v2_utils.TableConfig(
vocabulary_size=2, dim=4, initializer=None,
combiner='sum', name='table')
self.assertEqual(
repr(table),
'TableConfig(vocabulary_size=2, dim=4, initializer=None, '
'optimizer=None, combiner=\'sum\', name=\'table\')')
def test_feature_config_repr(self):
table = tpu_embedding_v2_utils.TableConfig(
vocabulary_size=2, dim=4, initializer=None,
combiner='sum', name='table')
feature_config = tpu_embedding_v2_utils.FeatureConfig(
table=table, name='feature')
self.assertEqual(
repr(feature_config),
'FeatureConfig(table=TableConfig(vocabulary_size=2, dim=4, '
'initializer=None, optimizer=None, combiner=\'sum\', name=\'table\'), '
'max_sequence_length=0, name=\'feature\')'
)
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()