Add sensible __repr__ implementation for TableConfig, FeatureConfig.
PiperOrigin-RevId: 335060379 Change-Id: I456fae3f07fd74b08db8f069ab817ac18f458fb7
This commit is contained in:
parent
4cfe0b124f
commit
12833639d5
tensorflow/python/tpu
@ -21,6 +21,7 @@ from __future__ import unicode_literals
|
||||
|
||||
import abc
|
||||
import math
|
||||
import typing
|
||||
from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
|
||||
import six
|
||||
|
||||
@ -620,6 +621,29 @@ class TableConfig(object):
|
||||
self.combiner = combiner
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
# If using the default initializer, just print "None" for clarity.
|
||||
initializer = self.initializer
|
||||
|
||||
if isinstance(initializer, init_ops_v2.TruncatedNormal):
|
||||
# PY2 type checking can't infer type of initializer even after if.
|
||||
initializer = typing.cast(init_ops_v2.TruncatedNormal, initializer)
|
||||
if (initializer.mean == 0.0
|
||||
and math.isclose(initializer.stddev, 1/math.sqrt(self.dim))): # pytype: disable=module-attr (math.isclose not in PY2)
|
||||
initializer = None
|
||||
|
||||
return (
|
||||
"TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, "
|
||||
"initializer={initializer!r}, optimizer={optimizer!r}, "
|
||||
"combiner={combiner!r}, name={name!r})".format(
|
||||
vocabulary_size=self.vocabulary_size,
|
||||
dim=self.dim,
|
||||
initializer=initializer,
|
||||
optimizer=self.optimizer,
|
||||
combiner=self.combiner,
|
||||
name=self.name,)
|
||||
)
|
||||
|
||||
|
||||
@tf_export("tpu.experimental.embedding.FeatureConfig")
|
||||
class FeatureConfig(object):
|
||||
@ -697,3 +721,13 @@ class FeatureConfig(object):
|
||||
self.table = table
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"FeatureConfig(table={table!r}, "
|
||||
"max_sequence_length={max_sequence_length!r}, name={name!r})"
|
||||
.format(
|
||||
table=self.table,
|
||||
max_sequence_length=self.max_sequence_length,
|
||||
name=self.name)
|
||||
)
|
||||
|
@ -60,6 +60,34 @@ class TPUEmbeddingOptimizerTest(parameterized.TestCase, test.TestCase):
|
||||
self.assertEqual(1., opt.clip_gradient_max)
|
||||
|
||||
|
||||
class ConfigTest(test.TestCase):
|
||||
|
||||
def test_table_config_repr(self):
|
||||
table = tpu_embedding_v2_utils.TableConfig(
|
||||
vocabulary_size=2, dim=4, initializer=None,
|
||||
combiner='sum', name='table')
|
||||
|
||||
self.assertEqual(
|
||||
repr(table),
|
||||
'TableConfig(vocabulary_size=2, dim=4, initializer=None, '
|
||||
'optimizer=None, combiner=\'sum\', name=\'table\')')
|
||||
|
||||
def test_feature_config_repr(self):
|
||||
table = tpu_embedding_v2_utils.TableConfig(
|
||||
vocabulary_size=2, dim=4, initializer=None,
|
||||
combiner='sum', name='table')
|
||||
|
||||
feature_config = tpu_embedding_v2_utils.FeatureConfig(
|
||||
table=table, name='feature')
|
||||
|
||||
self.assertEqual(
|
||||
repr(feature_config),
|
||||
'FeatureConfig(table=TableConfig(vocabulary_size=2, dim=4, '
|
||||
'initializer=None, optimizer=None, combiner=\'sum\', name=\'table\'), '
|
||||
'max_sequence_length=0, name=\'feature\')'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user