diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 6179658bfc8..48c97e3b69e 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -2265,8 +2265,7 @@ class FeatureColumn(object): """ pass - @abc.abstractmethod - def _get_config(self): + def get_config(self): """Returns the config of the feature column. A FeatureColumn config is a Python dictionary (serializable) containing the @@ -2283,7 +2282,7 @@ class FeatureColumn(object): 'SerializationExampleFeatureColumn', ('dimension', 'parent', 'dtype', 'normalizer_fn'))): - def _get_config(self): + def get_config(self): # Create a dict from the namedtuple. # Python attribute literals can be directly copied from / to the config. # For example 'dimension', assuming it is an integer literal. @@ -2304,8 +2303,8 @@ class FeatureColumn(object): return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): - # This should do the inverse transform from `_get_config` and construct + def from_config(cls, config, custom_objects=None, columns_by_name=None): + # This should do the inverse transform from `get_config` and construct # the namedtuple. kwargs = config.copy() kwargs['parent'] = deserialize_feature_column( @@ -2320,21 +2319,24 @@ class FeatureColumn(object): A serializable Dict that can be used to deserialize the object with from_config. """ - pass + return self._get_config() + + def _get_config(self): + raise NotImplementedError('Must be implemented in subclasses.') @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """Creates a FeatureColumn from its config. - This method should be the reverse of `_get_config`, capable of instantiating - the same FeatureColumn from the config dictionary. See `_get_config` for an + This method should be the reverse of `get_config`, capable of instantiating + the same FeatureColumn from the config dictionary. See `get_config` for an example of common (de)serialization practices followed in this file. TODO(b/118939620): This is a private method until consensus is reached on supporting object deserialization deduping within Keras. Args: - config: A Dict config acquired with `_get_config`. + config: A Dict config acquired with `get_config`. custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. columns_by_name: A Dict[String, FeatureColumn] of existing columns in @@ -2344,7 +2346,11 @@ class FeatureColumn(object): Returns: A FeatureColumn for the input config. """ - pass + return cls._from_config(config, custom_objects, columns_by_name) + + @classmethod + def _from_config(cls, config, custom_objects=None, columns_by_name=None): + raise NotImplementedError('Must be implemented in subclasses.') class DenseColumn(FeatureColumn): @@ -2857,7 +2863,7 @@ class NumericColumn( """See 'FeatureColumn` base class.""" return [self.key] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" config = dict(zip(self._fields, self)) config['normalizer_fn'] = generic_utils.serialize_keras_object( @@ -2866,7 +2872,7 @@ class NumericColumn( return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" _check_config_keys(config, cls._fields) kwargs = _standardize_and_copy_config(config) @@ -3014,7 +3020,7 @@ class BucketizedColumn( """See 'FeatureColumn` base class.""" return [self.source_column] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top config = dict(zip(self._fields, self)) @@ -3022,7 +3028,7 @@ class BucketizedColumn( return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top _check_config_keys(config, cls._fields) @@ -3247,7 +3253,7 @@ class EmbeddingColumn( """See 'FeatureColumn` base class.""" return [self.categorical_column] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top config = dict(zip(self._fields, self)) @@ -3257,7 +3263,7 @@ class EmbeddingColumn( return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top _check_config_keys(config, cls._fields) @@ -3440,15 +3446,6 @@ class SharedEmbeddingColumn( """See 'FeatureColumn` base class.""" return [self.categorical_column] - def _get_config(self): - """See 'FeatureColumn` base class.""" - raise NotImplementedError() - - @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): - """See 'FeatureColumn` base class.""" - raise NotImplementedError() - def _check_shape(shape, key): """Returns shape if it's valid, raises error otherwise.""" @@ -3559,14 +3556,14 @@ class HashedCategoricalColumn( """See 'FeatureColumn` base class.""" return [self.key] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" config = dict(zip(self._fields, self)) config['dtype'] = self.dtype.name return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" _check_config_keys(config, cls._fields) kwargs = _standardize_and_copy_config(config) @@ -3673,14 +3670,14 @@ class VocabularyFileCategoricalColumn( """See 'FeatureColumn` base class.""" return [self.key] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" config = dict(zip(self._fields, self)) config['dtype'] = self.dtype.name return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" _check_config_keys(config, cls._fields) kwargs = _standardize_and_copy_config(config) @@ -3787,14 +3784,14 @@ class VocabularyListCategoricalColumn( """See 'FeatureColumn` base class.""" return [self.key] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" config = dict(zip(self._fields, self)) config['dtype'] = self.dtype.name return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" _check_config_keys(config, cls._fields) kwargs = _standardize_and_copy_config(config) @@ -3899,12 +3896,12 @@ class IdentityCategoricalColumn( """See 'FeatureColumn` base class.""" return [self.key] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" return dict(zip(self._fields, self)) @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" _check_config_keys(config, cls._fields) kwargs = _standardize_and_copy_config(config) @@ -4013,7 +4010,7 @@ class WeightedCategoricalColumn( """See 'FeatureColumn` base class.""" return [self.categorical_column, self.weight_feature_key] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top config = dict(zip(self._fields, self)) @@ -4023,7 +4020,7 @@ class WeightedCategoricalColumn( return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top _check_config_keys(config, cls._fields) @@ -4157,7 +4154,7 @@ class CrossedColumn( """See 'FeatureColumn` base class.""" return list(self.keys) - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top config = dict(zip(self._fields, self)) @@ -4165,7 +4162,7 @@ class CrossedColumn( return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top _check_config_keys(config, cls._fields) @@ -4427,7 +4424,7 @@ class IndicatorColumn( """See 'FeatureColumn` base class.""" return [self.categorical_column] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top config = dict(zip(self._fields, self)) @@ -4436,7 +4433,7 @@ class IndicatorColumn( return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top _check_config_keys(config, cls._fields) @@ -4573,7 +4570,7 @@ class SequenceCategoricalColumn( """See 'FeatureColumn` base class.""" return [self.categorical_column] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top config = dict(zip(self._fields, self)) @@ -4582,7 +4579,7 @@ class SequenceCategoricalColumn( return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top _check_config_keys(config, cls._fields) diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 8e46356ec55..85c21c0d38e 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -81,10 +81,10 @@ class BaseFeatureColumnForTests(fc.FeatureColumn): raise ValueError('Should not use this method.') @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): raise ValueError('Should not use this method.') - def _get_config(self): + def get_config(self): raise ValueError('Should not use this method.') @@ -478,7 +478,7 @@ class NumericColumnTest(test.TestCase): price = fc.numeric_column('price', normalizer_fn=_increment_two) self.assertEqual(['price'], price.parents) - config = price._get_config() + config = price.get_config() self.assertEqual({ 'key': 'price', 'shape': (1,), @@ -487,7 +487,7 @@ class NumericColumnTest(test.TestCase): 'normalizer_fn': '_increment_two' }, config) - new_col = fc.NumericColumn._from_config( + new_col = fc.NumericColumn.from_config( config, custom_objects={'_increment_two': _increment_two}) self.assertEqual(price, new_col) self.assertEqual(new_col.shape, (1,)) @@ -833,7 +833,7 @@ class BucketizedColumnTest(test.TestCase): bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) self.assertEqual([price], bucketized_price.parents) - config = bucketized_price._get_config() + config = bucketized_price.get_config() self.assertEqual({ 'source_column': { 'class_name': 'NumericColumn', @@ -848,11 +848,11 @@ class BucketizedColumnTest(test.TestCase): 'boundaries': (0, 2, 4, 6) }, config) - new_bucketized_price = fc.BucketizedColumn._from_config(config) + new_bucketized_price = fc.BucketizedColumn.from_config(config) self.assertEqual(bucketized_price, new_bucketized_price) self.assertIsNot(price, new_bucketized_price.source_column) - new_bucketized_price = fc.BucketizedColumn._from_config( + new_bucketized_price = fc.BucketizedColumn.from_config( config, columns_by_name={ serialization._column_name_with_class_name(price): price @@ -1106,7 +1106,7 @@ class HashedCategoricalColumnTest(test.TestCase): wire_column = fc.categorical_column_with_hash_bucket('wire', 4) self.assertEqual(['wire'], wire_column.parents) - config = wire_column._get_config() + config = wire_column.get_config() self.assertEqual({ 'key': 'wire', 'hash_bucket_size': 4, @@ -1114,7 +1114,7 @@ class HashedCategoricalColumnTest(test.TestCase): }, config) self.assertEqual(wire_column, - fc.HashedCategoricalColumn._from_config(config)) + fc.HashedCategoricalColumn.from_config(config)) class CrossedColumnTest(test.TestCase): @@ -1588,7 +1588,7 @@ class CrossedColumnTest(test.TestCase): self.assertEqual([b, 'c'], crossed.parents) - config = crossed._get_config() + config = crossed.get_config() self.assertEqual({ 'hash_bucket_size': 5, @@ -1612,11 +1612,11 @@ class CrossedColumnTest(test.TestCase): }, 'c') }, config) - new_crossed = fc.CrossedColumn._from_config(config) + new_crossed = fc.CrossedColumn.from_config(config) self.assertEqual(crossed, new_crossed) self.assertIsNot(b, new_crossed.keys[0]) - new_crossed = fc.CrossedColumn._from_config( + new_crossed = fc.CrossedColumn.from_config( config, columns_by_name={serialization._column_name_with_class_name(b): b}) self.assertEqual(crossed, new_crossed) @@ -4396,7 +4396,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): self.assertEqual(['wire'], wire_column.parents) - config = wire_column._get_config() + config = wire_column.get_config() self.assertEqual({ 'default_value': -1, 'dtype': 'string', @@ -4407,7 +4407,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): }, config) self.assertEqual(wire_column, - fc.VocabularyFileCategoricalColumn._from_config(config)) + fc.VocabularyFileCategoricalColumn.from_config(config)) class VocabularyListCategoricalColumnTest(test.TestCase): @@ -4859,7 +4859,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase): self.assertEqual(['aaa'], wire_column.parents) - config = wire_column._get_config() + config = wire_column.get_config() self.assertEqual({ 'default_value': -1, 'dtype': 'string', @@ -4869,7 +4869,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase): }, config) self.assertEqual(wire_column, - fc.VocabularyListCategoricalColumn._from_config(config)) + fc.VocabularyListCategoricalColumn.from_config(config)) class IdentityCategoricalColumnTest(test.TestCase): @@ -5218,14 +5218,14 @@ class IdentityCategoricalColumnTest(test.TestCase): self.assertEqual(['aaa'], column.parents) - config = column._get_config() + config = column.get_config() self.assertEqual({ 'default_value': None, 'key': 'aaa', 'number_buckets': 3 }, config) - self.assertEqual(column, fc.IdentityCategoricalColumn._from_config(config)) + self.assertEqual(column, fc.IdentityCategoricalColumn.from_config(config)) class TransformFeaturesTest(test.TestCase): @@ -5600,7 +5600,7 @@ class IndicatorColumnTest(test.TestCase): self.assertEqual([parent], animal.parents) - config = animal._get_config() + config = animal.get_config() self.assertEqual({ 'categorical_column': { 'class_name': 'IdentityCategoricalColumn', @@ -5612,11 +5612,11 @@ class IndicatorColumnTest(test.TestCase): } }, config) - new_animal = fc.IndicatorColumn._from_config(config) + new_animal = fc.IndicatorColumn.from_config(config) self.assertEqual(animal, new_animal) self.assertIsNot(parent, new_animal.categorical_column) - new_animal = fc.IndicatorColumn._from_config( + new_animal = fc.IndicatorColumn.from_config( config, columns_by_name={ serialization._column_name_with_class_name(parent): parent @@ -6605,7 +6605,7 @@ class EmbeddingColumnTest(test.TestCase): self.assertEqual([categorical_column], embedding_column.parents) - config = embedding_column._get_config() + config = embedding_column.get_config() self.assertEqual({ 'categorical_column': { 'class_name': 'IdentityCategoricalColumn', @@ -6633,22 +6633,22 @@ class EmbeddingColumnTest(test.TestCase): }, config) custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal} - new_embedding_column = fc.EmbeddingColumn._from_config( + new_embedding_column = fc.EmbeddingColumn.from_config( config, custom_objects=custom_objects) - self.assertEqual(embedding_column._get_config(), - new_embedding_column._get_config()) + self.assertEqual(embedding_column.get_config(), + new_embedding_column.get_config()) self.assertIsNot(categorical_column, new_embedding_column.categorical_column) - new_embedding_column = fc.EmbeddingColumn._from_config( + new_embedding_column = fc.EmbeddingColumn.from_config( config, custom_objects=custom_objects, columns_by_name={ serialization._column_name_with_class_name(categorical_column): categorical_column }) - self.assertEqual(embedding_column._get_config(), - new_embedding_column._get_config()) + self.assertEqual(embedding_column.get_config(), + new_embedding_column.get_config()) self.assertIs(categorical_column, new_embedding_column.categorical_column) @test_util.run_deprecated_v1 @@ -6666,7 +6666,7 @@ class EmbeddingColumnTest(test.TestCase): self.assertEqual([categorical_column], embedding_column.parents) - config = embedding_column._get_config() + config = embedding_column.get_config() self.assertEqual({ 'categorical_column': { 'class_name': 'IdentityCategoricalColumn', @@ -6689,13 +6689,13 @@ class EmbeddingColumnTest(test.TestCase): '_initializer': _initializer, } - new_embedding_column = fc.EmbeddingColumn._from_config( + new_embedding_column = fc.EmbeddingColumn.from_config( config, custom_objects=custom_objects) self.assertEqual(embedding_column, new_embedding_column) self.assertIsNot(categorical_column, new_embedding_column.categorical_column) - new_embedding_column = fc.EmbeddingColumn._from_config( + new_embedding_column = fc.EmbeddingColumn.from_config( config, custom_objects=custom_objects, columns_by_name={ @@ -7763,7 +7763,7 @@ class WeightedCategoricalColumnTest(test.TestCase): self.assertEqual([categorical_column, 'weight'], column.parents) - config = column._get_config() + config = column.get_config() self.assertEqual({ 'categorical_column': { 'config': { @@ -7777,9 +7777,9 @@ class WeightedCategoricalColumnTest(test.TestCase): 'weight_feature_key': 'weight' }, config) - self.assertEqual(column, fc.WeightedCategoricalColumn._from_config(config)) + self.assertEqual(column, fc.WeightedCategoricalColumn.from_config(config)) - new_column = fc.WeightedCategoricalColumn._from_config( + new_column = fc.WeightedCategoricalColumn.from_config( config, columns_by_name={ serialization._column_name_with_class_name(categorical_column): diff --git a/tensorflow/python/feature_column/sequence_feature_column.py b/tensorflow/python/feature_column/sequence_feature_column.py index 53f2d3e85e5..8dce0926a23 100644 --- a/tensorflow/python/feature_column/sequence_feature_column.py +++ b/tensorflow/python/feature_column/sequence_feature_column.py @@ -582,7 +582,7 @@ class SequenceNumericColumn( """See 'FeatureColumn` base class.""" return [self.key] - def _get_config(self): + def get_config(self): """See 'FeatureColumn` base class.""" config = dict(zip(self._fields, self)) config['normalizer_fn'] = utils.serialize_keras_object(self.normalizer_fn) @@ -590,7 +590,7 @@ class SequenceNumericColumn( return config @classmethod - def _from_config(cls, config, custom_objects=None, columns_by_name=None): + def from_config(cls, config, custom_objects=None, columns_by_name=None): """See 'FeatureColumn` base class.""" fc._check_config_keys(config, cls._fields) kwargs = fc._standardize_and_copy_config(config) diff --git a/tensorflow/python/feature_column/sequence_feature_column_test.py b/tensorflow/python/feature_column/sequence_feature_column_test.py index 8c269a0b800..662a826bd29 100644 --- a/tensorflow/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/python/feature_column/sequence_feature_column_test.py @@ -765,7 +765,7 @@ class SequenceCategoricalColumnWithIdentityTest( 'animal', num_buckets=4) animal = fc.indicator_column(parent) - config = animal._get_config() + config = animal.get_config() self.assertEqual( { 'categorical_column': { @@ -783,11 +783,11 @@ class SequenceCategoricalColumnWithIdentityTest( } }, config) - new_animal = fc.IndicatorColumn._from_config(config) + new_animal = fc.IndicatorColumn.from_config(config) self.assertEqual(animal, new_animal) self.assertIsNot(parent, new_animal.categorical_column) - new_animal = fc.IndicatorColumn._from_config( + new_animal = fc.IndicatorColumn.from_config( config, columns_by_name={ serialization._column_name_with_class_name(parent): parent diff --git a/tensorflow/python/feature_column/serialization.py b/tensorflow/python/feature_column/serialization.py index 1bec4cba6bf..81058d7b657 100644 --- a/tensorflow/python/feature_column/serialization.py +++ b/tensorflow/python/feature_column/serialization.py @@ -45,14 +45,14 @@ def serialize_feature_column(fc): """Serializes a FeatureColumn or a raw string key. This method should only be used to serialize parent FeatureColumns when - implementing FeatureColumn._get_config(), else serialize_feature_columns() + implementing FeatureColumn.get_config(), else serialize_feature_columns() is preferable. This serialization also keeps information of the FeatureColumn class, so deserialization is possible without knowing the class type. For example: a = numeric_column('x') - a._get_config() gives: + a.get_config() gives: { 'key': 'price', 'shape': (1,), @@ -85,7 +85,7 @@ def serialize_feature_column(fc): return fc elif isinstance(fc, fc_lib.FeatureColumn): return generic_utils.serialize_keras_class_and_config( - fc.__class__.__name__, fc._get_config()) # pylint: disable=protected-access + fc.__class__.__name__, fc.get_config()) # pylint: disable=protected-access else: raise ValueError('Instance: {} is not a FeatureColumn'.format(fc)) @@ -96,7 +96,7 @@ def deserialize_feature_column(config, """Deserializes a `config` generated with `serialize_feature_column`. This method should only be used to deserialize parent FeatureColumns when - implementing FeatureColumn._from_config(), else deserialize_feature_columns() + implementing FeatureColumn.from_config(), else deserialize_feature_columns() is preferable. Returns a FeatureColumn for this config. TODO(b/118939620): Simplify code if Keras utils support object deduping. @@ -136,7 +136,7 @@ def deserialize_feature_column(config, 'Expected FeatureColumn class, instead found: {}'.format(cls)) # Always deserialize the FeatureColumn, in order to get the name. - new_instance = cls._from_config( # pylint: disable=protected-access + new_instance = cls.from_config( # pylint: disable=protected-access cls_config, custom_objects=custom_objects, columns_by_name=columns_by_name)