Copy the generic_util logic to feature column for serialization/deserialization.
This is the final dependency from feature_column to Keras. The copied functions are trimmed down version since it doesn't have access to Keras global custom object registration, which I don't think are used by feature column. The custom object scope will still work. PiperOrigin-RevId: 315621765 Change-Id: I2ae22af83d625c8e55c7fe21b42194bbdbfded23
This commit is contained in:
parent
2e9eaece7b
commit
090f260aab
@ -90,7 +90,6 @@ py_library(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -142,9 +142,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
|
||||
# of the main repo.
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -2609,7 +2606,8 @@ class NumericColumn(
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
config = dict(zip(self._fields, self))
|
||||
config['normalizer_fn'] = generic_utils.serialize_keras_object(
|
||||
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||
config['normalizer_fn'] = serialization._serialize_keras_object( # pylint: disable=protected-access
|
||||
self.normalizer_fn)
|
||||
config['dtype'] = self.dtype.name
|
||||
return config
|
||||
@ -2618,8 +2616,9 @@ class NumericColumn(
|
||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
_check_config_keys(config, cls._fields)
|
||||
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||
kwargs = _standardize_and_copy_config(config)
|
||||
kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
|
||||
kwargs['normalizer_fn'] = serialization._deserialize_keras_object( # pylint: disable=protected-access
|
||||
config['normalizer_fn'], custom_objects=custom_objects)
|
||||
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
|
||||
|
||||
@ -3027,11 +3026,11 @@ class EmbeddingColumn(
|
||||
|
||||
def get_config(self):
|
||||
"""See 'FeatureColumn` base class."""
|
||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||
config = dict(zip(self._fields, self))
|
||||
config['categorical_column'] = serialize_feature_column(
|
||||
config['categorical_column'] = serialization.serialize_feature_column(
|
||||
self.categorical_column)
|
||||
config['initializer'] = generic_utils.serialize_keras_object(
|
||||
config['initializer'] = serialization._serialize_keras_object( # pylint: disable=protected-access
|
||||
self.initializer)
|
||||
return config
|
||||
|
||||
@ -3040,13 +3039,13 @@ class EmbeddingColumn(
|
||||
"""See 'FeatureColumn` base class."""
|
||||
if 'use_safe_embedding_lookup' not in config:
|
||||
config['use_safe_embedding_lookup'] = True
|
||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||
_check_config_keys(config, cls._fields)
|
||||
kwargs = _standardize_and_copy_config(config)
|
||||
kwargs['categorical_column'] = deserialize_feature_column(
|
||||
kwargs['categorical_column'] = serialization.deserialize_feature_column(
|
||||
config['categorical_column'], custom_objects, columns_by_name)
|
||||
all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
|
||||
kwargs['initializer'] = generic_utils.deserialize_keras_object(
|
||||
kwargs['initializer'] = serialization._deserialize_keras_object( # pylint: disable=protected-access
|
||||
config['initializer'],
|
||||
module_objects=all_initializers,
|
||||
custom_objects=custom_objects)
|
||||
|
@ -23,12 +23,9 @@ import six
|
||||
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
|
||||
from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
# Prevent circular dependencies with Keras serialization.
|
||||
generic_utils = LazyLoader(
|
||||
'generic_utils', globals(),
|
||||
'tensorflow.python.keras.utils.generic_utils')
|
||||
|
||||
_FEATURE_COLUMNS = [
|
||||
fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
|
||||
@ -124,7 +121,7 @@ def deserialize_feature_column(config,
|
||||
columns_by_name = {}
|
||||
|
||||
(cls,
|
||||
cls_config) = generic_utils.class_and_config_for_serialized_keras_object(
|
||||
cls_config) = _class_and_config_for_serialized_keras_object(
|
||||
config,
|
||||
module_objects=module_feature_column_classes,
|
||||
custom_objects=custom_objects,
|
||||
@ -205,3 +202,138 @@ def _column_name_with_class_name(fc):
|
||||
A unique name as a string.
|
||||
"""
|
||||
return fc.__class__.__name__ + ':' + fc.name
|
||||
|
||||
|
||||
def _serialize_keras_object(instance):
|
||||
"""Serialize a Keras object into a JSON-compatible representation."""
|
||||
_, instance = tf_decorator.unwrap(instance)
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
if hasattr(instance, 'get_config'):
|
||||
name = instance.__class__.__name__
|
||||
config = instance.get_config()
|
||||
serialization_config = {}
|
||||
for key, item in config.items():
|
||||
if isinstance(item, six.string_types):
|
||||
serialization_config[key] = item
|
||||
continue
|
||||
|
||||
# Any object of a different type needs to be converted to string or dict
|
||||
# for serialization (e.g. custom functions, custom classes)
|
||||
try:
|
||||
serialized_item = _serialize_keras_object(item)
|
||||
if isinstance(serialized_item, dict) and not isinstance(item, dict):
|
||||
serialized_item['__passive_serialization__'] = True
|
||||
serialization_config[key] = serialized_item
|
||||
except ValueError:
|
||||
serialization_config[key] = item
|
||||
|
||||
return {'class_name': name, 'config': serialization_config}
|
||||
if hasattr(instance, '__name__'):
|
||||
return instance.__name__
|
||||
raise ValueError('Cannot serialize', instance)
|
||||
|
||||
|
||||
def _deserialize_keras_object(identifier,
|
||||
module_objects=None,
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
"""Turns the serialized form of a Keras object back into an actual object."""
|
||||
if identifier is None:
|
||||
return None
|
||||
|
||||
if isinstance(identifier, dict):
|
||||
# In this case we are dealing with a Keras config dictionary.
|
||||
config = identifier
|
||||
(cls, cls_config) = _class_and_config_for_serialized_keras_object(
|
||||
config, module_objects, custom_objects, printable_module_name)
|
||||
|
||||
if hasattr(cls, 'from_config'):
|
||||
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
||||
custom_objects = custom_objects or {}
|
||||
|
||||
if 'custom_objects' in arg_spec.args:
|
||||
return cls.from_config(
|
||||
cls_config,
|
||||
custom_objects=dict(
|
||||
list(custom_objects.items())))
|
||||
return cls.from_config(cls_config)
|
||||
else:
|
||||
# Then `cls` may be a function returning a class.
|
||||
# in this case by convention `config` holds
|
||||
# the kwargs of the function.
|
||||
custom_objects = custom_objects or {}
|
||||
return cls(**cls_config)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
object_name = identifier
|
||||
if custom_objects and object_name in custom_objects:
|
||||
obj = custom_objects.get(object_name)
|
||||
else:
|
||||
obj = module_objects.get(object_name)
|
||||
if obj is None:
|
||||
raise ValueError(
|
||||
'Unknown ' + printable_module_name + ': ' + object_name)
|
||||
# Classes passed by name are instantiated with no args, functions are
|
||||
# returned as-is.
|
||||
if tf_inspect.isclass(obj):
|
||||
return obj()
|
||||
return obj
|
||||
elif tf_inspect.isfunction(identifier):
|
||||
# If a function has already been deserialized, return as is.
|
||||
return identifier
|
||||
else:
|
||||
raise ValueError('Could not interpret serialized %s: %s' %
|
||||
(printable_module_name, identifier))
|
||||
|
||||
|
||||
def _class_and_config_for_serialized_keras_object(
|
||||
config,
|
||||
module_objects=None,
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
"""Returns the class name and config for a serialized keras object."""
|
||||
if (not isinstance(config, dict) or 'class_name' not in config or
|
||||
'config' not in config):
|
||||
raise ValueError('Improper config format: ' + str(config))
|
||||
|
||||
class_name = config['class_name']
|
||||
cls = _get_registered_object(class_name, custom_objects=custom_objects,
|
||||
module_objects=module_objects)
|
||||
if cls is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||
|
||||
cls_config = config['config']
|
||||
|
||||
deserialized_objects = {}
|
||||
for key, item in cls_config.items():
|
||||
if isinstance(item, dict) and '__passive_serialization__' in item:
|
||||
deserialized_objects[key] = _deserialize_keras_object(
|
||||
item,
|
||||
module_objects=module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='config_item')
|
||||
elif (isinstance(item, six.string_types) and
|
||||
tf_inspect.isfunction(_get_registered_object(item, custom_objects))):
|
||||
# Handle custom functions here. When saving functions, we only save the
|
||||
# function's name as a string. If we find a matching string in the custom
|
||||
# objects during deserialization, we convert the string back to the
|
||||
# original function.
|
||||
# Note that a potential issue is that a string field could have a naming
|
||||
# conflict with a custom function name, but this should be a rare case.
|
||||
# This issue does not occur if a string field has a naming conflict with
|
||||
# a custom object, since the config of an object will always be a dict.
|
||||
deserialized_objects[key] = _get_registered_object(item, custom_objects)
|
||||
for key, item in deserialized_objects.items():
|
||||
cls_config[key] = deserialized_objects[key]
|
||||
|
||||
return (cls, cls_config)
|
||||
|
||||
|
||||
def _get_registered_object(name, custom_objects=None, module_objects=None):
|
||||
if custom_objects and name in custom_objects:
|
||||
return custom_objects[name]
|
||||
elif module_objects and name in module_objects:
|
||||
return module_objects[name]
|
||||
return None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user