make get_config from config public for feature column
PiperOrigin-RevId: 272227486
This commit is contained in:
parent
9cd59e4549
commit
97757f34d2
@ -2265,8 +2265,7 @@ class FeatureColumn(object):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
def get_config(self):
|
||||||
def _get_config(self):
|
|
||||||
"""Returns the config of the feature column.
|
"""Returns the config of the feature column.
|
||||||
|
|
||||||
A FeatureColumn config is a Python dictionary (serializable) containing the
|
A FeatureColumn config is a Python dictionary (serializable) containing the
|
||||||
@ -2283,7 +2282,7 @@ class FeatureColumn(object):
|
|||||||
'SerializationExampleFeatureColumn',
|
'SerializationExampleFeatureColumn',
|
||||||
('dimension', 'parent', 'dtype', 'normalizer_fn'))):
|
('dimension', 'parent', 'dtype', 'normalizer_fn'))):
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
# Create a dict from the namedtuple.
|
# Create a dict from the namedtuple.
|
||||||
# Python attribute literals can be directly copied from / to the config.
|
# Python attribute literals can be directly copied from / to the config.
|
||||||
# For example 'dimension', assuming it is an integer literal.
|
# For example 'dimension', assuming it is an integer literal.
|
||||||
@ -2304,8 +2303,8 @@ class FeatureColumn(object):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
# This should do the inverse transform from `_get_config` and construct
|
# This should do the inverse transform from `get_config` and construct
|
||||||
# the namedtuple.
|
# the namedtuple.
|
||||||
kwargs = config.copy()
|
kwargs = config.copy()
|
||||||
kwargs['parent'] = deserialize_feature_column(
|
kwargs['parent'] = deserialize_feature_column(
|
||||||
@ -2320,21 +2319,24 @@ class FeatureColumn(object):
|
|||||||
A serializable Dict that can be used to deserialize the object with
|
A serializable Dict that can be used to deserialize the object with
|
||||||
from_config.
|
from_config.
|
||||||
"""
|
"""
|
||||||
pass
|
return self._get_config()
|
||||||
|
|
||||||
|
def _get_config(self):
|
||||||
|
raise NotImplementedError('Must be implemented in subclasses.')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""Creates a FeatureColumn from its config.
|
"""Creates a FeatureColumn from its config.
|
||||||
|
|
||||||
This method should be the reverse of `_get_config`, capable of instantiating
|
This method should be the reverse of `get_config`, capable of instantiating
|
||||||
the same FeatureColumn from the config dictionary. See `_get_config` for an
|
the same FeatureColumn from the config dictionary. See `get_config` for an
|
||||||
example of common (de)serialization practices followed in this file.
|
example of common (de)serialization practices followed in this file.
|
||||||
|
|
||||||
TODO(b/118939620): This is a private method until consensus is reached on
|
TODO(b/118939620): This is a private method until consensus is reached on
|
||||||
supporting object deserialization deduping within Keras.
|
supporting object deserialization deduping within Keras.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: A Dict config acquired with `_get_config`.
|
config: A Dict config acquired with `get_config`.
|
||||||
custom_objects: Optional dictionary mapping names (strings) to custom
|
custom_objects: Optional dictionary mapping names (strings) to custom
|
||||||
classes or functions to be considered during deserialization.
|
classes or functions to be considered during deserialization.
|
||||||
columns_by_name: A Dict[String, FeatureColumn] of existing columns in
|
columns_by_name: A Dict[String, FeatureColumn] of existing columns in
|
||||||
@ -2344,7 +2346,11 @@ class FeatureColumn(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A FeatureColumn for the input config.
|
A FeatureColumn for the input config.
|
||||||
"""
|
"""
|
||||||
pass
|
return cls._from_config(config, custom_objects, columns_by_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
|
raise NotImplementedError('Must be implemented in subclasses.')
|
||||||
|
|
||||||
|
|
||||||
class DenseColumn(FeatureColumn):
|
class DenseColumn(FeatureColumn):
|
||||||
@ -2857,7 +2863,7 @@ class NumericColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.key]
|
return [self.key]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
config['normalizer_fn'] = generic_utils.serialize_keras_object(
|
config['normalizer_fn'] = generic_utils.serialize_keras_object(
|
||||||
@ -2866,7 +2872,7 @@ class NumericColumn(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
kwargs = _standardize_and_copy_config(config)
|
kwargs = _standardize_and_copy_config(config)
|
||||||
@ -3014,7 +3020,7 @@ class BucketizedColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.source_column]
|
return [self.source_column]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
@ -3022,7 +3028,7 @@ class BucketizedColumn(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
@ -3247,7 +3253,7 @@ class EmbeddingColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.categorical_column]
|
return [self.categorical_column]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
@ -3257,7 +3263,7 @@ class EmbeddingColumn(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
@ -3440,15 +3446,6 @@ class SharedEmbeddingColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.categorical_column]
|
return [self.categorical_column]
|
||||||
|
|
||||||
def _get_config(self):
|
|
||||||
"""See 'FeatureColumn` base class."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
|
||||||
"""See 'FeatureColumn` base class."""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
def _check_shape(shape, key):
|
def _check_shape(shape, key):
|
||||||
"""Returns shape if it's valid, raises error otherwise."""
|
"""Returns shape if it's valid, raises error otherwise."""
|
||||||
@ -3559,14 +3556,14 @@ class HashedCategoricalColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.key]
|
return [self.key]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
config['dtype'] = self.dtype.name
|
config['dtype'] = self.dtype.name
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
kwargs = _standardize_and_copy_config(config)
|
kwargs = _standardize_and_copy_config(config)
|
||||||
@ -3673,14 +3670,14 @@ class VocabularyFileCategoricalColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.key]
|
return [self.key]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
config['dtype'] = self.dtype.name
|
config['dtype'] = self.dtype.name
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
kwargs = _standardize_and_copy_config(config)
|
kwargs = _standardize_and_copy_config(config)
|
||||||
@ -3787,14 +3784,14 @@ class VocabularyListCategoricalColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.key]
|
return [self.key]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
config['dtype'] = self.dtype.name
|
config['dtype'] = self.dtype.name
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
kwargs = _standardize_and_copy_config(config)
|
kwargs = _standardize_and_copy_config(config)
|
||||||
@ -3899,12 +3896,12 @@ class IdentityCategoricalColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.key]
|
return [self.key]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return dict(zip(self._fields, self))
|
return dict(zip(self._fields, self))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
kwargs = _standardize_and_copy_config(config)
|
kwargs = _standardize_and_copy_config(config)
|
||||||
@ -4013,7 +4010,7 @@ class WeightedCategoricalColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.categorical_column, self.weight_feature_key]
|
return [self.categorical_column, self.weight_feature_key]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
@ -4023,7 +4020,7 @@ class WeightedCategoricalColumn(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
@ -4157,7 +4154,7 @@ class CrossedColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return list(self.keys)
|
return list(self.keys)
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
@ -4165,7 +4162,7 @@ class CrossedColumn(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
@ -4427,7 +4424,7 @@ class IndicatorColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.categorical_column]
|
return [self.categorical_column]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
@ -4436,7 +4433,7 @@ class IndicatorColumn(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
@ -4573,7 +4570,7 @@ class SequenceCategoricalColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.categorical_column]
|
return [self.categorical_column]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
@ -4582,7 +4579,7 @@ class SequenceCategoricalColumn(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
|
@ -81,10 +81,10 @@ class BaseFeatureColumnForTests(fc.FeatureColumn):
|
|||||||
raise ValueError('Should not use this method.')
|
raise ValueError('Should not use this method.')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
raise ValueError('Should not use this method.')
|
raise ValueError('Should not use this method.')
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
raise ValueError('Should not use this method.')
|
raise ValueError('Should not use this method.')
|
||||||
|
|
||||||
|
|
||||||
@ -478,7 +478,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
price = fc.numeric_column('price', normalizer_fn=_increment_two)
|
price = fc.numeric_column('price', normalizer_fn=_increment_two)
|
||||||
self.assertEqual(['price'], price.parents)
|
self.assertEqual(['price'], price.parents)
|
||||||
|
|
||||||
config = price._get_config()
|
config = price.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'key': 'price',
|
'key': 'price',
|
||||||
'shape': (1,),
|
'shape': (1,),
|
||||||
@ -487,7 +487,7 @@ class NumericColumnTest(test.TestCase):
|
|||||||
'normalizer_fn': '_increment_two'
|
'normalizer_fn': '_increment_two'
|
||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
new_col = fc.NumericColumn._from_config(
|
new_col = fc.NumericColumn.from_config(
|
||||||
config, custom_objects={'_increment_two': _increment_two})
|
config, custom_objects={'_increment_two': _increment_two})
|
||||||
self.assertEqual(price, new_col)
|
self.assertEqual(price, new_col)
|
||||||
self.assertEqual(new_col.shape, (1,))
|
self.assertEqual(new_col.shape, (1,))
|
||||||
@ -833,7 +833,7 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
|
||||||
self.assertEqual([price], bucketized_price.parents)
|
self.assertEqual([price], bucketized_price.parents)
|
||||||
|
|
||||||
config = bucketized_price._get_config()
|
config = bucketized_price.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'source_column': {
|
'source_column': {
|
||||||
'class_name': 'NumericColumn',
|
'class_name': 'NumericColumn',
|
||||||
@ -848,11 +848,11 @@ class BucketizedColumnTest(test.TestCase):
|
|||||||
'boundaries': (0, 2, 4, 6)
|
'boundaries': (0, 2, 4, 6)
|
||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
new_bucketized_price = fc.BucketizedColumn._from_config(config)
|
new_bucketized_price = fc.BucketizedColumn.from_config(config)
|
||||||
self.assertEqual(bucketized_price, new_bucketized_price)
|
self.assertEqual(bucketized_price, new_bucketized_price)
|
||||||
self.assertIsNot(price, new_bucketized_price.source_column)
|
self.assertIsNot(price, new_bucketized_price.source_column)
|
||||||
|
|
||||||
new_bucketized_price = fc.BucketizedColumn._from_config(
|
new_bucketized_price = fc.BucketizedColumn.from_config(
|
||||||
config,
|
config,
|
||||||
columns_by_name={
|
columns_by_name={
|
||||||
serialization._column_name_with_class_name(price): price
|
serialization._column_name_with_class_name(price): price
|
||||||
@ -1106,7 +1106,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
|
wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
|
||||||
self.assertEqual(['wire'], wire_column.parents)
|
self.assertEqual(['wire'], wire_column.parents)
|
||||||
|
|
||||||
config = wire_column._get_config()
|
config = wire_column.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'key': 'wire',
|
'key': 'wire',
|
||||||
'hash_bucket_size': 4,
|
'hash_bucket_size': 4,
|
||||||
@ -1114,7 +1114,7 @@ class HashedCategoricalColumnTest(test.TestCase):
|
|||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
self.assertEqual(wire_column,
|
self.assertEqual(wire_column,
|
||||||
fc.HashedCategoricalColumn._from_config(config))
|
fc.HashedCategoricalColumn.from_config(config))
|
||||||
|
|
||||||
|
|
||||||
class CrossedColumnTest(test.TestCase):
|
class CrossedColumnTest(test.TestCase):
|
||||||
@ -1588,7 +1588,7 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual([b, 'c'], crossed.parents)
|
self.assertEqual([b, 'c'], crossed.parents)
|
||||||
|
|
||||||
config = crossed._get_config()
|
config = crossed.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'hash_bucket_size':
|
'hash_bucket_size':
|
||||||
5,
|
5,
|
||||||
@ -1612,11 +1612,11 @@ class CrossedColumnTest(test.TestCase):
|
|||||||
}, 'c')
|
}, 'c')
|
||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
new_crossed = fc.CrossedColumn._from_config(config)
|
new_crossed = fc.CrossedColumn.from_config(config)
|
||||||
self.assertEqual(crossed, new_crossed)
|
self.assertEqual(crossed, new_crossed)
|
||||||
self.assertIsNot(b, new_crossed.keys[0])
|
self.assertIsNot(b, new_crossed.keys[0])
|
||||||
|
|
||||||
new_crossed = fc.CrossedColumn._from_config(
|
new_crossed = fc.CrossedColumn.from_config(
|
||||||
config,
|
config,
|
||||||
columns_by_name={serialization._column_name_with_class_name(b): b})
|
columns_by_name={serialization._column_name_with_class_name(b): b})
|
||||||
self.assertEqual(crossed, new_crossed)
|
self.assertEqual(crossed, new_crossed)
|
||||||
@ -4396,7 +4396,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(['wire'], wire_column.parents)
|
self.assertEqual(['wire'], wire_column.parents)
|
||||||
|
|
||||||
config = wire_column._get_config()
|
config = wire_column.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'default_value': -1,
|
'default_value': -1,
|
||||||
'dtype': 'string',
|
'dtype': 'string',
|
||||||
@ -4407,7 +4407,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
|
|||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
self.assertEqual(wire_column,
|
self.assertEqual(wire_column,
|
||||||
fc.VocabularyFileCategoricalColumn._from_config(config))
|
fc.VocabularyFileCategoricalColumn.from_config(config))
|
||||||
|
|
||||||
|
|
||||||
class VocabularyListCategoricalColumnTest(test.TestCase):
|
class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||||
@ -4859,7 +4859,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(['aaa'], wire_column.parents)
|
self.assertEqual(['aaa'], wire_column.parents)
|
||||||
|
|
||||||
config = wire_column._get_config()
|
config = wire_column.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'default_value': -1,
|
'default_value': -1,
|
||||||
'dtype': 'string',
|
'dtype': 'string',
|
||||||
@ -4869,7 +4869,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
self.assertEqual(wire_column,
|
self.assertEqual(wire_column,
|
||||||
fc.VocabularyListCategoricalColumn._from_config(config))
|
fc.VocabularyListCategoricalColumn.from_config(config))
|
||||||
|
|
||||||
|
|
||||||
class IdentityCategoricalColumnTest(test.TestCase):
|
class IdentityCategoricalColumnTest(test.TestCase):
|
||||||
@ -5218,14 +5218,14 @@ class IdentityCategoricalColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(['aaa'], column.parents)
|
self.assertEqual(['aaa'], column.parents)
|
||||||
|
|
||||||
config = column._get_config()
|
config = column.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'default_value': None,
|
'default_value': None,
|
||||||
'key': 'aaa',
|
'key': 'aaa',
|
||||||
'number_buckets': 3
|
'number_buckets': 3
|
||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
self.assertEqual(column, fc.IdentityCategoricalColumn._from_config(config))
|
self.assertEqual(column, fc.IdentityCategoricalColumn.from_config(config))
|
||||||
|
|
||||||
|
|
||||||
class TransformFeaturesTest(test.TestCase):
|
class TransformFeaturesTest(test.TestCase):
|
||||||
@ -5600,7 +5600,7 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual([parent], animal.parents)
|
self.assertEqual([parent], animal.parents)
|
||||||
|
|
||||||
config = animal._get_config()
|
config = animal.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'categorical_column': {
|
'categorical_column': {
|
||||||
'class_name': 'IdentityCategoricalColumn',
|
'class_name': 'IdentityCategoricalColumn',
|
||||||
@ -5612,11 +5612,11 @@ class IndicatorColumnTest(test.TestCase):
|
|||||||
}
|
}
|
||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
new_animal = fc.IndicatorColumn._from_config(config)
|
new_animal = fc.IndicatorColumn.from_config(config)
|
||||||
self.assertEqual(animal, new_animal)
|
self.assertEqual(animal, new_animal)
|
||||||
self.assertIsNot(parent, new_animal.categorical_column)
|
self.assertIsNot(parent, new_animal.categorical_column)
|
||||||
|
|
||||||
new_animal = fc.IndicatorColumn._from_config(
|
new_animal = fc.IndicatorColumn.from_config(
|
||||||
config,
|
config,
|
||||||
columns_by_name={
|
columns_by_name={
|
||||||
serialization._column_name_with_class_name(parent): parent
|
serialization._column_name_with_class_name(parent): parent
|
||||||
@ -6605,7 +6605,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual([categorical_column], embedding_column.parents)
|
self.assertEqual([categorical_column], embedding_column.parents)
|
||||||
|
|
||||||
config = embedding_column._get_config()
|
config = embedding_column.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'categorical_column': {
|
'categorical_column': {
|
||||||
'class_name': 'IdentityCategoricalColumn',
|
'class_name': 'IdentityCategoricalColumn',
|
||||||
@ -6633,22 +6633,22 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal}
|
custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal}
|
||||||
new_embedding_column = fc.EmbeddingColumn._from_config(
|
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||||
config, custom_objects=custom_objects)
|
config, custom_objects=custom_objects)
|
||||||
self.assertEqual(embedding_column._get_config(),
|
self.assertEqual(embedding_column.get_config(),
|
||||||
new_embedding_column._get_config())
|
new_embedding_column.get_config())
|
||||||
self.assertIsNot(categorical_column,
|
self.assertIsNot(categorical_column,
|
||||||
new_embedding_column.categorical_column)
|
new_embedding_column.categorical_column)
|
||||||
|
|
||||||
new_embedding_column = fc.EmbeddingColumn._from_config(
|
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||||
config,
|
config,
|
||||||
custom_objects=custom_objects,
|
custom_objects=custom_objects,
|
||||||
columns_by_name={
|
columns_by_name={
|
||||||
serialization._column_name_with_class_name(categorical_column):
|
serialization._column_name_with_class_name(categorical_column):
|
||||||
categorical_column
|
categorical_column
|
||||||
})
|
})
|
||||||
self.assertEqual(embedding_column._get_config(),
|
self.assertEqual(embedding_column.get_config(),
|
||||||
new_embedding_column._get_config())
|
new_embedding_column.get_config())
|
||||||
self.assertIs(categorical_column, new_embedding_column.categorical_column)
|
self.assertIs(categorical_column, new_embedding_column.categorical_column)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@ -6666,7 +6666,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual([categorical_column], embedding_column.parents)
|
self.assertEqual([categorical_column], embedding_column.parents)
|
||||||
|
|
||||||
config = embedding_column._get_config()
|
config = embedding_column.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'categorical_column': {
|
'categorical_column': {
|
||||||
'class_name': 'IdentityCategoricalColumn',
|
'class_name': 'IdentityCategoricalColumn',
|
||||||
@ -6689,13 +6689,13 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
'_initializer': _initializer,
|
'_initializer': _initializer,
|
||||||
}
|
}
|
||||||
|
|
||||||
new_embedding_column = fc.EmbeddingColumn._from_config(
|
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||||
config, custom_objects=custom_objects)
|
config, custom_objects=custom_objects)
|
||||||
self.assertEqual(embedding_column, new_embedding_column)
|
self.assertEqual(embedding_column, new_embedding_column)
|
||||||
self.assertIsNot(categorical_column,
|
self.assertIsNot(categorical_column,
|
||||||
new_embedding_column.categorical_column)
|
new_embedding_column.categorical_column)
|
||||||
|
|
||||||
new_embedding_column = fc.EmbeddingColumn._from_config(
|
new_embedding_column = fc.EmbeddingColumn.from_config(
|
||||||
config,
|
config,
|
||||||
custom_objects=custom_objects,
|
custom_objects=custom_objects,
|
||||||
columns_by_name={
|
columns_by_name={
|
||||||
@ -7763,7 +7763,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertEqual([categorical_column, 'weight'], column.parents)
|
self.assertEqual([categorical_column, 'weight'], column.parents)
|
||||||
|
|
||||||
config = column._get_config()
|
config = column.get_config()
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
'categorical_column': {
|
'categorical_column': {
|
||||||
'config': {
|
'config': {
|
||||||
@ -7777,9 +7777,9 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
|||||||
'weight_feature_key': 'weight'
|
'weight_feature_key': 'weight'
|
||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
self.assertEqual(column, fc.WeightedCategoricalColumn._from_config(config))
|
self.assertEqual(column, fc.WeightedCategoricalColumn.from_config(config))
|
||||||
|
|
||||||
new_column = fc.WeightedCategoricalColumn._from_config(
|
new_column = fc.WeightedCategoricalColumn.from_config(
|
||||||
config,
|
config,
|
||||||
columns_by_name={
|
columns_by_name={
|
||||||
serialization._column_name_with_class_name(categorical_column):
|
serialization._column_name_with_class_name(categorical_column):
|
||||||
|
@ -582,7 +582,7 @@ class SequenceNumericColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
return [self.key]
|
return [self.key]
|
||||||
|
|
||||||
def _get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
config['normalizer_fn'] = utils.serialize_keras_object(self.normalizer_fn)
|
config['normalizer_fn'] = utils.serialize_keras_object(self.normalizer_fn)
|
||||||
@ -590,7 +590,7 @@ class SequenceNumericColumn(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
fc._check_config_keys(config, cls._fields)
|
fc._check_config_keys(config, cls._fields)
|
||||||
kwargs = fc._standardize_and_copy_config(config)
|
kwargs = fc._standardize_and_copy_config(config)
|
||||||
|
@ -765,7 +765,7 @@ class SequenceCategoricalColumnWithIdentityTest(
|
|||||||
'animal', num_buckets=4)
|
'animal', num_buckets=4)
|
||||||
animal = fc.indicator_column(parent)
|
animal = fc.indicator_column(parent)
|
||||||
|
|
||||||
config = animal._get_config()
|
config = animal.get_config()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{
|
{
|
||||||
'categorical_column': {
|
'categorical_column': {
|
||||||
@ -783,11 +783,11 @@ class SequenceCategoricalColumnWithIdentityTest(
|
|||||||
}
|
}
|
||||||
}, config)
|
}, config)
|
||||||
|
|
||||||
new_animal = fc.IndicatorColumn._from_config(config)
|
new_animal = fc.IndicatorColumn.from_config(config)
|
||||||
self.assertEqual(animal, new_animal)
|
self.assertEqual(animal, new_animal)
|
||||||
self.assertIsNot(parent, new_animal.categorical_column)
|
self.assertIsNot(parent, new_animal.categorical_column)
|
||||||
|
|
||||||
new_animal = fc.IndicatorColumn._from_config(
|
new_animal = fc.IndicatorColumn.from_config(
|
||||||
config,
|
config,
|
||||||
columns_by_name={
|
columns_by_name={
|
||||||
serialization._column_name_with_class_name(parent): parent
|
serialization._column_name_with_class_name(parent): parent
|
||||||
|
@ -45,14 +45,14 @@ def serialize_feature_column(fc):
|
|||||||
"""Serializes a FeatureColumn or a raw string key.
|
"""Serializes a FeatureColumn or a raw string key.
|
||||||
|
|
||||||
This method should only be used to serialize parent FeatureColumns when
|
This method should only be used to serialize parent FeatureColumns when
|
||||||
implementing FeatureColumn._get_config(), else serialize_feature_columns()
|
implementing FeatureColumn.get_config(), else serialize_feature_columns()
|
||||||
is preferable.
|
is preferable.
|
||||||
|
|
||||||
This serialization also keeps information of the FeatureColumn class, so
|
This serialization also keeps information of the FeatureColumn class, so
|
||||||
deserialization is possible without knowing the class type. For example:
|
deserialization is possible without knowing the class type. For example:
|
||||||
|
|
||||||
a = numeric_column('x')
|
a = numeric_column('x')
|
||||||
a._get_config() gives:
|
a.get_config() gives:
|
||||||
{
|
{
|
||||||
'key': 'price',
|
'key': 'price',
|
||||||
'shape': (1,),
|
'shape': (1,),
|
||||||
@ -85,7 +85,7 @@ def serialize_feature_column(fc):
|
|||||||
return fc
|
return fc
|
||||||
elif isinstance(fc, fc_lib.FeatureColumn):
|
elif isinstance(fc, fc_lib.FeatureColumn):
|
||||||
return generic_utils.serialize_keras_class_and_config(
|
return generic_utils.serialize_keras_class_and_config(
|
||||||
fc.__class__.__name__, fc._get_config()) # pylint: disable=protected-access
|
fc.__class__.__name__, fc.get_config()) # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
|
raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ def deserialize_feature_column(config,
|
|||||||
"""Deserializes a `config` generated with `serialize_feature_column`.
|
"""Deserializes a `config` generated with `serialize_feature_column`.
|
||||||
|
|
||||||
This method should only be used to deserialize parent FeatureColumns when
|
This method should only be used to deserialize parent FeatureColumns when
|
||||||
implementing FeatureColumn._from_config(), else deserialize_feature_columns()
|
implementing FeatureColumn.from_config(), else deserialize_feature_columns()
|
||||||
is preferable. Returns a FeatureColumn for this config.
|
is preferable. Returns a FeatureColumn for this config.
|
||||||
TODO(b/118939620): Simplify code if Keras utils support object deduping.
|
TODO(b/118939620): Simplify code if Keras utils support object deduping.
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ def deserialize_feature_column(config,
|
|||||||
'Expected FeatureColumn class, instead found: {}'.format(cls))
|
'Expected FeatureColumn class, instead found: {}'.format(cls))
|
||||||
|
|
||||||
# Always deserialize the FeatureColumn, in order to get the name.
|
# Always deserialize the FeatureColumn, in order to get the name.
|
||||||
new_instance = cls._from_config( # pylint: disable=protected-access
|
new_instance = cls.from_config( # pylint: disable=protected-access
|
||||||
cls_config,
|
cls_config,
|
||||||
custom_objects=custom_objects,
|
custom_objects=custom_objects,
|
||||||
columns_by_name=columns_by_name)
|
columns_by_name=columns_by_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user