make get_config from config public for feature column
PiperOrigin-RevId: 272227486
This commit is contained in:
parent
9cd59e4549
commit
97757f34d2
@ -2265,8 +2265,7 @@ class FeatureColumn(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""Returns the config of the feature column.
|
||||
|
||||
A FeatureColumn config is a Python dictionary (serializable) containing the
|
||||
@ -2283,7 +2282,7 @@ class FeatureColumn(object):
|
||||
'SerializationExampleFeatureColumn',
|
||||
('dimension', 'parent', 'dtype', 'normalizer_fn'))):
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
# Create a dict from the namedtuple.
|
||||
# Python attribute literals can be directly copied from / to the config.
|
||||
# For example 'dimension', assuming it is an integer literal.
|
||||
@ -2304,8 +2303,8 @@ class FeatureColumn(object):
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
# This should do the inverse transform from `_get_config` and construct
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
# This should do the inverse transform from `get_config` and construct
|
||||
# the namedtuple.
|
||||
kwargs = config.copy()
|
||||
kwargs['parent'] = deserialize_feature_column(
|
||||
@ -2320,21 +2319,24 @@ class FeatureColumn(object):
|
||||
A serializable Dict that can be used to deserialize the object with
|
||||
from_config.
|
||||
"""
|
||||
pass
|
||||
return self._get_config()
|
||||
|
||||
def _get_config(self):
|
||||
raise NotImplementedError('Must be implemented in subclasses.')
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""Creates a FeatureColumn from its config.
|
||||
|
||||
This method should be the reverse of `_get_config`, capable of instantiating
|
||||
the same FeatureColumn from the config dictionary. See `_get_config` for an
|
||||
This method should be the reverse of `get_config`, capable of instantiating
|
||||
the same FeatureColumn from the config dictionary. See `get_config` for an
|
||||
example of common (de)serialization practices followed in this file.
|
||||
|
||||
TODO(b/118939620): This is a private method until consensus is reached on
|
||||
supporting object deserialization deduping within Keras.
|
||||
|
||||
Args:
|
||||
config: A Dict config acquired with `_get_config`.
|
||||
config: A Dict config acquired with `get_config`.
|
||||
custom_objects: Optional dictionary mapping names (strings) to custom
|
||||
classes or functions to be considered during deserialization.
|
||||
columns_by_name: A Dict[String, FeatureColumn] of existing columns in
|
||||
@ -2344,7 +2346,11 @@ class FeatureColumn(object):
|
||||
Returns:
|
||||
A FeatureColumn for the input config.
|
||||
"""
|
||||
pass
|
||||
return cls._from_config(config, custom_objects, columns_by_name)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
raise NotImplementedError('Must be implemented in subclasses.')
|
||||
|
||||
|
||||
class DenseColumn(FeatureColumn):
|
||||
@ -2857,7 +2863,7 @@ class NumericColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.key]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
config = dict(zip(self._fields, self))
|
||||
config['normalizer_fn'] = generic_utils.serialize_keras_object(
|
||||
@ -2866,7 +2872,7 @@ class NumericColumn(
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
_check_config_keys(config, cls._fields)
|
||||
kwargs = _standardize_and_copy_config(config)
|
||||
@ -3014,7 +3020,7 @@ class BucketizedColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.source_column]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
config = dict(zip(self._fields, self))
|
||||
@ -3022,7 +3028,7 @@ class BucketizedColumn(
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
_check_config_keys(config, cls._fields)
|
||||
@ -3247,7 +3253,7 @@ class EmbeddingColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.categorical_column]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
config = dict(zip(self._fields, self))
|
||||
@ -3257,7 +3263,7 @@ class EmbeddingColumn(
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
_check_config_keys(config, cls._fields)
|
||||
@ -3440,15 +3446,6 @@ class SharedEmbeddingColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.categorical_column]
|
||||
|
||||
def _get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _check_shape(shape, key):
|
||||
"""Returns shape if it's valid, raises error otherwise."""
|
||||
@ -3559,14 +3556,14 @@ class HashedCategoricalColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.key]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
config = dict(zip(self._fields, self))
|
||||
config['dtype'] = self.dtype.name
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
_check_config_keys(config, cls._fields)
|
||||
kwargs = _standardize_and_copy_config(config)
|
||||
@ -3673,14 +3670,14 @@ class VocabularyFileCategoricalColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.key]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
config = dict(zip(self._fields, self))
|
||||
config['dtype'] = self.dtype.name
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
_check_config_keys(config, cls._fields)
|
||||
kwargs = _standardize_and_copy_config(config)
|
||||
@ -3787,14 +3784,14 @@ class VocabularyListCategoricalColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.key]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
config = dict(zip(self._fields, self))
|
||||
config['dtype'] = self.dtype.name
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
_check_config_keys(config, cls._fields)
|
||||
kwargs = _standardize_and_copy_config(config)
|
||||
@ -3899,12 +3896,12 @@ class IdentityCategoricalColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.key]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return dict(zip(self._fields, self))
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
_check_config_keys(config, cls._fields)
|
||||
kwargs = _standardize_and_copy_config(config)
|
||||
@ -4013,7 +4010,7 @@ class WeightedCategoricalColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.categorical_column, self.weight_feature_key]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
config = dict(zip(self._fields, self))
|
||||
@ -4023,7 +4020,7 @@ class WeightedCategoricalColumn(
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
_check_config_keys(config, cls._fields)
|
||||
@ -4157,7 +4154,7 @@ class CrossedColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return list(self.keys)
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
config = dict(zip(self._fields, self))
|
||||
@ -4165,7 +4162,7 @@ class CrossedColumn(
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
_check_config_keys(config, cls._fields)
|
||||
@ -4427,7 +4424,7 @@ class IndicatorColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.categorical_column]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
config = dict(zip(self._fields, self))
|
||||
@ -4436,7 +4433,7 @@ class IndicatorColumn(
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
_check_config_keys(config, cls._fields)
|
||||
@ -4573,7 +4570,7 @@ class SequenceCategoricalColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.categorical_column]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
config = dict(zip(self._fields, self))
|
||||
@ -4582,7 +4579,7 @@ class SequenceCategoricalColumn(
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
_check_config_keys(config, cls._fields)
|
||||
|
@ -81,10 +81,10 @@ class BaseFeatureColumnForTests(fc.FeatureColumn):
|
||||
raise ValueError('Should not use this method.')
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
raise ValueError('Should not use this method.')
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
raise ValueError('Should not use this method.')
|
||||
|
||||
|
||||
@ -478,7 +478,7 @@ class NumericColumnTest(test.TestCase):
|
||||
price = fc.numeric_column('price', normalizer_fn=_increment_two)
|
||||
self.assertEqual(['price'], price.parents)
|
||||
|
||||
config = price._get_config()
|
||||
config = price.get_config()
|
||||
self.assertEqual({
|
||||
'key': 'price',
|
||||
'shape': (1,),
|
||||
@ -487,7 +487,7 @@ class NumericColumnTest(test.TestCase):
|
||||
'normalizer_fn': '_increment_two'
|
||||
}, config)
|
||||
|
||||
new_col = fc.NumericColumn._from_config(
|
||||
new_col = fc.NumericColumn.from_config(
|
||||
config, custom_objects={'_increment_two': _increment_two})
|
||||
self.assertEqual(price, new_col)
|
||||
self.assertEqual(new_col.shape, (1,))
|
||||
@ -833,7 +833,7 @@ class BucketizedColumnTest(test.TestCase):
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||
self.assertEqual([price], bucketized_price.parents)
|
||||
|
||||
config = bucketized_price._get_config()
|
||||
config = bucketized_price.get_config()
|
||||
self.assertEqual({
|
||||
'source_column': {
|
||||
'class_name': 'NumericColumn',
|
||||
@ -848,11 +848,11 @@ class BucketizedColumnTest(test.TestCase):
|
||||
'boundaries': (0, 2, 4, 6)
|
||||
}, config)
|
||||
|
||||
new_bucketized_price = fc.BucketizedColumn._from_config(config)
|
||||
new_bucketized_price = fc.BucketizedColumn.from_config(config)
|
||||
self.assertEqual(bucketized_price, new_bucketized_price)
|
||||
self.assertIsNot(price, new_bucketized_price.source_column)
|
||||
|
||||
new_bucketized_price = fc.BucketizedColumn._from_config(
|
||||
new_bucketized_price = fc.BucketizedColumn.from_config(
|
||||
config,
|
||||
columns_by_name={
|
||||
serialization._column_name_with_class_name(price): price
|
||||
@ -1106,7 +1106,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
||||
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
|
||||
self.assertEqual(['wire'], wire_column.parents)
|
||||
|
||||
config = wire_column._get_config()
|
||||
config = wire_column.get_config()
|
||||
self.assertEqual({
|
||||
'key': 'wire',
|
||||
'hash_bucket_size': 4,
|
||||
@ -1114,7 +1114,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
||||
}, config)
|
||||
|
||||
self.assertEqual(wire_column,
|
||||
fc.HashedCategoricalColumn._from_config(config))
|
||||
fc.HashedCategoricalColumn.from_config(config))
|
||||
|
||||
|
||||
class CrossedColumnTest(test.TestCase):
|
||||
@ -1588,7 +1588,7 @@ class CrossedColumnTest(test.TestCase):
|
||||
|
||||
self.assertEqual([b, 'c'], crossed.parents)
|
||||
|
||||
config = crossed._get_config()
|
||||
config = crossed.get_config()
|
||||
self.assertEqual({
|
||||
'hash_bucket_size':
|
||||
5,
|
||||
@ -1612,11 +1612,11 @@ class CrossedColumnTest(test.TestCase):
|
||||
}, 'c')
|
||||
}, config)
|
||||
|
||||
new_crossed = fc.CrossedColumn._from_config(config)
|
||||
new_crossed = fc.CrossedColumn.from_config(config)
|
||||
self.assertEqual(crossed, new_crossed)
|
||||
self.assertIsNot(b, new_crossed.keys[0])
|
||||
|
||||
new_crossed = fc.CrossedColumn._from_config(
|
||||
new_crossed = fc.CrossedColumn.from_config(
|
||||
config,
|
||||
columns_by_name={serialization._column_name_with_class_name(b): b})
|
||||
self.assertEqual(crossed, new_crossed)
|
||||
@ -4396,7 +4396,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
||||
|
||||
self.assertEqual(['wire'], wire_column.parents)
|
||||
|
||||
config = wire_column._get_config()
|
||||
config = wire_column.get_config()
|
||||
self.assertEqual({
|
||||
'default_value': -1,
|
||||
'dtype': 'string',
|
||||
@ -4407,7 +4407,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
||||
}, config)
|
||||
|
||||
self.assertEqual(wire_column,
|
||||
fc.VocabularyFileCategoricalColumn._from_config(config))
|
||||
fc.VocabularyFileCategoricalColumn.from_config(config))
|
||||
|
||||
|
||||
class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||
@ -4859,7 +4859,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||
|
||||
self.assertEqual(['aaa'], wire_column.parents)
|
||||
|
||||
config = wire_column._get_config()
|
||||
config = wire_column.get_config()
|
||||
self.assertEqual({
|
||||
'default_value': -1,
|
||||
'dtype': 'string',
|
||||
@ -4869,7 +4869,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||
}, config)
|
||||
|
||||
self.assertEqual(wire_column,
|
||||
fc.VocabularyListCategoricalColumn._from_config(config))
|
||||
fc.VocabularyListCategoricalColumn.from_config(config))
|
||||
|
||||
|
||||
class IdentityCategoricalColumnTest(test.TestCase):
|
||||
@ -5218,14 +5218,14 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
||||
|
||||
self.assertEqual(['aaa'], column.parents)
|
||||
|
||||
config = column._get_config()
|
||||
config = column.get_config()
|
||||
self.assertEqual({
|
||||
'default_value': None,
|
||||
'key': 'aaa',
|
||||
'number_buckets': 3
|
||||
}, config)
|
||||
|
||||
self.assertEqual(column, fc.IdentityCategoricalColumn._from_config(config))
|
||||
self.assertEqual(column, fc.IdentityCategoricalColumn.from_config(config))
|
||||
|
||||
|
||||
class TransformFeaturesTest(test.TestCase):
|
||||
@ -5600,7 +5600,7 @@ class IndicatorColumnTest(test.TestCase):
|
||||
|
||||
self.assertEqual([parent], animal.parents)
|
||||
|
||||
config = animal._get_config()
|
||||
config = animal.get_config()
|
||||
self.assertEqual({
|
||||
'categorical_column': {
|
||||
'class_name': 'IdentityCategoricalColumn',
|
||||
@ -5612,11 +5612,11 @@ class IndicatorColumnTest(test.TestCase):
|
||||
}
|
||||
}, config)
|
||||
|
||||
new_animal = fc.IndicatorColumn._from_config(config)
|
||||
new_animal = fc.IndicatorColumn.from_config(config)
|
||||
self.assertEqual(animal, new_animal)
|
||||
self.assertIsNot(parent, new_animal.categorical_column)
|
||||
|
||||
new_animal = fc.IndicatorColumn._from_config(
|
||||
new_animal = fc.IndicatorColumn.from_config(
|
||||
config,
|
||||
columns_by_name={
|
||||
serialization._column_name_with_class_name(parent): parent
|
||||
@ -6605,7 +6605,7 @@ class EmbeddingColumnTest(test.TestCase):
|
||||
|
||||
self.assertEqual([categorical_column], embedding_column.parents)
|
||||
|
||||
config = embedding_column._get_config()
|
||||
config = embedding_column.get_config()
|
||||
self.assertEqual({
|
||||
'categorical_column': {
|
||||
'class_name': 'IdentityCategoricalColumn',
|
||||
@ -6633,22 +6633,22 @@ class EmbeddingColumnTest(test.TestCase):
|
||||
}, config)
|
||||
|
||||
custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal}
|
||||
new_embedding_column = fc.EmbeddingColumn._from_config(
|
||||
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||
config, custom_objects=custom_objects)
|
||||
self.assertEqual(embedding_column._get_config(),
|
||||
new_embedding_column._get_config())
|
||||
self.assertEqual(embedding_column.get_config(),
|
||||
new_embedding_column.get_config())
|
||||
self.assertIsNot(categorical_column,
|
||||
new_embedding_column.categorical_column)
|
||||
|
||||
new_embedding_column = fc.EmbeddingColumn._from_config(
|
||||
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||
config,
|
||||
custom_objects=custom_objects,
|
||||
columns_by_name={
|
||||
serialization._column_name_with_class_name(categorical_column):
|
||||
categorical_column
|
||||
})
|
||||
self.assertEqual(embedding_column._get_config(),
|
||||
new_embedding_column._get_config())
|
||||
self.assertEqual(embedding_column.get_config(),
|
||||
new_embedding_column.get_config())
|
||||
self.assertIs(categorical_column, new_embedding_column.categorical_column)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@ -6666,7 +6666,7 @@ class EmbeddingColumnTest(test.TestCase):
|
||||
|
||||
self.assertEqual([categorical_column], embedding_column.parents)
|
||||
|
||||
config = embedding_column._get_config()
|
||||
config = embedding_column.get_config()
|
||||
self.assertEqual({
|
||||
'categorical_column': {
|
||||
'class_name': 'IdentityCategoricalColumn',
|
||||
@ -6689,13 +6689,13 @@ class EmbeddingColumnTest(test.TestCase):
|
||||
'_initializer': _initializer,
|
||||
}
|
||||
|
||||
new_embedding_column = fc.EmbeddingColumn._from_config(
|
||||
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||
config, custom_objects=custom_objects)
|
||||
self.assertEqual(embedding_column, new_embedding_column)
|
||||
self.assertIsNot(categorical_column,
|
||||
new_embedding_column.categorical_column)
|
||||
|
||||
new_embedding_column = fc.EmbeddingColumn._from_config(
|
||||
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||
config,
|
||||
custom_objects=custom_objects,
|
||||
columns_by_name={
|
||||
@ -7763,7 +7763,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
||||
|
||||
self.assertEqual([categorical_column, 'weight'], column.parents)
|
||||
|
||||
config = column._get_config()
|
||||
config = column.get_config()
|
||||
self.assertEqual({
|
||||
'categorical_column': {
|
||||
'config': {
|
||||
@ -7777,9 +7777,9 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
||||
'weight_feature_key': 'weight'
|
||||
}, config)
|
||||
|
||||
self.assertEqual(column, fc.WeightedCategoricalColumn._from_config(config))
|
||||
self.assertEqual(column, fc.WeightedCategoricalColumn.from_config(config))
|
||||
|
||||
new_column = fc.WeightedCategoricalColumn._from_config(
|
||||
new_column = fc.WeightedCategoricalColumn.from_config(
|
||||
config,
|
||||
columns_by_name={
|
||||
serialization._column_name_with_class_name(categorical_column):
|
||||
|
@ -582,7 +582,7 @@ class SequenceNumericColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
return [self.key]
|
||||
|
||||
def _get_config(self):
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
config = dict(zip(self._fields, self))
|
||||
config['normalizer_fn'] = utils.serialize_keras_object(self.normalizer_fn)
|
||||
@ -590,7 +590,7 @@ class SequenceNumericColumn(
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
fc._check_config_keys(config, cls._fields)
|
||||
kwargs = fc._standardize_and_copy_config(config)
|
||||
|
@ -765,7 +765,7 @@ class SequenceCategoricalColumnWithIdentityTest(
|
||||
'animal', num_buckets=4)
|
||||
animal = fc.indicator_column(parent)
|
||||
|
||||
config = animal._get_config()
|
||||
config = animal.get_config()
|
||||
self.assertEqual(
|
||||
{
|
||||
'categorical_column': {
|
||||
@ -783,11 +783,11 @@ class SequenceCategoricalColumnWithIdentityTest(
|
||||
}
|
||||
}, config)
|
||||
|
||||
new_animal = fc.IndicatorColumn._from_config(config)
|
||||
new_animal = fc.IndicatorColumn.from_config(config)
|
||||
self.assertEqual(animal, new_animal)
|
||||
self.assertIsNot(parent, new_animal.categorical_column)
|
||||
|
||||
new_animal = fc.IndicatorColumn._from_config(
|
||||
new_animal = fc.IndicatorColumn.from_config(
|
||||
config,
|
||||
columns_by_name={
|
||||
serialization._column_name_with_class_name(parent): parent
|
||||
|
@ -45,14 +45,14 @@ def serialize_feature_column(fc):
|
||||
"""Serializes a FeatureColumn or a raw string key.
|
||||
|
||||
This method should only be used to serialize parent FeatureColumns when
|
||||
implementing FeatureColumn._get_config(), else serialize_feature_columns()
|
||||
implementing FeatureColumn.get_config(), else serialize_feature_columns()
|
||||
is preferable.
|
||||
|
||||
This serialization also keeps information of the FeatureColumn class, so
|
||||
deserialization is possible without knowing the class type. For example:
|
||||
|
||||
a = numeric_column('x')
|
||||
a._get_config() gives:
|
||||
a.get_config() gives:
|
||||
{
|
||||
'key': 'price',
|
||||
'shape': (1,),
|
||||
@ -85,7 +85,7 @@ def serialize_feature_column(fc):
|
||||
return fc
|
||||
elif isinstance(fc, fc_lib.FeatureColumn):
|
||||
return generic_utils.serialize_keras_class_and_config(
|
||||
fc.__class__.__name__, fc._get_config()) # pylint: disable=protected-access
|
||||
fc.__class__.__name__, fc.get_config()) # pylint: disable=protected-access
|
||||
else:
|
||||
raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
|
||||
|
||||
@ -96,7 +96,7 @@ def deserialize_feature_column(config,
|
||||
"""Deserializes a `config` generated with `serialize_feature_column`.
|
||||
|
||||
This method should only be used to deserialize parent FeatureColumns when
|
||||
implementing FeatureColumn._from_config(), else deserialize_feature_columns()
|
||||
implementing FeatureColumn.from_config(), else deserialize_feature_columns()
|
||||
is preferable. Returns a FeatureColumn for this config.
|
||||
TODO(b/118939620): Simplify code if Keras utils support object deduping.
|
||||
|
||||
@ -136,7 +136,7 @@ def deserialize_feature_column(config,
|
||||
'Expected FeatureColumn class, instead found: {}'.format(cls))
|
||||
|
||||
# Always deserialize the FeatureColumn, in order to get the name.
|
||||
new_instance = cls._from_config( # pylint: disable=protected-access
|
||||
new_instance = cls.from_config( # pylint: disable=protected-access
|
||||
cls_config,
|
||||
custom_objects=custom_objects,
|
||||
columns_by_name=columns_by_name)
|
||||
|
Loading…
Reference in New Issue
Block a user