Add serialization support for feature columns in DenseFeatures layers in Keras.
PiperOrigin-RevId: 251656649
This commit is contained in:
parent
14d492d7b4
commit
cd09510f92
tensorflow
python
feature_column
keras
tools/api/golden
@ -46,7 +46,7 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:engine",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
@ -84,7 +84,9 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras:engine",
|
||||
"//tensorflow/python/keras:generic_utils",
|
||||
"//tensorflow/python/keras:layers_base",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -233,3 +235,16 @@ py_test(
|
||||
"//tensorflow/python/keras:layers",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "serialization_test",
|
||||
srcs = ["serialization_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":feature_column_v2",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:util",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -144,7 +144,8 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
|
||||
# of the main repo.
|
||||
from tensorflow.python.keras import utils
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -478,6 +479,27 @@ class DenseFeatures(_BaseFeaturesLayer):
|
||||
output_tensors.append(processed_tensors)
|
||||
return self._verify_and_concat_tensors(output_tensors)
|
||||
|
||||
def get_config(self):
|
||||
# Import here to avoid circular imports.
|
||||
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||
column_configs = serialization.serialize_feature_columns(
|
||||
self._feature_columns)
|
||||
config = {'feature_columns': column_configs}
|
||||
|
||||
base_config = super( # pylint: disable=bad-super-call
|
||||
DenseFeatures, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
# Import here to avoid circular imports.
|
||||
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||
config_cp = config.copy()
|
||||
config_cp['feature_columns'] = serialization.deserialize_feature_columns(
|
||||
config['feature_columns'], custom_objects=custom_objects)
|
||||
|
||||
return cls(**config_cp)
|
||||
|
||||
|
||||
class _LinearModelLayer(Layer):
|
||||
"""Layer that contains logic for `LinearModel`."""
|
||||
@ -572,6 +594,32 @@ class _LinearModelLayer(Layer):
|
||||
predictions_no_bias, self.bias, name='weighted_sum')
|
||||
return predictions
|
||||
|
||||
def get_config(self):
|
||||
# Import here to avoid circular imports.
|
||||
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||
column_configs = serialization.serialize_feature_columns(
|
||||
self._feature_columns)
|
||||
config = {
|
||||
'feature_columns': column_configs,
|
||||
'units': self._units,
|
||||
'sparse_combiner': self._sparse_combiner
|
||||
}
|
||||
|
||||
base_config = super( # pylint: disable=bad-super-call
|
||||
_LinearModelLayer, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
# Import here to avoid circular imports.
|
||||
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||
config_cp = config.copy()
|
||||
columns = serialization.deserialize_feature_columns(
|
||||
config_cp['feature_columns'], custom_objects=custom_objects)
|
||||
|
||||
del config_cp['feature_columns']
|
||||
return cls(feature_columns=columns, **config_cp)
|
||||
|
||||
|
||||
class LinearModel(training.Model):
|
||||
"""Produces a linear prediction `Tensor` based on given `feature_columns`.
|
||||
@ -2251,7 +2299,7 @@ class FeatureColumn(object):
|
||||
config['dtype'] = self.dtype.name
|
||||
|
||||
# Non-trivial dependencies should be Keras-(de)serializable.
|
||||
config['normalizer_fn'] = utils.serialize_keras_object(
|
||||
config['normalizer_fn'] = generic_utils.serialize_keras_object(
|
||||
self.normalizer_fn)
|
||||
|
||||
return config
|
||||
@ -2264,7 +2312,7 @@ class FeatureColumn(object):
|
||||
kwargs['parent'] = deserialize_feature_column(
|
||||
config['parent'], custom_objects, columns_by_name)
|
||||
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
|
||||
kwargs['normalizer_fn'] = utils.deserialize_keras_object(
|
||||
kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
|
||||
config['normalizer_fn'], custom_objects=custom_objects)
|
||||
return cls(**kwargs)
|
||||
|
||||
@ -2813,7 +2861,8 @@ class NumericColumn(
|
||||
def _get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
config = dict(zip(self._fields, self))
|
||||
config['normalizer_fn'] = utils.serialize_keras_object(self.normalizer_fn)
|
||||
config['normalizer_fn'] = generic_utils.serialize_keras_object(
|
||||
self.normalizer_fn)
|
||||
config['dtype'] = self.dtype.name
|
||||
return config
|
||||
|
||||
@ -2822,9 +2871,18 @@ class NumericColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
_check_config_keys(config, cls._fields)
|
||||
kwargs = config.copy()
|
||||
kwargs['normalizer_fn'] = utils.deserialize_keras_object(
|
||||
kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
|
||||
config['normalizer_fn'], custom_objects=custom_objects)
|
||||
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
|
||||
|
||||
# Keras serialization uses nest to listify everything.
|
||||
# This causes problems with the NumericColumn shape, which becomes
|
||||
# unhashable. We could try to solve this on the Keras side, but that
|
||||
# would require lots of tracking to avoid changing existing behavior.
|
||||
# Instead, we ensure here that we revive correctly.
|
||||
if isinstance(config['shape'], list):
|
||||
kwargs['shape'] = tuple(config['shape'])
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
@ -3199,7 +3257,7 @@ class EmbeddingColumn(
|
||||
config = dict(zip(self._fields, self))
|
||||
config['categorical_column'] = serialize_feature_column(
|
||||
self.categorical_column)
|
||||
config['initializer'] = utils.serialize_keras_object(self.initializer)
|
||||
config['initializer'] = initializers.serialize(self.initializer)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
@ -3210,7 +3268,7 @@ class EmbeddingColumn(
|
||||
kwargs = config.copy()
|
||||
kwargs['categorical_column'] = deserialize_feature_column(
|
||||
config['categorical_column'], custom_objects, columns_by_name)
|
||||
kwargs['initializer'] = utils.deserialize_keras_object(
|
||||
kwargs['initializer'] = initializers.deserialize(
|
||||
config['initializer'], custom_objects=custom_objects)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
@ -33,7 +33,6 @@ from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.feature_column import feature_column as fc_old
|
||||
from tensorflow.python.feature_column import feature_column_v2 as fc
|
||||
from tensorflow.python.feature_column import serialization
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -463,10 +462,10 @@ class NumericColumnTest(test.TestCase):
|
||||
'normalizer_fn': '_increment_two'
|
||||
}, config)
|
||||
|
||||
self.assertEqual(
|
||||
price,
|
||||
fc.NumericColumn._from_config(
|
||||
config, custom_objects={'_increment_two': _increment_two}))
|
||||
new_col = fc.NumericColumn._from_config(
|
||||
config, custom_objects={'_increment_two': _increment_two})
|
||||
self.assertEqual(price, new_col)
|
||||
self.assertEqual(new_col.shape, (1,))
|
||||
|
||||
|
||||
class BucketizedColumnTest(test.TestCase):
|
||||
@ -8347,109 +8346,5 @@ class WeightedCategoricalColumnTest(test.TestCase):
|
||||
self.assertEqual(column, new_column)
|
||||
self.assertIs(categorical_column, new_column.categorical_column)
|
||||
|
||||
|
||||
class FeatureColumnForSerializationTest(BaseFeatureColumnForTests):
|
||||
|
||||
@property
|
||||
def _is_v2_column(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'BadParentsFeatureColumn'
|
||||
|
||||
def transform_feature(self, transformation_cache, state_manager):
|
||||
return 'Output'
|
||||
|
||||
@property
|
||||
def parse_example_spec(self):
|
||||
pass
|
||||
|
||||
|
||||
class SerializationTest(test.TestCase):
|
||||
"""Tests for serialization, deserialization helpers."""
|
||||
|
||||
def test_serialize_non_feature_column(self):
|
||||
|
||||
class NotAFeatureColumn(object):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'is not a FeatureColumn'):
|
||||
serialization.serialize_feature_column(NotAFeatureColumn())
|
||||
|
||||
def test_deserialize_invalid_config(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Improper config format: {}'):
|
||||
serialization.deserialize_feature_column({})
|
||||
|
||||
def test_deserialize_config_missing_key(self):
|
||||
config_missing_key = {
|
||||
'config': {
|
||||
# Dtype is missing and should cause a failure.
|
||||
# 'dtype': 'int32',
|
||||
'default_value': None,
|
||||
'key': 'a',
|
||||
'normalizer_fn': None,
|
||||
'shape': (2,)
|
||||
},
|
||||
'class_name': 'NumericColumn'
|
||||
}
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid config:'):
|
||||
serialization.deserialize_feature_column(config_missing_key)
|
||||
|
||||
def test_deserialize_invalid_class(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Unknown feature_column_v2: NotExistingFeatureColumnClass'):
|
||||
serialization.deserialize_feature_column({
|
||||
'class_name': 'NotExistingFeatureColumnClass',
|
||||
'config': {}
|
||||
})
|
||||
|
||||
def test_deserialization_deduping(self):
|
||||
price = fc.numeric_column('price')
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
|
||||
|
||||
configs = serialization.serialize_feature_columns([price, bucketized_price])
|
||||
|
||||
deserialized_feature_columns = serialization.deserialize_feature_columns(
|
||||
configs)
|
||||
self.assertEqual(2, len(deserialized_feature_columns))
|
||||
new_price = deserialized_feature_columns[0]
|
||||
new_bucketized_price = deserialized_feature_columns[1]
|
||||
|
||||
# Ensure these are not the original objects:
|
||||
self.assertIsNot(price, new_price)
|
||||
self.assertIsNot(bucketized_price, new_bucketized_price)
|
||||
# But they are equivalent:
|
||||
self.assertEquals(price, new_price)
|
||||
self.assertEquals(bucketized_price, new_bucketized_price)
|
||||
|
||||
# Check that deduping worked:
|
||||
self.assertIs(new_bucketized_price.source_column, new_price)
|
||||
|
||||
def deserialization_custom_objects(self):
|
||||
# Note that custom_objects is also tested extensively above per class, this
|
||||
# test ensures that the public wrappers also handle it correctly.
|
||||
def _custom_fn(input_tensor):
|
||||
return input_tensor + 42.
|
||||
|
||||
price = fc.numeric_column('price', normalizer_fn=_custom_fn)
|
||||
|
||||
configs = serialization.serialize_feature_columns([price])
|
||||
|
||||
deserialized_feature_columns = serialization.deserialize_feature_columns(
|
||||
configs)
|
||||
|
||||
self.assertEqual(1, len(deserialized_feature_columns))
|
||||
new_price = deserialized_feature_columns[0]
|
||||
|
||||
# Ensure these are not the original objects:
|
||||
self.assertIsNot(price, new_price)
|
||||
# But they are equivalent:
|
||||
self.assertEquals(price, new_price)
|
||||
|
||||
# Check that normalizer_fn points to the correct function.
|
||||
self.assertIs(new_price.normalizer_fn, _custom_fn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -22,7 +22,6 @@ import six
|
||||
|
||||
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
|
||||
from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
|
||||
from tensorflow.python.keras import utils
|
||||
from tensorflow.python.ops import init_ops
|
||||
|
||||
|
||||
@ -77,11 +76,14 @@ def serialize_feature_column(fc):
|
||||
Raises:
|
||||
ValueError if called with input that is not string or FeatureColumn.
|
||||
"""
|
||||
# Import here to avoid circular imports.
|
||||
from tensorflow.python.keras.utils import generic_utils # pylint: disable=g-import-not-at-top
|
||||
|
||||
if isinstance(fc, six.string_types):
|
||||
return fc
|
||||
elif isinstance(fc, fc_lib.FeatureColumn):
|
||||
return utils.serialize_keras_class_and_config(fc.__class__.__name__,
|
||||
fc._get_config()) # pylint: disable=protected-access
|
||||
return generic_utils.serialize_keras_class_and_config(
|
||||
fc.__class__.__name__, fc._get_config()) # pylint: disable=protected-access
|
||||
else:
|
||||
raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
|
||||
|
||||
@ -111,6 +113,9 @@ def deserialize_feature_column(config,
|
||||
Returns:
|
||||
A FeatureColumn corresponding to the input `config`.
|
||||
"""
|
||||
# Import here to avoid circular imports.
|
||||
from tensorflow.python.keras.utils import generic_utils # pylint: disable=g-import-not-at-top
|
||||
|
||||
if isinstance(config, six.string_types):
|
||||
return config
|
||||
# A dict from class_name to class for all FeatureColumns in this module.
|
||||
@ -120,11 +125,12 @@ def deserialize_feature_column(config,
|
||||
if columns_by_name is None:
|
||||
columns_by_name = {}
|
||||
|
||||
(cls, cls_config) = utils.class_and_config_for_serialized_keras_object(
|
||||
config,
|
||||
module_objects=module_feature_column_classes,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='feature_column_v2')
|
||||
(cls,
|
||||
cls_config) = generic_utils.class_and_config_for_serialized_keras_object(
|
||||
config,
|
||||
module_objects=module_feature_column_classes,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='feature_column_v2')
|
||||
|
||||
if not issubclass(cls, fc_lib.FeatureColumn):
|
||||
raise ValueError(
|
||||
|
217
tensorflow/python/feature_column/serialization_test.py
Normal file
217
tensorflow/python/feature_column/serialization_test.py
Normal file
@ -0,0 +1,217 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for feature_column and DenseFeatures serialization."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.feature_column import feature_column_v2 as fc
|
||||
from tensorflow.python.feature_column import serialization
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FeatureColumnSerializationTest(test.TestCase):
|
||||
"""Tests for serialization, deserialization helpers."""
|
||||
|
||||
def test_serialize_non_feature_column(self):
|
||||
|
||||
class NotAFeatureColumn(object):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'is not a FeatureColumn'):
|
||||
serialization.serialize_feature_column(NotAFeatureColumn())
|
||||
|
||||
def test_deserialize_invalid_config(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Improper config format: {}'):
|
||||
serialization.deserialize_feature_column({})
|
||||
|
||||
def test_deserialize_config_missing_key(self):
|
||||
config_missing_key = {
|
||||
'config': {
|
||||
# Dtype is missing and should cause a failure.
|
||||
# 'dtype': 'int32',
|
||||
'default_value': None,
|
||||
'key': 'a',
|
||||
'normalizer_fn': None,
|
||||
'shape': (2,)
|
||||
},
|
||||
'class_name': 'NumericColumn'
|
||||
}
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Invalid config:.*expected keys.*dtype'):
|
||||
serialization.deserialize_feature_column(config_missing_key)
|
||||
|
||||
def test_deserialize_invalid_class(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Unknown feature_column_v2: NotExistingFeatureColumnClass'):
|
||||
serialization.deserialize_feature_column({
|
||||
'class_name': 'NotExistingFeatureColumnClass',
|
||||
'config': {}
|
||||
})
|
||||
|
||||
def test_deserialization_deduping(self):
|
||||
price = fc.numeric_column('price')
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
|
||||
|
||||
configs = serialization.serialize_feature_columns([price, bucketized_price])
|
||||
|
||||
deserialized_feature_columns = serialization.deserialize_feature_columns(
|
||||
configs)
|
||||
self.assertLen(deserialized_feature_columns, 2)
|
||||
new_price = deserialized_feature_columns[0]
|
||||
new_bucketized_price = deserialized_feature_columns[1]
|
||||
|
||||
# Ensure these are not the original objects:
|
||||
self.assertIsNot(price, new_price)
|
||||
self.assertIsNot(bucketized_price, new_bucketized_price)
|
||||
# But they are equivalent:
|
||||
self.assertEqual(price, new_price)
|
||||
self.assertEqual(bucketized_price, new_bucketized_price)
|
||||
|
||||
# Check that deduping worked:
|
||||
self.assertIs(new_bucketized_price.source_column, new_price)
|
||||
|
||||
def deserialization_custom_objects(self):
|
||||
# Note that custom_objects is also tested extensively above per class, this
|
||||
# test ensures that the public wrappers also handle it correctly.
|
||||
def _custom_fn(input_tensor):
|
||||
return input_tensor + 42.
|
||||
|
||||
price = fc.numeric_column('price', normalizer_fn=_custom_fn)
|
||||
|
||||
configs = serialization.serialize_feature_columns([price])
|
||||
|
||||
deserialized_feature_columns = serialization.deserialize_feature_columns(
|
||||
configs)
|
||||
|
||||
self.assertLen(deserialized_feature_columns, 1)
|
||||
new_price = deserialized_feature_columns[0]
|
||||
|
||||
# Ensure these are not the original objects:
|
||||
self.assertIsNot(price, new_price)
|
||||
# But they are equivalent:
|
||||
self.assertEqual(price, new_price)
|
||||
|
||||
# Check that normalizer_fn points to the correct function.
|
||||
self.assertIs(new_price.normalizer_fn, _custom_fn)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DenseFeaturesSerializationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('default', None, None),
|
||||
('trainable', True, 'trainable'),
|
||||
('not_trainable', False, 'frozen'))
|
||||
def test_get_config(self, trainable, name):
|
||||
cols = [fc.numeric_column('a'),
|
||||
fc.embedding_column(fc.categorical_column_with_identity(
|
||||
key='b', num_buckets=3), dimension=2)]
|
||||
orig_layer = fc.DenseFeatures(cols, trainable=trainable, name=name)
|
||||
config = orig_layer.get_config()
|
||||
|
||||
self.assertEqual(config['name'], orig_layer.name)
|
||||
self.assertEqual(config['trainable'], trainable)
|
||||
self.assertLen(config['feature_columns'], 2)
|
||||
self.assertEqual(
|
||||
config['feature_columns'][0]['class_name'], 'NumericColumn')
|
||||
self.assertEqual(config['feature_columns'][0]['config']['shape'], (1,))
|
||||
self.assertEqual(
|
||||
config['feature_columns'][1]['class_name'], 'EmbeddingColumn')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('default', None, None),
|
||||
('trainable', True, 'trainable'),
|
||||
('not_trainable', False, 'frozen'))
|
||||
def test_from_config(self, trainable, name):
|
||||
cols = [fc.numeric_column('a'),
|
||||
fc.embedding_column(fc.categorical_column_with_vocabulary_list(
|
||||
'b', vocabulary_list=['1', '2', '3']), dimension=2),
|
||||
fc.indicator_column(fc.categorical_column_with_hash_bucket(
|
||||
key='c', hash_bucket_size=3))]
|
||||
orig_layer = fc.DenseFeatures(cols, trainable=trainable, name=name)
|
||||
config = orig_layer.get_config()
|
||||
|
||||
new_layer = fc.DenseFeatures.from_config(config)
|
||||
|
||||
self.assertEqual(new_layer.name, orig_layer.name)
|
||||
self.assertEqual(new_layer.trainable, trainable)
|
||||
self.assertLen(new_layer._feature_columns, 3)
|
||||
self.assertEqual(new_layer._feature_columns[0].name, 'a')
|
||||
self.assertEqual(new_layer._feature_columns[1].initializer.mean, 0.0)
|
||||
self.assertEqual(new_layer._feature_columns[1].categorical_column.name, 'b')
|
||||
self.assertIsInstance(new_layer._feature_columns[2], fc.IndicatorColumn)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearModelLayerSerializationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('default', 1, 'sum', None, None),
|
||||
('trainable', 6, 'mean', True, 'trainable'),
|
||||
('not_trainable', 10, 'sum', False, 'frozen'))
|
||||
def test_get_config(self, units, sparse_combiner, trainable, name):
|
||||
cols = [fc.numeric_column('a'),
|
||||
fc.categorical_column_with_identity(key='b', num_buckets=3)]
|
||||
layer = fc._LinearModelLayer(
|
||||
cols, units=units, sparse_combiner=sparse_combiner,
|
||||
trainable=trainable, name=name)
|
||||
config = layer.get_config()
|
||||
|
||||
self.assertEqual(config['name'], layer.name)
|
||||
self.assertEqual(config['trainable'], trainable)
|
||||
self.assertEqual(config['units'], units)
|
||||
self.assertEqual(config['sparse_combiner'], sparse_combiner)
|
||||
self.assertLen(config['feature_columns'], 2)
|
||||
self.assertEqual(
|
||||
config['feature_columns'][0]['class_name'], 'NumericColumn')
|
||||
self.assertEqual(
|
||||
config['feature_columns'][1]['class_name'], 'IdentityCategoricalColumn')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('default', 1, 'sum', None, None),
|
||||
('trainable', 6, 'mean', True, 'trainable'),
|
||||
('not_trainable', 10, 'sum', False, 'frozen'))
|
||||
def test_from_config(self, units, sparse_combiner, trainable, name):
|
||||
cols = [fc.numeric_column('a'),
|
||||
fc.categorical_column_with_vocabulary_list(
|
||||
'b', vocabulary_list=('1', '2', '3')),
|
||||
fc.categorical_column_with_hash_bucket(
|
||||
key='c', hash_bucket_size=3)]
|
||||
orig_layer = fc._LinearModelLayer(
|
||||
cols, units=units, sparse_combiner=sparse_combiner,
|
||||
trainable=trainable, name=name)
|
||||
config = orig_layer.get_config()
|
||||
|
||||
new_layer = fc._LinearModelLayer.from_config(config)
|
||||
|
||||
self.assertEqual(new_layer.name, orig_layer.name)
|
||||
self.assertEqual(new_layer._units, units)
|
||||
self.assertEqual(new_layer._sparse_combiner, sparse_combiner)
|
||||
self.assertEqual(new_layer.trainable, trainable)
|
||||
self.assertLen(new_layer._feature_columns, 3)
|
||||
self.assertEqual(new_layer._feature_columns[0].name, 'a')
|
||||
self.assertEqual(
|
||||
new_layer._feature_columns[1].vocabulary_list, ('1', '2', '3'))
|
||||
self.assertEqual(new_layer._feature_columns[2].num_buckets, 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -387,8 +387,10 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
# A separate build for layers without serialization to avoid circular deps
|
||||
# with feature column.
|
||||
py_library(
|
||||
name = "layers",
|
||||
name = "layers_base",
|
||||
srcs = [
|
||||
"layers/__init__.py",
|
||||
"layers/advanced_activations.py",
|
||||
@ -408,7 +410,6 @@ py_library(
|
||||
"layers/recurrent.py",
|
||||
"layers/recurrent_v2.py",
|
||||
"layers/rnn_cell_wrapper_v2.py",
|
||||
"layers/serialization.py",
|
||||
"layers/wrappers.py",
|
||||
"utils/kernelized_utils.py",
|
||||
"utils/layer_utils.py",
|
||||
@ -439,6 +440,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "layers",
|
||||
srcs = [
|
||||
"layers/serialization.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":layers_base",
|
||||
":tf_utils",
|
||||
"//tensorflow/python/feature_column:feature_column_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "generic_utils",
|
||||
srcs = [
|
||||
@ -1522,13 +1536,14 @@ tf_py_test(
|
||||
|
||||
tf_py_test(
|
||||
name = "save_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["saving/save_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/feature_column:feature_column_v2",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -940,6 +940,7 @@ class Network(base_layer.Layer):
|
||||
for layer in self.layers: # From the earliest layers on.
|
||||
layer_class_name = layer.__class__.__name__
|
||||
layer_config = layer.get_config()
|
||||
|
||||
filtered_inbound_nodes = []
|
||||
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||||
node_key = _make_node_key(layer.name, original_node_index)
|
||||
@ -974,6 +975,7 @@ class Network(base_layer.Layer):
|
||||
# Convert ListWrapper to list for backwards compatible configs.
|
||||
node_data = tf_utils.convert_inner_node_data(node_data)
|
||||
filtered_inbound_nodes.append(node_data)
|
||||
|
||||
layer_configs.append({
|
||||
'name': layer.name,
|
||||
'class_name': layer_class_name,
|
||||
@ -1083,12 +1085,14 @@ class Network(base_layer.Layer):
|
||||
# Call layer on its inputs, thus creating the node
|
||||
# and building the layer if needed.
|
||||
if input_tensors is not None:
|
||||
# Preserve compatibility with older configs.
|
||||
# Preserve compatibility with older configs
|
||||
flat_input_tensors = nest.flatten(input_tensors)
|
||||
if len(flat_input_tensors) == 1:
|
||||
layer(flat_input_tensors[0], **kwargs)
|
||||
else:
|
||||
layer(input_tensors, **kwargs)
|
||||
# If this is a single element but not a dict, unwrap. If this is a dict,
|
||||
# assume the first layer expects a dict (as is the case with a
|
||||
# DenseFeatures layer); pass through.
|
||||
if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
|
||||
input_tensors = flat_input_tensors[0]
|
||||
layer(input_tensors, **kwargs)
|
||||
|
||||
def process_layer(layer_data):
|
||||
"""Deserializes a layer, then call it on appropriate inputs.
|
||||
|
@ -74,11 +74,18 @@ def deserialize(config, custom_objects=None):
|
||||
Returns:
|
||||
Layer instance (may be Model, Sequential, Network, Layer...)
|
||||
"""
|
||||
# Prevent circular dependencies.
|
||||
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.feature_column import feature_column_v2 # pylint: disable=g-import-not-at-top
|
||||
|
||||
globs = globals() # All layers.
|
||||
globs['Network'] = models.Network
|
||||
globs['Model'] = models.Model
|
||||
globs['Sequential'] = models.Sequential
|
||||
|
||||
# Prevent circular dependencies with FeatureColumn serialization.
|
||||
globs['DenseFeatures'] = feature_column_v2.DenseFeatures
|
||||
|
||||
layer_class_name = config['class_name']
|
||||
if layer_class_name in _DESERIALIZATION_TABLE:
|
||||
config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]
|
||||
|
@ -20,8 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.feature_column import feature_column_v2
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.saving import model_config
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -66,5 +71,26 @@ class TestSaveModel(test.TestCase):
|
||||
'Saving the model as SavedModel is still in experimental stages.'):
|
||||
save.save_model(self.model, path, save_format='tf')
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_saving_with_dense_features(self):
|
||||
cols = [feature_column_v2.numeric_column('a')]
|
||||
input_layer = keras.layers.Input(shape=(1,), name='a')
|
||||
fc_layer = feature_column_v2.DenseFeatures(cols)({'a': input_layer})
|
||||
output = keras.layers.Dense(10)(fc_layer)
|
||||
|
||||
model = keras.models.Model(input_layer, output)
|
||||
|
||||
model.compile(
|
||||
loss=keras.losses.MSE,
|
||||
optimizer=keras.optimizers.RMSprop(lr=0.0001),
|
||||
metrics=[keras.metrics.categorical_accuracy])
|
||||
|
||||
config = model.to_json()
|
||||
loaded_model = model_config.model_from_json(config)
|
||||
|
||||
inputs = np.arange(10).reshape(10, 1)
|
||||
self.assertLen(loaded_model.predict({'a': inputs}), 10)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -157,7 +157,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
|
@ -157,7 +157,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
|
Loading…
Reference in New Issue
Block a user