Add serialization support for feature columns in DenseFeatures layers in Keras.

PiperOrigin-RevId: 251656649
This commit is contained in:
Karmel Allison 2019-06-05 09:23:28 -07:00 committed by TensorFlower Gardener
parent 14d492d7b4
commit cd09510f92
11 changed files with 379 additions and 136 deletions

View File

@ -46,7 +46,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/keras",
"//tensorflow/python/keras:engine",
"@six_archive//:six",
],
)
@ -84,7 +84,9 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/keras",
"//tensorflow/python/keras:engine",
"//tensorflow/python/keras:generic_utils",
"//tensorflow/python/keras:layers_base",
"//third_party/py/numpy",
"@six_archive//:six",
],
@ -233,3 +235,16 @@ py_test(
"//tensorflow/python/keras:layers",
],
)
py_test(
name = "serialization_test",
srcs = ["serialization_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":feature_column_v2",
"//tensorflow/python:client_testlib",
"//tensorflow/python:util",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -144,7 +144,8 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
# of the main repo.
from tensorflow.python.keras import utils
from tensorflow.python.keras import initializers
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.ops import array_ops
@ -478,6 +479,27 @@ class DenseFeatures(_BaseFeaturesLayer):
output_tensors.append(processed_tensors)
return self._verify_and_concat_tensors(output_tensors)
def get_config(self):
# Import here to avoid circular imports.
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
column_configs = serialization.serialize_feature_columns(
self._feature_columns)
config = {'feature_columns': column_configs}
base_config = super( # pylint: disable=bad-super-call
DenseFeatures, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
# Import here to avoid circular imports.
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
config_cp = config.copy()
config_cp['feature_columns'] = serialization.deserialize_feature_columns(
config['feature_columns'], custom_objects=custom_objects)
return cls(**config_cp)
class _LinearModelLayer(Layer):
"""Layer that contains logic for `LinearModel`."""
@ -572,6 +594,32 @@ class _LinearModelLayer(Layer):
predictions_no_bias, self.bias, name='weighted_sum')
return predictions
def get_config(self):
# Import here to avoid circular imports.
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
column_configs = serialization.serialize_feature_columns(
self._feature_columns)
config = {
'feature_columns': column_configs,
'units': self._units,
'sparse_combiner': self._sparse_combiner
}
base_config = super( # pylint: disable=bad-super-call
_LinearModelLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
# Import here to avoid circular imports.
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
config_cp = config.copy()
columns = serialization.deserialize_feature_columns(
config_cp['feature_columns'], custom_objects=custom_objects)
del config_cp['feature_columns']
return cls(feature_columns=columns, **config_cp)
class LinearModel(training.Model):
"""Produces a linear prediction `Tensor` based on given `feature_columns`.
@ -2251,7 +2299,7 @@ class FeatureColumn(object):
config['dtype'] = self.dtype.name
# Non-trivial dependencies should be Keras-(de)serializable.
config['normalizer_fn'] = utils.serialize_keras_object(
config['normalizer_fn'] = generic_utils.serialize_keras_object(
self.normalizer_fn)
return config
@ -2264,7 +2312,7 @@ class FeatureColumn(object):
kwargs['parent'] = deserialize_feature_column(
config['parent'], custom_objects, columns_by_name)
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
kwargs['normalizer_fn'] = utils.deserialize_keras_object(
kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
config['normalizer_fn'], custom_objects=custom_objects)
return cls(**kwargs)
@ -2813,7 +2861,8 @@ class NumericColumn(
def _get_config(self):
"""See 'FeatureColumn` base class."""
config = dict(zip(self._fields, self))
config['normalizer_fn'] = utils.serialize_keras_object(self.normalizer_fn)
config['normalizer_fn'] = generic_utils.serialize_keras_object(
self.normalizer_fn)
config['dtype'] = self.dtype.name
return config
@ -2822,9 +2871,18 @@ class NumericColumn(
"""See 'FeatureColumn` base class."""
_check_config_keys(config, cls._fields)
kwargs = config.copy()
kwargs['normalizer_fn'] = utils.deserialize_keras_object(
kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
config['normalizer_fn'], custom_objects=custom_objects)
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
# Keras serialization uses nest to listify everything.
# This causes problems with the NumericColumn shape, which becomes
# unhashable. We could try to solve this on the Keras side, but that
# would require lots of tracking to avoid changing existing behavior.
# Instead, we ensure here that we revive correctly.
if isinstance(config['shape'], list):
kwargs['shape'] = tuple(config['shape'])
return cls(**kwargs)
@ -3199,7 +3257,7 @@ class EmbeddingColumn(
config = dict(zip(self._fields, self))
config['categorical_column'] = serialize_feature_column(
self.categorical_column)
config['initializer'] = utils.serialize_keras_object(self.initializer)
config['initializer'] = initializers.serialize(self.initializer)
return config
@classmethod
@ -3210,7 +3268,7 @@ class EmbeddingColumn(
kwargs = config.copy()
kwargs['categorical_column'] = deserialize_feature_column(
config['categorical_column'], custom_objects, columns_by_name)
kwargs['initializer'] = utils.deserialize_keras_object(
kwargs['initializer'] = initializers.deserialize(
config['initializer'], custom_objects=custom_objects)
return cls(**kwargs)

View File

@ -33,7 +33,6 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.feature_column import serialization
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -463,10 +462,10 @@ class NumericColumnTest(test.TestCase):
'normalizer_fn': '_increment_two'
}, config)
self.assertEqual(
price,
fc.NumericColumn._from_config(
config, custom_objects={'_increment_two': _increment_two}))
new_col = fc.NumericColumn._from_config(
config, custom_objects={'_increment_two': _increment_two})
self.assertEqual(price, new_col)
self.assertEqual(new_col.shape, (1,))
class BucketizedColumnTest(test.TestCase):
@ -8347,109 +8346,5 @@ class WeightedCategoricalColumnTest(test.TestCase):
self.assertEqual(column, new_column)
self.assertIs(categorical_column, new_column.categorical_column)
class FeatureColumnForSerializationTest(BaseFeatureColumnForTests):
@property
def _is_v2_column(self):
return True
@property
def name(self):
return 'BadParentsFeatureColumn'
def transform_feature(self, transformation_cache, state_manager):
return 'Output'
@property
def parse_example_spec(self):
pass
class SerializationTest(test.TestCase):
"""Tests for serialization, deserialization helpers."""
def test_serialize_non_feature_column(self):
class NotAFeatureColumn(object):
pass
with self.assertRaisesRegexp(ValueError, 'is not a FeatureColumn'):
serialization.serialize_feature_column(NotAFeatureColumn())
def test_deserialize_invalid_config(self):
with self.assertRaisesRegexp(ValueError, 'Improper config format: {}'):
serialization.deserialize_feature_column({})
def test_deserialize_config_missing_key(self):
config_missing_key = {
'config': {
# Dtype is missing and should cause a failure.
# 'dtype': 'int32',
'default_value': None,
'key': 'a',
'normalizer_fn': None,
'shape': (2,)
},
'class_name': 'NumericColumn'
}
with self.assertRaisesRegexp(ValueError, 'Invalid config:'):
serialization.deserialize_feature_column(config_missing_key)
def test_deserialize_invalid_class(self):
with self.assertRaisesRegexp(
ValueError, 'Unknown feature_column_v2: NotExistingFeatureColumnClass'):
serialization.deserialize_feature_column({
'class_name': 'NotExistingFeatureColumnClass',
'config': {}
})
def test_deserialization_deduping(self):
price = fc.numeric_column('price')
bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
configs = serialization.serialize_feature_columns([price, bucketized_price])
deserialized_feature_columns = serialization.deserialize_feature_columns(
configs)
self.assertEqual(2, len(deserialized_feature_columns))
new_price = deserialized_feature_columns[0]
new_bucketized_price = deserialized_feature_columns[1]
# Ensure these are not the original objects:
self.assertIsNot(price, new_price)
self.assertIsNot(bucketized_price, new_bucketized_price)
# But they are equivalent:
self.assertEquals(price, new_price)
self.assertEquals(bucketized_price, new_bucketized_price)
# Check that deduping worked:
self.assertIs(new_bucketized_price.source_column, new_price)
def deserialization_custom_objects(self):
# Note that custom_objects is also tested extensively above per class, this
# test ensures that the public wrappers also handle it correctly.
def _custom_fn(input_tensor):
return input_tensor + 42.
price = fc.numeric_column('price', normalizer_fn=_custom_fn)
configs = serialization.serialize_feature_columns([price])
deserialized_feature_columns = serialization.deserialize_feature_columns(
configs)
self.assertEqual(1, len(deserialized_feature_columns))
new_price = deserialized_feature_columns[0]
# Ensure these are not the original objects:
self.assertIsNot(price, new_price)
# But they are equivalent:
self.assertEquals(price, new_price)
# Check that normalizer_fn points to the correct function.
self.assertIs(new_price.normalizer_fn, _custom_fn)
if __name__ == '__main__':
test.main()

View File

@ -22,7 +22,6 @@ import six
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
from tensorflow.python.keras import utils
from tensorflow.python.ops import init_ops
@ -77,11 +76,14 @@ def serialize_feature_column(fc):
Raises:
ValueError if called with input that is not string or FeatureColumn.
"""
# Import here to avoid circular imports.
from tensorflow.python.keras.utils import generic_utils # pylint: disable=g-import-not-at-top
if isinstance(fc, six.string_types):
return fc
elif isinstance(fc, fc_lib.FeatureColumn):
return utils.serialize_keras_class_and_config(fc.__class__.__name__,
fc._get_config()) # pylint: disable=protected-access
return generic_utils.serialize_keras_class_and_config(
fc.__class__.__name__, fc._get_config()) # pylint: disable=protected-access
else:
raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
@ -111,6 +113,9 @@ def deserialize_feature_column(config,
Returns:
A FeatureColumn corresponding to the input `config`.
"""
# Import here to avoid circular imports.
from tensorflow.python.keras.utils import generic_utils # pylint: disable=g-import-not-at-top
if isinstance(config, six.string_types):
return config
# A dict from class_name to class for all FeatureColumns in this module.
@ -120,11 +125,12 @@ def deserialize_feature_column(config,
if columns_by_name is None:
columns_by_name = {}
(cls, cls_config) = utils.class_and_config_for_serialized_keras_object(
config,
module_objects=module_feature_column_classes,
custom_objects=custom_objects,
printable_module_name='feature_column_v2')
(cls,
cls_config) = generic_utils.class_and_config_for_serialized_keras_object(
config,
module_objects=module_feature_column_classes,
custom_objects=custom_objects,
printable_module_name='feature_column_v2')
if not issubclass(cls, fc_lib.FeatureColumn):
raise ValueError(

View File

@ -0,0 +1,217 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for feature_column and DenseFeatures serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.feature_column import serialization
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class FeatureColumnSerializationTest(test.TestCase):
"""Tests for serialization, deserialization helpers."""
def test_serialize_non_feature_column(self):
class NotAFeatureColumn(object):
pass
with self.assertRaisesRegexp(ValueError, 'is not a FeatureColumn'):
serialization.serialize_feature_column(NotAFeatureColumn())
def test_deserialize_invalid_config(self):
with self.assertRaisesRegexp(ValueError, 'Improper config format: {}'):
serialization.deserialize_feature_column({})
def test_deserialize_config_missing_key(self):
config_missing_key = {
'config': {
# Dtype is missing and should cause a failure.
# 'dtype': 'int32',
'default_value': None,
'key': 'a',
'normalizer_fn': None,
'shape': (2,)
},
'class_name': 'NumericColumn'
}
with self.assertRaisesRegexp(
ValueError, 'Invalid config:.*expected keys.*dtype'):
serialization.deserialize_feature_column(config_missing_key)
def test_deserialize_invalid_class(self):
with self.assertRaisesRegexp(
ValueError, 'Unknown feature_column_v2: NotExistingFeatureColumnClass'):
serialization.deserialize_feature_column({
'class_name': 'NotExistingFeatureColumnClass',
'config': {}
})
def test_deserialization_deduping(self):
price = fc.numeric_column('price')
bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
configs = serialization.serialize_feature_columns([price, bucketized_price])
deserialized_feature_columns = serialization.deserialize_feature_columns(
configs)
self.assertLen(deserialized_feature_columns, 2)
new_price = deserialized_feature_columns[0]
new_bucketized_price = deserialized_feature_columns[1]
# Ensure these are not the original objects:
self.assertIsNot(price, new_price)
self.assertIsNot(bucketized_price, new_bucketized_price)
# But they are equivalent:
self.assertEqual(price, new_price)
self.assertEqual(bucketized_price, new_bucketized_price)
# Check that deduping worked:
self.assertIs(new_bucketized_price.source_column, new_price)
def deserialization_custom_objects(self):
# Note that custom_objects is also tested extensively above per class, this
# test ensures that the public wrappers also handle it correctly.
def _custom_fn(input_tensor):
return input_tensor + 42.
price = fc.numeric_column('price', normalizer_fn=_custom_fn)
configs = serialization.serialize_feature_columns([price])
deserialized_feature_columns = serialization.deserialize_feature_columns(
configs)
self.assertLen(deserialized_feature_columns, 1)
new_price = deserialized_feature_columns[0]
# Ensure these are not the original objects:
self.assertIsNot(price, new_price)
# But they are equivalent:
self.assertEqual(price, new_price)
# Check that normalizer_fn points to the correct function.
self.assertIs(new_price.normalizer_fn, _custom_fn)
@test_util.run_all_in_graph_and_eager_modes
class DenseFeaturesSerializationTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('default', None, None),
('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_get_config(self, trainable, name):
cols = [fc.numeric_column('a'),
fc.embedding_column(fc.categorical_column_with_identity(
key='b', num_buckets=3), dimension=2)]
orig_layer = fc.DenseFeatures(cols, trainable=trainable, name=name)
config = orig_layer.get_config()
self.assertEqual(config['name'], orig_layer.name)
self.assertEqual(config['trainable'], trainable)
self.assertLen(config['feature_columns'], 2)
self.assertEqual(
config['feature_columns'][0]['class_name'], 'NumericColumn')
self.assertEqual(config['feature_columns'][0]['config']['shape'], (1,))
self.assertEqual(
config['feature_columns'][1]['class_name'], 'EmbeddingColumn')
@parameterized.named_parameters(
('default', None, None),
('trainable', True, 'trainable'),
('not_trainable', False, 'frozen'))
def test_from_config(self, trainable, name):
cols = [fc.numeric_column('a'),
fc.embedding_column(fc.categorical_column_with_vocabulary_list(
'b', vocabulary_list=['1', '2', '3']), dimension=2),
fc.indicator_column(fc.categorical_column_with_hash_bucket(
key='c', hash_bucket_size=3))]
orig_layer = fc.DenseFeatures(cols, trainable=trainable, name=name)
config = orig_layer.get_config()
new_layer = fc.DenseFeatures.from_config(config)
self.assertEqual(new_layer.name, orig_layer.name)
self.assertEqual(new_layer.trainable, trainable)
self.assertLen(new_layer._feature_columns, 3)
self.assertEqual(new_layer._feature_columns[0].name, 'a')
self.assertEqual(new_layer._feature_columns[1].initializer.mean, 0.0)
self.assertEqual(new_layer._feature_columns[1].categorical_column.name, 'b')
self.assertIsInstance(new_layer._feature_columns[2], fc.IndicatorColumn)
@test_util.run_all_in_graph_and_eager_modes
class LinearModelLayerSerializationTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('default', 1, 'sum', None, None),
('trainable', 6, 'mean', True, 'trainable'),
('not_trainable', 10, 'sum', False, 'frozen'))
def test_get_config(self, units, sparse_combiner, trainable, name):
cols = [fc.numeric_column('a'),
fc.categorical_column_with_identity(key='b', num_buckets=3)]
layer = fc._LinearModelLayer(
cols, units=units, sparse_combiner=sparse_combiner,
trainable=trainable, name=name)
config = layer.get_config()
self.assertEqual(config['name'], layer.name)
self.assertEqual(config['trainable'], trainable)
self.assertEqual(config['units'], units)
self.assertEqual(config['sparse_combiner'], sparse_combiner)
self.assertLen(config['feature_columns'], 2)
self.assertEqual(
config['feature_columns'][0]['class_name'], 'NumericColumn')
self.assertEqual(
config['feature_columns'][1]['class_name'], 'IdentityCategoricalColumn')
@parameterized.named_parameters(
('default', 1, 'sum', None, None),
('trainable', 6, 'mean', True, 'trainable'),
('not_trainable', 10, 'sum', False, 'frozen'))
def test_from_config(self, units, sparse_combiner, trainable, name):
cols = [fc.numeric_column('a'),
fc.categorical_column_with_vocabulary_list(
'b', vocabulary_list=('1', '2', '3')),
fc.categorical_column_with_hash_bucket(
key='c', hash_bucket_size=3)]
orig_layer = fc._LinearModelLayer(
cols, units=units, sparse_combiner=sparse_combiner,
trainable=trainable, name=name)
config = orig_layer.get_config()
new_layer = fc._LinearModelLayer.from_config(config)
self.assertEqual(new_layer.name, orig_layer.name)
self.assertEqual(new_layer._units, units)
self.assertEqual(new_layer._sparse_combiner, sparse_combiner)
self.assertEqual(new_layer.trainable, trainable)
self.assertLen(new_layer._feature_columns, 3)
self.assertEqual(new_layer._feature_columns[0].name, 'a')
self.assertEqual(
new_layer._feature_columns[1].vocabulary_list, ('1', '2', '3'))
self.assertEqual(new_layer._feature_columns[2].num_buckets, 3)
if __name__ == '__main__':
test.main()

View File

@ -387,8 +387,10 @@ py_library(
],
)
# A separate build for layers without serialization to avoid circular deps
# with feature column.
py_library(
name = "layers",
name = "layers_base",
srcs = [
"layers/__init__.py",
"layers/advanced_activations.py",
@ -408,7 +410,6 @@ py_library(
"layers/recurrent.py",
"layers/recurrent_v2.py",
"layers/rnn_cell_wrapper_v2.py",
"layers/serialization.py",
"layers/wrappers.py",
"utils/kernelized_utils.py",
"utils/layer_utils.py",
@ -439,6 +440,19 @@ py_library(
],
)
py_library(
name = "layers",
srcs = [
"layers/serialization.py",
],
srcs_version = "PY2AND3",
deps = [
":layers_base",
":tf_utils",
"//tensorflow/python/feature_column:feature_column_py",
],
)
py_library(
name = "generic_utils",
srcs = [
@ -1522,13 +1536,14 @@ tf_py_test(
tf_py_test(
name = "save_test",
size = "small",
size = "medium",
srcs = ["saving/save_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python/feature_column:feature_column_v2",
],
)

View File

@ -940,6 +940,7 @@ class Network(base_layer.Layer):
for layer in self.layers: # From the earliest layers on.
layer_class_name = layer.__class__.__name__
layer_config = layer.get_config()
filtered_inbound_nodes = []
for original_node_index, node in enumerate(layer._inbound_nodes):
node_key = _make_node_key(layer.name, original_node_index)
@ -974,6 +975,7 @@ class Network(base_layer.Layer):
# Convert ListWrapper to list for backwards compatible configs.
node_data = tf_utils.convert_inner_node_data(node_data)
filtered_inbound_nodes.append(node_data)
layer_configs.append({
'name': layer.name,
'class_name': layer_class_name,
@ -1083,12 +1085,14 @@ class Network(base_layer.Layer):
# Call layer on its inputs, thus creating the node
# and building the layer if needed.
if input_tensors is not None:
# Preserve compatibility with older configs.
# Preserve compatibility with older configs
flat_input_tensors = nest.flatten(input_tensors)
if len(flat_input_tensors) == 1:
layer(flat_input_tensors[0], **kwargs)
else:
layer(input_tensors, **kwargs)
# If this is a single element but not a dict, unwrap. If this is a dict,
# assume the first layer expects a dict (as is the case with a
# DenseFeatures layer); pass through.
if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
input_tensors = flat_input_tensors[0]
layer(input_tensors, **kwargs)
def process_layer(layer_data):
"""Deserializes a layer, then call it on appropriate inputs.

View File

@ -74,11 +74,18 @@ def deserialize(config, custom_objects=None):
Returns:
Layer instance (may be Model, Sequential, Network, Layer...)
"""
# Prevent circular dependencies.
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import feature_column_v2 # pylint: disable=g-import-not-at-top
globs = globals() # All layers.
globs['Network'] = models.Network
globs['Model'] = models.Model
globs['Sequential'] = models.Sequential
# Prevent circular dependencies with FeatureColumn serialization.
globs['DenseFeatures'] = feature_column_v2.DenseFeatures
layer_class_name = config['class_name']
if layer_class_name in _DESERIALIZATION_TABLE:
config['class_name'] = _DESERIALIZATION_TABLE[layer_class_name]

View File

@ -20,8 +20,13 @@ from __future__ import print_function
import os
import numpy as np
from tensorflow.python import keras
from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.saving import model_config
from tensorflow.python.keras.saving import save
from tensorflow.python.platform import test
@ -66,5 +71,26 @@ class TestSaveModel(test.TestCase):
'Saving the model as SavedModel is still in experimental stages.'):
save.save_model(self.model, path, save_format='tf')
@test_util.run_in_graph_and_eager_modes
def test_saving_with_dense_features(self):
cols = [feature_column_v2.numeric_column('a')]
input_layer = keras.layers.Input(shape=(1,), name='a')
fc_layer = feature_column_v2.DenseFeatures(cols)({'a': input_layer})
output = keras.layers.Dense(10)(fc_layer)
model = keras.models.Model(input_layer, output)
model.compile(
loss=keras.losses.MSE,
optimizer=keras.optimizers.RMSprop(lr=0.0001),
metrics=[keras.metrics.categorical_accuracy])
config = model.to_json()
loaded_model = model_config.model_from_json(config)
inputs = np.arange(10).reshape(10, 1)
self.assertLen(loaded_model.predict({'a': inputs}), 10)
if __name__ == '__main__':
test.main()

View File

@ -157,7 +157,7 @@ tf_class {
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_config"

View File

@ -157,7 +157,7 @@ tf_class {
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_config"