Expose utilities for getting custom Keras objects/object names.
PiperOrigin-RevId: 286474375 Change-Id: I9951b7cd40b409ad70bd3111acef8a735c4170db
This commit is contained in:
parent
69bb090113
commit
00526b3758
@ -96,8 +96,8 @@ def custom_object_scope(*args):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
*args: Variable length list of dictionaries of name,
|
*args: Variable length list of dictionaries of name, class pairs to add to
|
||||||
class pairs to add to custom objects.
|
custom objects.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Object of type `CustomObjectScope`.
|
Object of type `CustomObjectScope`.
|
||||||
@ -180,13 +180,63 @@ def register_keras_serializable(package='Custom', name=None):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def _get_name_or_custom_name(obj):
|
@keras_export('keras.utils.get_registered_name')
|
||||||
|
def get_registered_name(obj):
|
||||||
|
"""Returns the name registered to an object within the Keras framework.
|
||||||
|
|
||||||
|
This function is part of the Keras serialization and deserialization
|
||||||
|
framework. It maps objects to the string names associated with those objects
|
||||||
|
for serialization/deserialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The name associated with the object, or the default Python name if the
|
||||||
|
object is not registered.
|
||||||
|
"""
|
||||||
if obj in _GLOBAL_CUSTOM_NAMES:
|
if obj in _GLOBAL_CUSTOM_NAMES:
|
||||||
return _GLOBAL_CUSTOM_NAMES[obj]
|
return _GLOBAL_CUSTOM_NAMES[obj]
|
||||||
else:
|
else:
|
||||||
return obj.__name__
|
return obj.__name__
|
||||||
|
|
||||||
|
|
||||||
|
@keras_export('keras.utils.get_registered_object')
|
||||||
|
def get_registered_object(name, custom_objects=None, module_objects=None):
|
||||||
|
"""Returns the class associated with `name` if it is registered with Keras.
|
||||||
|
|
||||||
|
This function is part of the Keras serialization and deserialization
|
||||||
|
framework. It maps strings to the objects associated with them for
|
||||||
|
serialization/deserialization.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
def from_config(cls, config, custom_objects=None):
|
||||||
|
if 'my_custom_object_name' in config:
|
||||||
|
config['hidden_cls'] = tf.keras.utils.get_registered_object(
|
||||||
|
config['my_custom_object_name'], custom_objects=custom_objects)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to look up.
|
||||||
|
custom_objects: A dictionary of custom objects to look the name up in.
|
||||||
|
Generally, custom_objects is provided by the user.
|
||||||
|
module_objects: A dictionary of custom objects to look the name up in.
|
||||||
|
Generally, module_objects is provided by midlevel library implementers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instantiable class associated with 'name', or None if no such class
|
||||||
|
exists.
|
||||||
|
"""
|
||||||
|
if name in _GLOBAL_CUSTOM_OBJECTS:
|
||||||
|
return _GLOBAL_CUSTOM_OBJECTS[name]
|
||||||
|
elif custom_objects and name in custom_objects:
|
||||||
|
return custom_objects[name]
|
||||||
|
elif module_objects and name in module_objects:
|
||||||
|
return module_objects[name]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.utils.serialize_keras_object')
|
@keras_export('keras.utils.serialize_keras_object')
|
||||||
def serialize_keras_object(instance):
|
def serialize_keras_object(instance):
|
||||||
"""Serialize Keras object into JSON."""
|
"""Serialize Keras object into JSON."""
|
||||||
@ -212,22 +262,13 @@ def serialize_keras_object(instance):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
serialization_config[key] = item
|
serialization_config[key] = item
|
||||||
|
|
||||||
name = _get_name_or_custom_name(instance.__class__)
|
name = get_registered_name(instance.__class__)
|
||||||
return serialize_keras_class_and_config(name, serialization_config)
|
return serialize_keras_class_and_config(name, serialization_config)
|
||||||
if hasattr(instance, '__name__'):
|
if hasattr(instance, '__name__'):
|
||||||
return _get_name_or_custom_name(instance)
|
return get_registered_name(instance)
|
||||||
raise ValueError('Cannot serialize', instance)
|
raise ValueError('Cannot serialize', instance)
|
||||||
|
|
||||||
|
|
||||||
def _get_custom_objects_by_name(item, custom_objects=None):
|
|
||||||
"""Returns the item if it is in either local or global custom objects."""
|
|
||||||
if item in _GLOBAL_CUSTOM_OBJECTS:
|
|
||||||
return _GLOBAL_CUSTOM_OBJECTS[item]
|
|
||||||
elif custom_objects and item in custom_objects:
|
|
||||||
return custom_objects[item]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def class_and_config_for_serialized_keras_object(
|
def class_and_config_for_serialized_keras_object(
|
||||||
config,
|
config,
|
||||||
module_objects=None,
|
module_objects=None,
|
||||||
@ -239,15 +280,9 @@ def class_and_config_for_serialized_keras_object(
|
|||||||
raise ValueError('Improper config format: ' + str(config))
|
raise ValueError('Improper config format: ' + str(config))
|
||||||
|
|
||||||
class_name = config['class_name']
|
class_name = config['class_name']
|
||||||
if custom_objects and class_name in custom_objects:
|
cls = get_registered_object(class_name, custom_objects, module_objects)
|
||||||
cls = custom_objects[class_name]
|
if cls is None:
|
||||||
elif class_name in _GLOBAL_CUSTOM_OBJECTS:
|
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||||
cls = _GLOBAL_CUSTOM_OBJECTS[class_name]
|
|
||||||
else:
|
|
||||||
module_objects = module_objects or {}
|
|
||||||
cls = module_objects.get(class_name)
|
|
||||||
if cls is None:
|
|
||||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
|
||||||
|
|
||||||
cls_config = config['config']
|
cls_config = config['config']
|
||||||
deserialized_objects = {}
|
deserialized_objects = {}
|
||||||
@ -258,9 +293,9 @@ def class_and_config_for_serialized_keras_object(
|
|||||||
module_objects=module_objects,
|
module_objects=module_objects,
|
||||||
custom_objects=custom_objects,
|
custom_objects=custom_objects,
|
||||||
printable_module_name='config_item')
|
printable_module_name='config_item')
|
||||||
|
# TODO(momernick): Should this also have 'module_objects'?
|
||||||
elif (isinstance(item, six.string_types) and
|
elif (isinstance(item, six.string_types) and
|
||||||
tf_inspect.isfunction(
|
tf_inspect.isfunction(get_registered_object(item, custom_objects))):
|
||||||
_get_custom_objects_by_name(item, custom_objects))):
|
|
||||||
# Handle custom functions here. When saving functions, we only save the
|
# Handle custom functions here. When saving functions, we only save the
|
||||||
# function's name as a string. If we find a matching string in the custom
|
# function's name as a string. If we find a matching string in the custom
|
||||||
# objects during deserialization, we convert the string back to the
|
# objects during deserialization, we convert the string back to the
|
||||||
@ -269,8 +304,7 @@ def class_and_config_for_serialized_keras_object(
|
|||||||
# conflict with a custom function name, but this should be a rare case.
|
# conflict with a custom function name, but this should be a rare case.
|
||||||
# This issue does not occur if a string field has a naming conflict with
|
# This issue does not occur if a string field has a naming conflict with
|
||||||
# a custom object, since the config of an object will always be a dict.
|
# a custom object, since the config of an object will always be a dict.
|
||||||
deserialized_objects[key] = _get_custom_objects_by_name(
|
deserialized_objects[key] = get_registered_object(item, custom_objects)
|
||||||
item, custom_objects)
|
|
||||||
for key, item in deserialized_objects.items():
|
for key, item in deserialized_objects.items():
|
||||||
cls_config[key] = deserialized_objects[key]
|
cls_config[key] = deserialized_objects[key]
|
||||||
|
|
||||||
@ -382,6 +416,7 @@ def func_load(code, defaults=None, closure=None, globs=None):
|
|||||||
Returns:
|
Returns:
|
||||||
A value wrapped as a cell object (see function "func_load")
|
A value wrapped as a cell object (see function "func_load")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def dummy_fn():
|
def dummy_fn():
|
||||||
# pylint: disable=pointless-statement
|
# pylint: disable=pointless-statement
|
||||||
value # just access it so it gets captured in .__closure__
|
value # just access it so it gets captured in .__closure__
|
||||||
@ -410,8 +445,8 @@ def has_arg(fn, name, accept_all=False):
|
|||||||
Arguments:
|
Arguments:
|
||||||
fn: Callable to inspect.
|
fn: Callable to inspect.
|
||||||
name: Check if `fn` can be called with `name` as a keyword argument.
|
name: Check if `fn` can be called with `name` as a keyword argument.
|
||||||
accept_all: What to return if there is no parameter called `name`
|
accept_all: What to return if there is no parameter called `name` but the
|
||||||
but the function accepts a `**kwargs` argument.
|
function accepts a `**kwargs` argument.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool, whether `fn` accepts a `name` keyword argument.
|
bool, whether `fn` accepts a `name` keyword argument.
|
||||||
@ -430,16 +465,20 @@ class Progbar(object):
|
|||||||
target: Total number of steps expected, None if unknown.
|
target: Total number of steps expected, None if unknown.
|
||||||
width: Progress bar width on screen.
|
width: Progress bar width on screen.
|
||||||
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
|
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
|
||||||
stateful_metrics: Iterable of string names of metrics that
|
stateful_metrics: Iterable of string names of metrics that should *not* be
|
||||||
should *not* be averaged over time. Metrics in this list
|
averaged over time. Metrics in this list will be displayed as-is. All
|
||||||
will be displayed as-is. All others will be averaged
|
others will be averaged by the progbar before display.
|
||||||
by the progbar before display.
|
|
||||||
interval: Minimum visual progress update interval (in seconds).
|
interval: Minimum visual progress update interval (in seconds).
|
||||||
unit_name: Display name for step counts (usually "step" or "sample").
|
unit_name: Display name for step counts (usually "step" or "sample").
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, target, width=30, verbose=1, interval=0.05,
|
def __init__(self,
|
||||||
stateful_metrics=None, unit_name='step'):
|
target,
|
||||||
|
width=30,
|
||||||
|
verbose=1,
|
||||||
|
interval=0.05,
|
||||||
|
stateful_metrics=None,
|
||||||
|
unit_name='step'):
|
||||||
self.target = target
|
self.target = target
|
||||||
self.width = width
|
self.width = width
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@ -469,11 +508,9 @@ class Progbar(object):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
current: Index of current step.
|
current: Index of current step.
|
||||||
values: List of tuples:
|
values: List of tuples: `(name, value_for_last_step)`. If `name` is in
|
||||||
`(name, value_for_last_step)`.
|
`stateful_metrics`, `value_for_last_step` will be displayed as-is.
|
||||||
If `name` is in `stateful_metrics`,
|
Else, an average of the metric over time will be displayed.
|
||||||
`value_for_last_step` will be displayed as-is.
|
|
||||||
Else, an average of the metric over time will be displayed.
|
|
||||||
"""
|
"""
|
||||||
values = values or []
|
values = values or []
|
||||||
for k, v in values:
|
for k, v in values:
|
||||||
@ -538,8 +575,7 @@ class Progbar(object):
|
|||||||
eta = time_per_unit * (self.target - current)
|
eta = time_per_unit * (self.target - current)
|
||||||
if eta > 3600:
|
if eta > 3600:
|
||||||
eta_format = '%d:%02d:%02d' % (eta // 3600,
|
eta_format = '%d:%02d:%02d' % (eta // 3600,
|
||||||
(eta % 3600) // 60,
|
(eta % 3600) // 60, eta % 60)
|
||||||
eta % 60)
|
|
||||||
elif eta > 60:
|
elif eta > 60:
|
||||||
eta_format = '%d:%02d' % (eta // 60, eta % 60)
|
eta_format = '%d:%02d' % (eta // 60, eta % 60)
|
||||||
else:
|
else:
|
||||||
@ -625,10 +661,8 @@ def slice_arrays(arrays, start=None, stop=None):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
arrays: Single array or list of arrays.
|
arrays: Single array or list of arrays.
|
||||||
start: can be an integer index (start index)
|
start: can be an integer index (start index) or a list/array of indices
|
||||||
or a list/array of indices
|
stop: integer (stop index); should be None if `start` was a list.
|
||||||
stop: integer (stop index); should be None if
|
|
||||||
`start` was a list.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A slice of the array(s).
|
A slice of the array(s).
|
||||||
@ -711,7 +745,8 @@ def check_for_unexpected_keys(name, input_dict, expected_values):
|
|||||||
expected_values))
|
expected_values))
|
||||||
|
|
||||||
|
|
||||||
def validate_kwargs(kwargs, allowed_kwargs,
|
def validate_kwargs(kwargs,
|
||||||
|
allowed_kwargs,
|
||||||
error_message='Keyword argument not understood:'):
|
error_message='Keyword argument not understood:'):
|
||||||
"""Checks that all keyword arguments are in the set of allowed keys."""
|
"""Checks that all keyword arguments are in the set of allowed keys."""
|
||||||
for kwarg in kwargs:
|
for kwarg in kwargs:
|
||||||
|
@ -129,6 +129,13 @@ class SerializeKerasObjectTest(test.TestCase):
|
|||||||
inst = OtherTestClass(val=5)
|
inst = OtherTestClass(val=5)
|
||||||
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
|
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
|
||||||
self.assertEqual(serialized_name, class_name)
|
self.assertEqual(serialized_name, class_name)
|
||||||
|
fn_class_name = keras.utils.generic_utils.get_registered_name(
|
||||||
|
OtherTestClass)
|
||||||
|
self.assertEqual(fn_class_name, class_name)
|
||||||
|
|
||||||
|
cls = keras.utils.generic_utils.get_registered_object(fn_class_name)
|
||||||
|
self.assertEqual(OtherTestClass, cls)
|
||||||
|
|
||||||
config = keras.utils.generic_utils.serialize_keras_object(inst)
|
config = keras.utils.generic_utils.serialize_keras_object(inst)
|
||||||
self.assertEqual(class_name, config['class_name'])
|
self.assertEqual(class_name, config['class_name'])
|
||||||
new_inst = keras.utils.generic_utils.deserialize_keras_object(config)
|
new_inst = keras.utils.generic_utils.deserialize_keras_object(config)
|
||||||
@ -145,11 +152,17 @@ class SerializeKerasObjectTest(test.TestCase):
|
|||||||
serialized_name = 'Custom>my_fn'
|
serialized_name = 'Custom>my_fn'
|
||||||
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
|
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
|
||||||
self.assertEqual(serialized_name, class_name)
|
self.assertEqual(serialized_name, class_name)
|
||||||
|
fn_class_name = keras.utils.generic_utils.get_registered_name(my_fn)
|
||||||
|
self.assertEqual(fn_class_name, class_name)
|
||||||
|
|
||||||
config = keras.utils.generic_utils.serialize_keras_object(my_fn)
|
config = keras.utils.generic_utils.serialize_keras_object(my_fn)
|
||||||
self.assertEqual(class_name, config)
|
self.assertEqual(class_name, config)
|
||||||
fn = keras.utils.generic_utils.deserialize_keras_object(config)
|
fn = keras.utils.generic_utils.deserialize_keras_object(config)
|
||||||
self.assertEqual(42, fn())
|
self.assertEqual(42, fn())
|
||||||
|
|
||||||
|
fn_2 = keras.utils.generic_utils.get_registered_object(fn_class_name)
|
||||||
|
self.assertEqual(42, fn_2())
|
||||||
|
|
||||||
def test_serialize_custom_class_without_get_config_fails(self):
|
def test_serialize_custom_class_without_get_config_fails(self):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
@ -48,6 +48,14 @@ tf_module {
|
|||||||
name: "get_file"
|
name: "get_file"
|
||||||
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
|
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_registered_name"
|
||||||
|
argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_registered_object"
|
||||||
|
argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_source_inputs"
|
name: "get_source_inputs"
|
||||||
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
@ -48,6 +48,14 @@ tf_module {
|
|||||||
name: "get_file"
|
name: "get_file"
|
||||||
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
|
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_registered_name"
|
||||||
|
argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_registered_object"
|
||||||
|
argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_source_inputs"
|
name: "get_source_inputs"
|
||||||
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user