Expose utilities for getting custom Keras objects/object names.
PiperOrigin-RevId: 286474375 Change-Id: I9951b7cd40b409ad70bd3111acef8a735c4170db
This commit is contained in:
parent
69bb090113
commit
00526b3758
@ -96,8 +96,8 @@ def custom_object_scope(*args):
|
||||
```
|
||||
|
||||
Arguments:
|
||||
*args: Variable length list of dictionaries of name,
|
||||
class pairs to add to custom objects.
|
||||
*args: Variable length list of dictionaries of name, class pairs to add to
|
||||
custom objects.
|
||||
|
||||
Returns:
|
||||
Object of type `CustomObjectScope`.
|
||||
@ -180,13 +180,63 @@ def register_keras_serializable(package='Custom', name=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_name_or_custom_name(obj):
|
||||
@keras_export('keras.utils.get_registered_name')
|
||||
def get_registered_name(obj):
|
||||
"""Returns the name registered to an object within the Keras framework.
|
||||
|
||||
This function is part of the Keras serialization and deserialization
|
||||
framework. It maps objects to the string names associated with those objects
|
||||
for serialization/deserialization.
|
||||
|
||||
Args:
|
||||
obj: The object to look up.
|
||||
|
||||
Returns:
|
||||
The name associated with the object, or the default Python name if the
|
||||
object is not registered.
|
||||
"""
|
||||
if obj in _GLOBAL_CUSTOM_NAMES:
|
||||
return _GLOBAL_CUSTOM_NAMES[obj]
|
||||
else:
|
||||
return obj.__name__
|
||||
|
||||
|
||||
@keras_export('keras.utils.get_registered_object')
|
||||
def get_registered_object(name, custom_objects=None, module_objects=None):
|
||||
"""Returns the class associated with `name` if it is registered with Keras.
|
||||
|
||||
This function is part of the Keras serialization and deserialization
|
||||
framework. It maps strings to the objects associated with them for
|
||||
serialization/deserialization.
|
||||
|
||||
Example:
|
||||
```
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
if 'my_custom_object_name' in config:
|
||||
config['hidden_cls'] = tf.keras.utils.get_registered_object(
|
||||
config['my_custom_object_name'], custom_objects=custom_objects)
|
||||
```
|
||||
|
||||
Args:
|
||||
name: The name to look up.
|
||||
custom_objects: A dictionary of custom objects to look the name up in.
|
||||
Generally, custom_objects is provided by the user.
|
||||
module_objects: A dictionary of custom objects to look the name up in.
|
||||
Generally, module_objects is provided by midlevel library implementers.
|
||||
|
||||
Returns:
|
||||
An instantiable class associated with 'name', or None if no such class
|
||||
exists.
|
||||
"""
|
||||
if name in _GLOBAL_CUSTOM_OBJECTS:
|
||||
return _GLOBAL_CUSTOM_OBJECTS[name]
|
||||
elif custom_objects and name in custom_objects:
|
||||
return custom_objects[name]
|
||||
elif module_objects and name in module_objects:
|
||||
return module_objects[name]
|
||||
return None
|
||||
|
||||
|
||||
@keras_export('keras.utils.serialize_keras_object')
|
||||
def serialize_keras_object(instance):
|
||||
"""Serialize Keras object into JSON."""
|
||||
@ -212,22 +262,13 @@ def serialize_keras_object(instance):
|
||||
except ValueError:
|
||||
serialization_config[key] = item
|
||||
|
||||
name = _get_name_or_custom_name(instance.__class__)
|
||||
name = get_registered_name(instance.__class__)
|
||||
return serialize_keras_class_and_config(name, serialization_config)
|
||||
if hasattr(instance, '__name__'):
|
||||
return _get_name_or_custom_name(instance)
|
||||
return get_registered_name(instance)
|
||||
raise ValueError('Cannot serialize', instance)
|
||||
|
||||
|
||||
def _get_custom_objects_by_name(item, custom_objects=None):
|
||||
"""Returns the item if it is in either local or global custom objects."""
|
||||
if item in _GLOBAL_CUSTOM_OBJECTS:
|
||||
return _GLOBAL_CUSTOM_OBJECTS[item]
|
||||
elif custom_objects and item in custom_objects:
|
||||
return custom_objects[item]
|
||||
return None
|
||||
|
||||
|
||||
def class_and_config_for_serialized_keras_object(
|
||||
config,
|
||||
module_objects=None,
|
||||
@ -239,15 +280,9 @@ def class_and_config_for_serialized_keras_object(
|
||||
raise ValueError('Improper config format: ' + str(config))
|
||||
|
||||
class_name = config['class_name']
|
||||
if custom_objects and class_name in custom_objects:
|
||||
cls = custom_objects[class_name]
|
||||
elif class_name in _GLOBAL_CUSTOM_OBJECTS:
|
||||
cls = _GLOBAL_CUSTOM_OBJECTS[class_name]
|
||||
else:
|
||||
module_objects = module_objects or {}
|
||||
cls = module_objects.get(class_name)
|
||||
if cls is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||
cls = get_registered_object(class_name, custom_objects, module_objects)
|
||||
if cls is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||
|
||||
cls_config = config['config']
|
||||
deserialized_objects = {}
|
||||
@ -258,9 +293,9 @@ def class_and_config_for_serialized_keras_object(
|
||||
module_objects=module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='config_item')
|
||||
# TODO(momernick): Should this also have 'module_objects'?
|
||||
elif (isinstance(item, six.string_types) and
|
||||
tf_inspect.isfunction(
|
||||
_get_custom_objects_by_name(item, custom_objects))):
|
||||
tf_inspect.isfunction(get_registered_object(item, custom_objects))):
|
||||
# Handle custom functions here. When saving functions, we only save the
|
||||
# function's name as a string. If we find a matching string in the custom
|
||||
# objects during deserialization, we convert the string back to the
|
||||
@ -269,8 +304,7 @@ def class_and_config_for_serialized_keras_object(
|
||||
# conflict with a custom function name, but this should be a rare case.
|
||||
# This issue does not occur if a string field has a naming conflict with
|
||||
# a custom object, since the config of an object will always be a dict.
|
||||
deserialized_objects[key] = _get_custom_objects_by_name(
|
||||
item, custom_objects)
|
||||
deserialized_objects[key] = get_registered_object(item, custom_objects)
|
||||
for key, item in deserialized_objects.items():
|
||||
cls_config[key] = deserialized_objects[key]
|
||||
|
||||
@ -382,6 +416,7 @@ def func_load(code, defaults=None, closure=None, globs=None):
|
||||
Returns:
|
||||
A value wrapped as a cell object (see function "func_load")
|
||||
"""
|
||||
|
||||
def dummy_fn():
|
||||
# pylint: disable=pointless-statement
|
||||
value # just access it so it gets captured in .__closure__
|
||||
@ -410,8 +445,8 @@ def has_arg(fn, name, accept_all=False):
|
||||
Arguments:
|
||||
fn: Callable to inspect.
|
||||
name: Check if `fn` can be called with `name` as a keyword argument.
|
||||
accept_all: What to return if there is no parameter called `name`
|
||||
but the function accepts a `**kwargs` argument.
|
||||
accept_all: What to return if there is no parameter called `name` but the
|
||||
function accepts a `**kwargs` argument.
|
||||
|
||||
Returns:
|
||||
bool, whether `fn` accepts a `name` keyword argument.
|
||||
@ -430,16 +465,20 @@ class Progbar(object):
|
||||
target: Total number of steps expected, None if unknown.
|
||||
width: Progress bar width on screen.
|
||||
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
|
||||
stateful_metrics: Iterable of string names of metrics that
|
||||
should *not* be averaged over time. Metrics in this list
|
||||
will be displayed as-is. All others will be averaged
|
||||
by the progbar before display.
|
||||
stateful_metrics: Iterable of string names of metrics that should *not* be
|
||||
averaged over time. Metrics in this list will be displayed as-is. All
|
||||
others will be averaged by the progbar before display.
|
||||
interval: Minimum visual progress update interval (in seconds).
|
||||
unit_name: Display name for step counts (usually "step" or "sample").
|
||||
"""
|
||||
|
||||
def __init__(self, target, width=30, verbose=1, interval=0.05,
|
||||
stateful_metrics=None, unit_name='step'):
|
||||
def __init__(self,
|
||||
target,
|
||||
width=30,
|
||||
verbose=1,
|
||||
interval=0.05,
|
||||
stateful_metrics=None,
|
||||
unit_name='step'):
|
||||
self.target = target
|
||||
self.width = width
|
||||
self.verbose = verbose
|
||||
@ -469,11 +508,9 @@ class Progbar(object):
|
||||
|
||||
Arguments:
|
||||
current: Index of current step.
|
||||
values: List of tuples:
|
||||
`(name, value_for_last_step)`.
|
||||
If `name` is in `stateful_metrics`,
|
||||
`value_for_last_step` will be displayed as-is.
|
||||
Else, an average of the metric over time will be displayed.
|
||||
values: List of tuples: `(name, value_for_last_step)`. If `name` is in
|
||||
`stateful_metrics`, `value_for_last_step` will be displayed as-is.
|
||||
Else, an average of the metric over time will be displayed.
|
||||
"""
|
||||
values = values or []
|
||||
for k, v in values:
|
||||
@ -538,8 +575,7 @@ class Progbar(object):
|
||||
eta = time_per_unit * (self.target - current)
|
||||
if eta > 3600:
|
||||
eta_format = '%d:%02d:%02d' % (eta // 3600,
|
||||
(eta % 3600) // 60,
|
||||
eta % 60)
|
||||
(eta % 3600) // 60, eta % 60)
|
||||
elif eta > 60:
|
||||
eta_format = '%d:%02d' % (eta // 60, eta % 60)
|
||||
else:
|
||||
@ -625,10 +661,8 @@ def slice_arrays(arrays, start=None, stop=None):
|
||||
|
||||
Arguments:
|
||||
arrays: Single array or list of arrays.
|
||||
start: can be an integer index (start index)
|
||||
or a list/array of indices
|
||||
stop: integer (stop index); should be None if
|
||||
`start` was a list.
|
||||
start: can be an integer index (start index) or a list/array of indices
|
||||
stop: integer (stop index); should be None if `start` was a list.
|
||||
|
||||
Returns:
|
||||
A slice of the array(s).
|
||||
@ -711,7 +745,8 @@ def check_for_unexpected_keys(name, input_dict, expected_values):
|
||||
expected_values))
|
||||
|
||||
|
||||
def validate_kwargs(kwargs, allowed_kwargs,
|
||||
def validate_kwargs(kwargs,
|
||||
allowed_kwargs,
|
||||
error_message='Keyword argument not understood:'):
|
||||
"""Checks that all keyword arguments are in the set of allowed keys."""
|
||||
for kwarg in kwargs:
|
||||
|
@ -129,6 +129,13 @@ class SerializeKerasObjectTest(test.TestCase):
|
||||
inst = OtherTestClass(val=5)
|
||||
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
|
||||
self.assertEqual(serialized_name, class_name)
|
||||
fn_class_name = keras.utils.generic_utils.get_registered_name(
|
||||
OtherTestClass)
|
||||
self.assertEqual(fn_class_name, class_name)
|
||||
|
||||
cls = keras.utils.generic_utils.get_registered_object(fn_class_name)
|
||||
self.assertEqual(OtherTestClass, cls)
|
||||
|
||||
config = keras.utils.generic_utils.serialize_keras_object(inst)
|
||||
self.assertEqual(class_name, config['class_name'])
|
||||
new_inst = keras.utils.generic_utils.deserialize_keras_object(config)
|
||||
@ -145,11 +152,17 @@ class SerializeKerasObjectTest(test.TestCase):
|
||||
serialized_name = 'Custom>my_fn'
|
||||
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
|
||||
self.assertEqual(serialized_name, class_name)
|
||||
fn_class_name = keras.utils.generic_utils.get_registered_name(my_fn)
|
||||
self.assertEqual(fn_class_name, class_name)
|
||||
|
||||
config = keras.utils.generic_utils.serialize_keras_object(my_fn)
|
||||
self.assertEqual(class_name, config)
|
||||
fn = keras.utils.generic_utils.deserialize_keras_object(config)
|
||||
self.assertEqual(42, fn())
|
||||
|
||||
fn_2 = keras.utils.generic_utils.get_registered_object(fn_class_name)
|
||||
self.assertEqual(42, fn_2())
|
||||
|
||||
def test_serialize_custom_class_without_get_config_fails(self):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -48,6 +48,14 @@ tf_module {
|
||||
name: "get_file"
|
||||
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_registered_name"
|
||||
argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_registered_object"
|
||||
argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_source_inputs"
|
||||
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
@ -48,6 +48,14 @@ tf_module {
|
||||
name: "get_file"
|
||||
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_registered_name"
|
||||
argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_registered_object"
|
||||
argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_source_inputs"
|
||||
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user