make get_config from config public for feature column

PiperOrigin-RevId: 272227486
This commit is contained in:
Zhenyu Tan 2019-10-01 09:27:03 -07:00 committed by TensorFlower Gardener
parent 9cd59e4549
commit 97757f34d2
5 changed files with 83 additions and 86 deletions

View File

@ -2265,8 +2265,7 @@ class FeatureColumn(object):
"""
pass
@abc.abstractmethod
def _get_config(self):
def get_config(self):
"""Returns the config of the feature column.
A FeatureColumn config is a Python dictionary (serializable) containing the
@ -2283,7 +2282,7 @@ class FeatureColumn(object):
'SerializationExampleFeatureColumn',
('dimension', 'parent', 'dtype', 'normalizer_fn'))):
def _get_config(self):
def get_config(self):
# Create a dict from the namedtuple.
# Python attribute literals can be directly copied from / to the config.
# For example 'dimension', assuming it is an integer literal.
@ -2304,8 +2303,8 @@ class FeatureColumn(object):
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
# This should do the inverse transform from `_get_config` and construct
def from_config(cls, config, custom_objects=None, columns_by_name=None):
# This should do the inverse transform from `get_config` and construct
# the namedtuple.
kwargs = config.copy()
kwargs['parent'] = deserialize_feature_column(
@ -2320,21 +2319,24 @@ class FeatureColumn(object):
A serializable Dict that can be used to deserialize the object with
from_config.
"""
pass
return self._get_config()
def _get_config(self):
raise NotImplementedError('Must be implemented in subclasses.')
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""Creates a FeatureColumn from its config.
This method should be the reverse of `_get_config`, capable of instantiating
the same FeatureColumn from the config dictionary. See `_get_config` for an
This method should be the reverse of `get_config`, capable of instantiating
the same FeatureColumn from the config dictionary. See `get_config` for an
example of common (de)serialization practices followed in this file.
TODO(b/118939620): This is a private method until consensus is reached on
supporting object deserialization deduping within Keras.
Args:
config: A Dict config acquired with `_get_config`.
config: A Dict config acquired with `get_config`.
custom_objects: Optional dictionary mapping names (strings) to custom
classes or functions to be considered during deserialization.
columns_by_name: A Dict[String, FeatureColumn] of existing columns in
@ -2344,7 +2346,11 @@ class FeatureColumn(object):
Returns:
A FeatureColumn for the input config.
"""
pass
return cls._from_config(config, custom_objects, columns_by_name)
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
raise NotImplementedError('Must be implemented in subclasses.')
class DenseColumn(FeatureColumn):
@ -2857,7 +2863,7 @@ class NumericColumn(
"""See 'FeatureColumn` base class."""
return [self.key]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
config = dict(zip(self._fields, self))
config['normalizer_fn'] = generic_utils.serialize_keras_object(
@ -2866,7 +2872,7 @@ class NumericColumn(
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
_check_config_keys(config, cls._fields)
kwargs = _standardize_and_copy_config(config)
@ -3014,7 +3020,7 @@ class BucketizedColumn(
"""See 'FeatureColumn` base class."""
return [self.source_column]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
config = dict(zip(self._fields, self))
@ -3022,7 +3028,7 @@ class BucketizedColumn(
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
_check_config_keys(config, cls._fields)
@ -3247,7 +3253,7 @@ class EmbeddingColumn(
"""See 'FeatureColumn` base class."""
return [self.categorical_column]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
config = dict(zip(self._fields, self))
@ -3257,7 +3263,7 @@ class EmbeddingColumn(
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
_check_config_keys(config, cls._fields)
@ -3440,15 +3446,6 @@ class SharedEmbeddingColumn(
"""See 'FeatureColumn` base class."""
return [self.categorical_column]
def _get_config(self):
"""See 'FeatureColumn` base class."""
raise NotImplementedError()
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
raise NotImplementedError()
def _check_shape(shape, key):
"""Returns shape if it's valid, raises error otherwise."""
@ -3559,14 +3556,14 @@ class HashedCategoricalColumn(
"""See 'FeatureColumn` base class."""
return [self.key]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
config = dict(zip(self._fields, self))
config['dtype'] = self.dtype.name
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
_check_config_keys(config, cls._fields)
kwargs = _standardize_and_copy_config(config)
@ -3673,14 +3670,14 @@ class VocabularyFileCategoricalColumn(
"""See 'FeatureColumn` base class."""
return [self.key]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
config = dict(zip(self._fields, self))
config['dtype'] = self.dtype.name
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
_check_config_keys(config, cls._fields)
kwargs = _standardize_and_copy_config(config)
@ -3787,14 +3784,14 @@ class VocabularyListCategoricalColumn(
"""See 'FeatureColumn` base class."""
return [self.key]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
config = dict(zip(self._fields, self))
config['dtype'] = self.dtype.name
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
_check_config_keys(config, cls._fields)
kwargs = _standardize_and_copy_config(config)
@ -3899,12 +3896,12 @@ class IdentityCategoricalColumn(
"""See 'FeatureColumn` base class."""
return [self.key]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
return dict(zip(self._fields, self))
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
_check_config_keys(config, cls._fields)
kwargs = _standardize_and_copy_config(config)
@ -4013,7 +4010,7 @@ class WeightedCategoricalColumn(
"""See 'FeatureColumn` base class."""
return [self.categorical_column, self.weight_feature_key]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
config = dict(zip(self._fields, self))
@ -4023,7 +4020,7 @@ class WeightedCategoricalColumn(
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
_check_config_keys(config, cls._fields)
@ -4157,7 +4154,7 @@ class CrossedColumn(
"""See 'FeatureColumn` base class."""
return list(self.keys)
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
config = dict(zip(self._fields, self))
@ -4165,7 +4162,7 @@ class CrossedColumn(
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
_check_config_keys(config, cls._fields)
@ -4427,7 +4424,7 @@ class IndicatorColumn(
"""See 'FeatureColumn` base class."""
return [self.categorical_column]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
config = dict(zip(self._fields, self))
@ -4436,7 +4433,7 @@ class IndicatorColumn(
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
_check_config_keys(config, cls._fields)
@ -4573,7 +4570,7 @@ class SequenceCategoricalColumn(
"""See 'FeatureColumn` base class."""
return [self.categorical_column]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
config = dict(zip(self._fields, self))
@ -4582,7 +4579,7 @@ class SequenceCategoricalColumn(
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
_check_config_keys(config, cls._fields)

View File

@ -81,10 +81,10 @@ class BaseFeatureColumnForTests(fc.FeatureColumn):
raise ValueError('Should not use this method.')
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
raise ValueError('Should not use this method.')
def _get_config(self):
def get_config(self):
raise ValueError('Should not use this method.')
@ -478,7 +478,7 @@ class NumericColumnTest(test.TestCase):
price = fc.numeric_column('price', normalizer_fn=_increment_two)
self.assertEqual(['price'], price.parents)
config = price._get_config()
config = price.get_config()
self.assertEqual({
'key': 'price',
'shape': (1,),
@ -487,7 +487,7 @@ class NumericColumnTest(test.TestCase):
'normalizer_fn': '_increment_two'
}, config)
new_col = fc.NumericColumn._from_config(
new_col = fc.NumericColumn.from_config(
config, custom_objects={'_increment_two': _increment_two})
self.assertEqual(price, new_col)
self.assertEqual(new_col.shape, (1,))
@ -833,7 +833,7 @@ class BucketizedColumnTest(test.TestCase):
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
self.assertEqual([price], bucketized_price.parents)
config = bucketized_price._get_config()
config = bucketized_price.get_config()
self.assertEqual({
'source_column': {
'class_name': 'NumericColumn',
@ -848,11 +848,11 @@ class BucketizedColumnTest(test.TestCase):
'boundaries': (0, 2, 4, 6)
}, config)
new_bucketized_price = fc.BucketizedColumn._from_config(config)
new_bucketized_price = fc.BucketizedColumn.from_config(config)
self.assertEqual(bucketized_price, new_bucketized_price)
self.assertIsNot(price, new_bucketized_price.source_column)
new_bucketized_price = fc.BucketizedColumn._from_config(
new_bucketized_price = fc.BucketizedColumn.from_config(
config,
columns_by_name={
serialization._column_name_with_class_name(price): price
@ -1106,7 +1106,7 @@ class HashedCategoricalColumnTest(test.TestCase):
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
self.assertEqual(['wire'], wire_column.parents)
config = wire_column._get_config()
config = wire_column.get_config()
self.assertEqual({
'key': 'wire',
'hash_bucket_size': 4,
@ -1114,7 +1114,7 @@ class HashedCategoricalColumnTest(test.TestCase):
}, config)
self.assertEqual(wire_column,
fc.HashedCategoricalColumn._from_config(config))
fc.HashedCategoricalColumn.from_config(config))
class CrossedColumnTest(test.TestCase):
@ -1588,7 +1588,7 @@ class CrossedColumnTest(test.TestCase):
self.assertEqual([b, 'c'], crossed.parents)
config = crossed._get_config()
config = crossed.get_config()
self.assertEqual({
'hash_bucket_size':
5,
@ -1612,11 +1612,11 @@ class CrossedColumnTest(test.TestCase):
}, 'c')
}, config)
new_crossed = fc.CrossedColumn._from_config(config)
new_crossed = fc.CrossedColumn.from_config(config)
self.assertEqual(crossed, new_crossed)
self.assertIsNot(b, new_crossed.keys[0])
new_crossed = fc.CrossedColumn._from_config(
new_crossed = fc.CrossedColumn.from_config(
config,
columns_by_name={serialization._column_name_with_class_name(b): b})
self.assertEqual(crossed, new_crossed)
@ -4396,7 +4396,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
self.assertEqual(['wire'], wire_column.parents)
config = wire_column._get_config()
config = wire_column.get_config()
self.assertEqual({
'default_value': -1,
'dtype': 'string',
@ -4407,7 +4407,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
}, config)
self.assertEqual(wire_column,
fc.VocabularyFileCategoricalColumn._from_config(config))
fc.VocabularyFileCategoricalColumn.from_config(config))
class VocabularyListCategoricalColumnTest(test.TestCase):
@ -4859,7 +4859,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertEqual(['aaa'], wire_column.parents)
config = wire_column._get_config()
config = wire_column.get_config()
self.assertEqual({
'default_value': -1,
'dtype': 'string',
@ -4869,7 +4869,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
}, config)
self.assertEqual(wire_column,
fc.VocabularyListCategoricalColumn._from_config(config))
fc.VocabularyListCategoricalColumn.from_config(config))
class IdentityCategoricalColumnTest(test.TestCase):
@ -5218,14 +5218,14 @@ class IdentityCategoricalColumnTest(test.TestCase):
self.assertEqual(['aaa'], column.parents)
config = column._get_config()
config = column.get_config()
self.assertEqual({
'default_value': None,
'key': 'aaa',
'number_buckets': 3
}, config)
self.assertEqual(column, fc.IdentityCategoricalColumn._from_config(config))
self.assertEqual(column, fc.IdentityCategoricalColumn.from_config(config))
class TransformFeaturesTest(test.TestCase):
@ -5600,7 +5600,7 @@ class IndicatorColumnTest(test.TestCase):
self.assertEqual([parent], animal.parents)
config = animal._get_config()
config = animal.get_config()
self.assertEqual({
'categorical_column': {
'class_name': 'IdentityCategoricalColumn',
@ -5612,11 +5612,11 @@ class IndicatorColumnTest(test.TestCase):
}
}, config)
new_animal = fc.IndicatorColumn._from_config(config)
new_animal = fc.IndicatorColumn.from_config(config)
self.assertEqual(animal, new_animal)
self.assertIsNot(parent, new_animal.categorical_column)
new_animal = fc.IndicatorColumn._from_config(
new_animal = fc.IndicatorColumn.from_config(
config,
columns_by_name={
serialization._column_name_with_class_name(parent): parent
@ -6605,7 +6605,7 @@ class EmbeddingColumnTest(test.TestCase):
self.assertEqual([categorical_column], embedding_column.parents)
config = embedding_column._get_config()
config = embedding_column.get_config()
self.assertEqual({
'categorical_column': {
'class_name': 'IdentityCategoricalColumn',
@ -6633,22 +6633,22 @@ class EmbeddingColumnTest(test.TestCase):
}, config)
custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal}
new_embedding_column = fc.EmbeddingColumn._from_config(
new_embedding_column = fc.EmbeddingColumn.from_config(
config, custom_objects=custom_objects)
self.assertEqual(embedding_column._get_config(),
new_embedding_column._get_config())
self.assertEqual(embedding_column.get_config(),
new_embedding_column.get_config())
self.assertIsNot(categorical_column,
new_embedding_column.categorical_column)
new_embedding_column = fc.EmbeddingColumn._from_config(
new_embedding_column = fc.EmbeddingColumn.from_config(
config,
custom_objects=custom_objects,
columns_by_name={
serialization._column_name_with_class_name(categorical_column):
categorical_column
})
self.assertEqual(embedding_column._get_config(),
new_embedding_column._get_config())
self.assertEqual(embedding_column.get_config(),
new_embedding_column.get_config())
self.assertIs(categorical_column, new_embedding_column.categorical_column)
@test_util.run_deprecated_v1
@ -6666,7 +6666,7 @@ class EmbeddingColumnTest(test.TestCase):
self.assertEqual([categorical_column], embedding_column.parents)
config = embedding_column._get_config()
config = embedding_column.get_config()
self.assertEqual({
'categorical_column': {
'class_name': 'IdentityCategoricalColumn',
@ -6689,13 +6689,13 @@ class EmbeddingColumnTest(test.TestCase):
'_initializer': _initializer,
}
new_embedding_column = fc.EmbeddingColumn._from_config(
new_embedding_column = fc.EmbeddingColumn.from_config(
config, custom_objects=custom_objects)
self.assertEqual(embedding_column, new_embedding_column)
self.assertIsNot(categorical_column,
new_embedding_column.categorical_column)
new_embedding_column = fc.EmbeddingColumn._from_config(
new_embedding_column = fc.EmbeddingColumn.from_config(
config,
custom_objects=custom_objects,
columns_by_name={
@ -7763,7 +7763,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
self.assertEqual([categorical_column, 'weight'], column.parents)
config = column._get_config()
config = column.get_config()
self.assertEqual({
'categorical_column': {
'config': {
@ -7777,9 +7777,9 @@ class WeightedCategoricalColumnTest(test.TestCase):
'weight_feature_key': 'weight'
}, config)
self.assertEqual(column, fc.WeightedCategoricalColumn._from_config(config))
self.assertEqual(column, fc.WeightedCategoricalColumn.from_config(config))
new_column = fc.WeightedCategoricalColumn._from_config(
new_column = fc.WeightedCategoricalColumn.from_config(
config,
columns_by_name={
serialization._column_name_with_class_name(categorical_column):

View File

@ -582,7 +582,7 @@ class SequenceNumericColumn(
"""See 'FeatureColumn` base class."""
return [self.key]
def _get_config(self):
def get_config(self):
"""See 'FeatureColumn` base class."""
config = dict(zip(self._fields, self))
config['normalizer_fn'] = utils.serialize_keras_object(self.normalizer_fn)
@ -590,7 +590,7 @@ class SequenceNumericColumn(
return config
@classmethod
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
fc._check_config_keys(config, cls._fields)
kwargs = fc._standardize_and_copy_config(config)

View File

@ -765,7 +765,7 @@ class SequenceCategoricalColumnWithIdentityTest(
'animal', num_buckets=4)
animal = fc.indicator_column(parent)
config = animal._get_config()
config = animal.get_config()
self.assertEqual(
{
'categorical_column': {
@ -783,11 +783,11 @@ class SequenceCategoricalColumnWithIdentityTest(
}
}, config)
new_animal = fc.IndicatorColumn._from_config(config)
new_animal = fc.IndicatorColumn.from_config(config)
self.assertEqual(animal, new_animal)
self.assertIsNot(parent, new_animal.categorical_column)
new_animal = fc.IndicatorColumn._from_config(
new_animal = fc.IndicatorColumn.from_config(
config,
columns_by_name={
serialization._column_name_with_class_name(parent): parent

View File

@ -45,14 +45,14 @@ def serialize_feature_column(fc):
"""Serializes a FeatureColumn or a raw string key.
This method should only be used to serialize parent FeatureColumns when
implementing FeatureColumn._get_config(), else serialize_feature_columns()
implementing FeatureColumn.get_config(), else serialize_feature_columns()
is preferable.
This serialization also keeps information of the FeatureColumn class, so
deserialization is possible without knowing the class type. For example:
a = numeric_column('x')
a._get_config() gives:
a.get_config() gives:
{
'key': 'price',
'shape': (1,),
@ -85,7 +85,7 @@ def serialize_feature_column(fc):
return fc
elif isinstance(fc, fc_lib.FeatureColumn):
return generic_utils.serialize_keras_class_and_config(
fc.__class__.__name__, fc._get_config()) # pylint: disable=protected-access
fc.__class__.__name__, fc.get_config()) # pylint: disable=protected-access
else:
raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
@ -96,7 +96,7 @@ def deserialize_feature_column(config,
"""Deserializes a `config` generated with `serialize_feature_column`.
This method should only be used to deserialize parent FeatureColumns when
implementing FeatureColumn._from_config(), else deserialize_feature_columns()
implementing FeatureColumn.from_config(), else deserialize_feature_columns()
is preferable. Returns a FeatureColumn for this config.
TODO(b/118939620): Simplify code if Keras utils support object deduping.
@ -136,7 +136,7 @@ def deserialize_feature_column(config,
'Expected FeatureColumn class, instead found: {}'.format(cls))
# Always deserialize the FeatureColumn, in order to get the name.
new_instance = cls._from_config( # pylint: disable=protected-access
new_instance = cls.from_config( # pylint: disable=protected-access
cls_config,
custom_objects=custom_objects,
columns_by_name=columns_by_name)