Copy the generic_util logic to feature column for serialization/deserialization.
This is the final dependency from feature_column to Keras. The copied functions are trimmed down version since it doesn't have access to Keras global custom object registration, which I don't think are used by feature column. The custom object scope will still work. PiperOrigin-RevId: 315621765 Change-Id: I2ae22af83d625c8e55c7fe21b42194bbdbfded23
This commit is contained in:
parent
2e9eaece7b
commit
090f260aab
@ -90,7 +90,6 @@ py_library(
|
|||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/keras/utils:generic_utils",
|
|
||||||
"//tensorflow/python/training/tracking",
|
"//tensorflow/python/training/tracking",
|
||||||
"//tensorflow/python/training/tracking:data_structures",
|
"//tensorflow/python/training/tracking:data_structures",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
@ -142,9 +142,6 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
# TODO(b/118385027): Dependency on keras can be problematic if Keras moves out
|
|
||||||
# of the main repo.
|
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -2609,7 +2606,8 @@ class NumericColumn(
|
|||||||
def get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
config['normalizer_fn'] = generic_utils.serialize_keras_object(
|
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||||
|
config['normalizer_fn'] = serialization._serialize_keras_object( # pylint: disable=protected-access
|
||||||
self.normalizer_fn)
|
self.normalizer_fn)
|
||||||
config['dtype'] = self.dtype.name
|
config['dtype'] = self.dtype.name
|
||||||
return config
|
return config
|
||||||
@ -2618,8 +2616,9 @@ class NumericColumn(
|
|||||||
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
def from_config(cls, config, custom_objects=None, columns_by_name=None):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
|
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||||
kwargs = _standardize_and_copy_config(config)
|
kwargs = _standardize_and_copy_config(config)
|
||||||
kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
|
kwargs['normalizer_fn'] = serialization._deserialize_keras_object( # pylint: disable=protected-access
|
||||||
config['normalizer_fn'], custom_objects=custom_objects)
|
config['normalizer_fn'], custom_objects=custom_objects)
|
||||||
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
|
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
|
||||||
|
|
||||||
@ -3027,11 +3026,11 @@ class EmbeddingColumn(
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||||
config = dict(zip(self._fields, self))
|
config = dict(zip(self._fields, self))
|
||||||
config['categorical_column'] = serialize_feature_column(
|
config['categorical_column'] = serialization.serialize_feature_column(
|
||||||
self.categorical_column)
|
self.categorical_column)
|
||||||
config['initializer'] = generic_utils.serialize_keras_object(
|
config['initializer'] = serialization._serialize_keras_object( # pylint: disable=protected-access
|
||||||
self.initializer)
|
self.initializer)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@ -3040,13 +3039,13 @@ class EmbeddingColumn(
|
|||||||
"""See 'FeatureColumn` base class."""
|
"""See 'FeatureColumn` base class."""
|
||||||
if 'use_safe_embedding_lookup' not in config:
|
if 'use_safe_embedding_lookup' not in config:
|
||||||
config['use_safe_embedding_lookup'] = True
|
config['use_safe_embedding_lookup'] = True
|
||||||
from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top
|
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
|
||||||
_check_config_keys(config, cls._fields)
|
_check_config_keys(config, cls._fields)
|
||||||
kwargs = _standardize_and_copy_config(config)
|
kwargs = _standardize_and_copy_config(config)
|
||||||
kwargs['categorical_column'] = deserialize_feature_column(
|
kwargs['categorical_column'] = serialization.deserialize_feature_column(
|
||||||
config['categorical_column'], custom_objects, columns_by_name)
|
config['categorical_column'], custom_objects, columns_by_name)
|
||||||
all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
|
all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
|
||||||
kwargs['initializer'] = generic_utils.deserialize_keras_object(
|
kwargs['initializer'] = serialization._deserialize_keras_object( # pylint: disable=protected-access
|
||||||
config['initializer'],
|
config['initializer'],
|
||||||
module_objects=all_initializers,
|
module_objects=all_initializers,
|
||||||
custom_objects=custom_objects)
|
custom_objects=custom_objects)
|
||||||
|
@ -23,12 +23,9 @@ import six
|
|||||||
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
|
from tensorflow.python.feature_column import feature_column_v2 as fc_lib
|
||||||
from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
|
from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
from tensorflow.python.util import tf_decorator
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
# Prevent circular dependencies with Keras serialization.
|
|
||||||
generic_utils = LazyLoader(
|
|
||||||
'generic_utils', globals(),
|
|
||||||
'tensorflow.python.keras.utils.generic_utils')
|
|
||||||
|
|
||||||
_FEATURE_COLUMNS = [
|
_FEATURE_COLUMNS = [
|
||||||
fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
|
fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
|
||||||
@ -124,7 +121,7 @@ def deserialize_feature_column(config,
|
|||||||
columns_by_name = {}
|
columns_by_name = {}
|
||||||
|
|
||||||
(cls,
|
(cls,
|
||||||
cls_config) = generic_utils.class_and_config_for_serialized_keras_object(
|
cls_config) = _class_and_config_for_serialized_keras_object(
|
||||||
config,
|
config,
|
||||||
module_objects=module_feature_column_classes,
|
module_objects=module_feature_column_classes,
|
||||||
custom_objects=custom_objects,
|
custom_objects=custom_objects,
|
||||||
@ -205,3 +202,138 @@ def _column_name_with_class_name(fc):
|
|||||||
A unique name as a string.
|
A unique name as a string.
|
||||||
"""
|
"""
|
||||||
return fc.__class__.__name__ + ':' + fc.name
|
return fc.__class__.__name__ + ':' + fc.name
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_keras_object(instance):
|
||||||
|
"""Serialize a Keras object into a JSON-compatible representation."""
|
||||||
|
_, instance = tf_decorator.unwrap(instance)
|
||||||
|
if instance is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if hasattr(instance, 'get_config'):
|
||||||
|
name = instance.__class__.__name__
|
||||||
|
config = instance.get_config()
|
||||||
|
serialization_config = {}
|
||||||
|
for key, item in config.items():
|
||||||
|
if isinstance(item, six.string_types):
|
||||||
|
serialization_config[key] = item
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Any object of a different type needs to be converted to string or dict
|
||||||
|
# for serialization (e.g. custom functions, custom classes)
|
||||||
|
try:
|
||||||
|
serialized_item = _serialize_keras_object(item)
|
||||||
|
if isinstance(serialized_item, dict) and not isinstance(item, dict):
|
||||||
|
serialized_item['__passive_serialization__'] = True
|
||||||
|
serialization_config[key] = serialized_item
|
||||||
|
except ValueError:
|
||||||
|
serialization_config[key] = item
|
||||||
|
|
||||||
|
return {'class_name': name, 'config': serialization_config}
|
||||||
|
if hasattr(instance, '__name__'):
|
||||||
|
return instance.__name__
|
||||||
|
raise ValueError('Cannot serialize', instance)
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_keras_object(identifier,
|
||||||
|
module_objects=None,
|
||||||
|
custom_objects=None,
|
||||||
|
printable_module_name='object'):
|
||||||
|
"""Turns the serialized form of a Keras object back into an actual object."""
|
||||||
|
if identifier is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(identifier, dict):
|
||||||
|
# In this case we are dealing with a Keras config dictionary.
|
||||||
|
config = identifier
|
||||||
|
(cls, cls_config) = _class_and_config_for_serialized_keras_object(
|
||||||
|
config, module_objects, custom_objects, printable_module_name)
|
||||||
|
|
||||||
|
if hasattr(cls, 'from_config'):
|
||||||
|
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
||||||
|
custom_objects = custom_objects or {}
|
||||||
|
|
||||||
|
if 'custom_objects' in arg_spec.args:
|
||||||
|
return cls.from_config(
|
||||||
|
cls_config,
|
||||||
|
custom_objects=dict(
|
||||||
|
list(custom_objects.items())))
|
||||||
|
return cls.from_config(cls_config)
|
||||||
|
else:
|
||||||
|
# Then `cls` may be a function returning a class.
|
||||||
|
# in this case by convention `config` holds
|
||||||
|
# the kwargs of the function.
|
||||||
|
custom_objects = custom_objects or {}
|
||||||
|
return cls(**cls_config)
|
||||||
|
elif isinstance(identifier, six.string_types):
|
||||||
|
object_name = identifier
|
||||||
|
if custom_objects and object_name in custom_objects:
|
||||||
|
obj = custom_objects.get(object_name)
|
||||||
|
else:
|
||||||
|
obj = module_objects.get(object_name)
|
||||||
|
if obj is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Unknown ' + printable_module_name + ': ' + object_name)
|
||||||
|
# Classes passed by name are instantiated with no args, functions are
|
||||||
|
# returned as-is.
|
||||||
|
if tf_inspect.isclass(obj):
|
||||||
|
return obj()
|
||||||
|
return obj
|
||||||
|
elif tf_inspect.isfunction(identifier):
|
||||||
|
# If a function has already been deserialized, return as is.
|
||||||
|
return identifier
|
||||||
|
else:
|
||||||
|
raise ValueError('Could not interpret serialized %s: %s' %
|
||||||
|
(printable_module_name, identifier))
|
||||||
|
|
||||||
|
|
||||||
|
def _class_and_config_for_serialized_keras_object(
|
||||||
|
config,
|
||||||
|
module_objects=None,
|
||||||
|
custom_objects=None,
|
||||||
|
printable_module_name='object'):
|
||||||
|
"""Returns the class name and config for a serialized keras object."""
|
||||||
|
if (not isinstance(config, dict) or 'class_name' not in config or
|
||||||
|
'config' not in config):
|
||||||
|
raise ValueError('Improper config format: ' + str(config))
|
||||||
|
|
||||||
|
class_name = config['class_name']
|
||||||
|
cls = _get_registered_object(class_name, custom_objects=custom_objects,
|
||||||
|
module_objects=module_objects)
|
||||||
|
if cls is None:
|
||||||
|
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||||
|
|
||||||
|
cls_config = config['config']
|
||||||
|
|
||||||
|
deserialized_objects = {}
|
||||||
|
for key, item in cls_config.items():
|
||||||
|
if isinstance(item, dict) and '__passive_serialization__' in item:
|
||||||
|
deserialized_objects[key] = _deserialize_keras_object(
|
||||||
|
item,
|
||||||
|
module_objects=module_objects,
|
||||||
|
custom_objects=custom_objects,
|
||||||
|
printable_module_name='config_item')
|
||||||
|
elif (isinstance(item, six.string_types) and
|
||||||
|
tf_inspect.isfunction(_get_registered_object(item, custom_objects))):
|
||||||
|
# Handle custom functions here. When saving functions, we only save the
|
||||||
|
# function's name as a string. If we find a matching string in the custom
|
||||||
|
# objects during deserialization, we convert the string back to the
|
||||||
|
# original function.
|
||||||
|
# Note that a potential issue is that a string field could have a naming
|
||||||
|
# conflict with a custom function name, but this should be a rare case.
|
||||||
|
# This issue does not occur if a string field has a naming conflict with
|
||||||
|
# a custom object, since the config of an object will always be a dict.
|
||||||
|
deserialized_objects[key] = _get_registered_object(item, custom_objects)
|
||||||
|
for key, item in deserialized_objects.items():
|
||||||
|
cls_config[key] = deserialized_objects[key]
|
||||||
|
|
||||||
|
return (cls, cls_config)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_registered_object(name, custom_objects=None, module_objects=None):
|
||||||
|
if custom_objects and name in custom_objects:
|
||||||
|
return custom_objects[name]
|
||||||
|
elif module_objects and name in module_objects:
|
||||||
|
return module_objects[name]
|
||||||
|
return None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user