Copy the generic_util logic to feature column for serialization/deserialization.

This is the final dependency from feature_column to Keras.

The copied functions are trimmed down version since it doesn't have access to Keras global custom object registration, which I don't think are used by feature column. The custom object scope will still work.

PiperOrigin-RevId: 315621765
Change-Id: I2ae22af83d625c8e55c7fe21b42194bbdbfded23
This commit is contained in:
Scott Zhu 2020-06-09 21:15:43 -07:00 committed by TensorFlower Gardener
parent 2e9eaece7b
commit 090f260aab
3 changed files with 148 additions and 18 deletions

View File

@ -90,7 +90,6 @@ py_library(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras/utils:generic_utils",
"//tensorflow/python/training/tracking",
"//tensorflow/python/training/tracking:data_structures",
"//third_party/py/numpy",

View File

@ -142,9 +142,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
# of the main repo.
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@ -2609,7 +2606,8 @@ class NumericColumn(
def get_config(self):
"""See 'FeatureColumn` base class."""
config = dict(zip(self._fields, self))
config['normalizer_fn'] = generic_utils.serialize_keras_object(
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
config['normalizer_fn'] = serialization._serialize_keras_object( # pylint: disable=protected-access
self.normalizer_fn)
config['dtype'] = self.dtype.name
return config
@ -2618,8 +2616,9 @@ class NumericColumn(
def from_config(cls, config, custom_objects=None, columns_by_name=None):
"""See 'FeatureColumn` base class."""
_check_config_keys(config, cls._fields)
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
kwargs = _standardize_and_copy_config(config)
kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
kwargs['normalizer_fn'] = serialization._deserialize_keras_object( # pylint: disable=protected-access
config['normalizer_fn'], custom_objects=custom_objects)
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
@ -3027,11 +3026,11 @@ class EmbeddingColumn(
def get_config(self):
"""See 'FeatureColumn` base class."""
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
config = dict(zip(self._fields, self))
config['categorical_column'] = serialize_feature_column(
config['categorical_column'] = serialization.serialize_feature_column(
self.categorical_column)
config['initializer'] = generic_utils.serialize_keras_object(
config['initializer'] = serialization._serialize_keras_object( # pylint: disable=protected-access
self.initializer)
return config
@ -3040,13 +3039,13 @@ class EmbeddingColumn(
"""See 'FeatureColumn` base class."""
if 'use_safe_embedding_lookup' not in config:
config['use_safe_embedding_lookup'] = True
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
_check_config_keys(config, cls._fields)
kwargs = _standardize_and_copy_config(config)
kwargs['categorical_column'] = deserialize_feature_column(
kwargs['categorical_column'] = serialization.deserialize_feature_column(
config['categorical_column'], custom_objects, columns_by_name)
all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
kwargs['initializer'] = generic_utils.deserialize_keras_object(
kwargs['initializer'] = serialization._deserialize_keras_object( # pylint: disable=protected-access
config['initializer'],
module_objects=all_initializers,
custom_objects=custom_objects)

View File

@ -23,12 +23,9 @@ import six
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
from tensorflow.python.ops import init_ops
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
# Prevent circular dependencies with Keras serialization.
generic_utils = LazyLoader(
'generic_utils', globals(),
'tensorflow.python.keras.utils.generic_utils')
_FEATURE_COLUMNS = [
fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
@ -124,7 +121,7 @@ def deserialize_feature_column(config,
columns_by_name = {}
(cls,
cls_config) = generic_utils.class_and_config_for_serialized_keras_object(
cls_config) = _class_and_config_for_serialized_keras_object(
config,
module_objects=module_feature_column_classes,
custom_objects=custom_objects,
@ -205,3 +202,138 @@ def _column_name_with_class_name(fc):
A unique name as a string.
"""
return fc.__class__.__name__ + ':' + fc.name
def _serialize_keras_object(instance):
"""Serialize a Keras object into a JSON-compatible representation."""
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
if hasattr(instance, 'get_config'):
name = instance.__class__.__name__
config = instance.get_config()
serialization_config = {}
for key, item in config.items():
if isinstance(item, six.string_types):
serialization_config[key] = item
continue
# Any object of a different type needs to be converted to string or dict
# for serialization (e.g. custom functions, custom classes)
try:
serialized_item = _serialize_keras_object(item)
if isinstance(serialized_item, dict) and not isinstance(item, dict):
serialized_item['__passive_serialization__'] = True
serialization_config[key] = serialized_item
except ValueError:
serialization_config[key] = item
return {'class_name': name, 'config': serialization_config}
if hasattr(instance, '__name__'):
return instance.__name__
raise ValueError('Cannot serialize', instance)
def _deserialize_keras_object(identifier,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
"""Turns the serialized form of a Keras object back into an actual object."""
if identifier is None:
return None
if isinstance(identifier, dict):
# In this case we are dealing with a Keras config dictionary.
config = identifier
(cls, cls_config) = _class_and_config_for_serialized_keras_object(
config, module_objects, custom_objects, printable_module_name)
if hasattr(cls, 'from_config'):
arg_spec = tf_inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}
if 'custom_objects' in arg_spec.args:
return cls.from_config(
cls_config,
custom_objects=dict(
list(custom_objects.items())))
return cls.from_config(cls_config)
else:
# Then `cls` may be a function returning a class.
# in this case by convention `config` holds
# the kwargs of the function.
custom_objects = custom_objects or {}
return cls(**cls_config)
elif isinstance(identifier, six.string_types):
object_name = identifier
if custom_objects and object_name in custom_objects:
obj = custom_objects.get(object_name)
else:
obj = module_objects.get(object_name)
if obj is None:
raise ValueError(
'Unknown ' + printable_module_name + ': ' + object_name)
# Classes passed by name are instantiated with no args, functions are
# returned as-is.
if tf_inspect.isclass(obj):
return obj()
return obj
elif tf_inspect.isfunction(identifier):
# If a function has already been deserialized, return as is.
return identifier
else:
raise ValueError('Could not interpret serialized %s: %s' %
(printable_module_name, identifier))
def _class_and_config_for_serialized_keras_object(
config,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
"""Returns the class name and config for a serialized keras object."""
if (not isinstance(config, dict) or 'class_name' not in config or
'config' not in config):
raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name']
cls = _get_registered_object(class_name, custom_objects=custom_objects,
module_objects=module_objects)
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
cls_config = config['config']
deserialized_objects = {}
for key, item in cls_config.items():
if isinstance(item, dict) and '__passive_serialization__' in item:
deserialized_objects[key] = _deserialize_keras_object(
item,
module_objects=module_objects,
custom_objects=custom_objects,
printable_module_name='config_item')
elif (isinstance(item, six.string_types) and
tf_inspect.isfunction(_get_registered_object(item, custom_objects))):
# Handle custom functions here. When saving functions, we only save the
# function's name as a string. If we find a matching string in the custom
# objects during deserialization, we convert the string back to the
# original function.
# Note that a potential issue is that a string field could have a naming
# conflict with a custom function name, but this should be a rare case.
# This issue does not occur if a string field has a naming conflict with
# a custom object, since the config of an object will always be a dict.
deserialized_objects[key] = _get_registered_object(item, custom_objects)
for key, item in deserialized_objects.items():
cls_config[key] = deserialized_objects[key]
return (cls, cls_config)
def _get_registered_object(name, custom_objects=None, module_objects=None):
if custom_objects and name in custom_objects:
return custom_objects[name]
elif module_objects and name in module_objects:
return module_objects[name]
return None