Expose utilities for getting custom Keras objects/object names.

PiperOrigin-RevId: 286474375
Change-Id: I9951b7cd40b409ad70bd3111acef8a735c4170db
This commit is contained in:
A. Unique TensorFlower 2019-12-19 15:23:26 -08:00 committed by TensorFlower Gardener
parent 69bb090113
commit 00526b3758
4 changed files with 111 additions and 47 deletions

View File

@ -96,8 +96,8 @@ def custom_object_scope(*args):
``` ```
Arguments: Arguments:
*args: Variable length list of dictionaries of name, *args: Variable length list of dictionaries of name, class pairs to add to
class pairs to add to custom objects. custom objects.
Returns: Returns:
Object of type `CustomObjectScope`. Object of type `CustomObjectScope`.
@ -180,13 +180,63 @@ def register_keras_serializable(package='Custom', name=None):
return decorator return decorator
def _get_name_or_custom_name(obj): @keras_export('keras.utils.get_registered_name')
def get_registered_name(obj):
"""Returns the name registered to an object within the Keras framework.
This function is part of the Keras serialization and deserialization
framework. It maps objects to the string names associated with those objects
for serialization/deserialization.
Args:
obj: The object to look up.
Returns:
The name associated with the object, or the default Python name if the
object is not registered.
"""
if obj in _GLOBAL_CUSTOM_NAMES: if obj in _GLOBAL_CUSTOM_NAMES:
return _GLOBAL_CUSTOM_NAMES[obj] return _GLOBAL_CUSTOM_NAMES[obj]
else: else:
return obj.__name__ return obj.__name__
@keras_export('keras.utils.get_registered_object')
def get_registered_object(name, custom_objects=None, module_objects=None):
"""Returns the class associated with `name` if it is registered with Keras.
This function is part of the Keras serialization and deserialization
framework. It maps strings to the objects associated with them for
serialization/deserialization.
Example:
```
def from_config(cls, config, custom_objects=None):
if 'my_custom_object_name' in config:
config['hidden_cls'] = tf.keras.utils.get_registered_object(
config['my_custom_object_name'], custom_objects=custom_objects)
```
Args:
name: The name to look up.
custom_objects: A dictionary of custom objects to look the name up in.
Generally, custom_objects is provided by the user.
module_objects: A dictionary of custom objects to look the name up in.
Generally, module_objects is provided by midlevel library implementers.
Returns:
An instantiable class associated with 'name', or None if no such class
exists.
"""
if name in _GLOBAL_CUSTOM_OBJECTS:
return _GLOBAL_CUSTOM_OBJECTS[name]
elif custom_objects and name in custom_objects:
return custom_objects[name]
elif module_objects and name in module_objects:
return module_objects[name]
return None
@keras_export('keras.utils.serialize_keras_object') @keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance): def serialize_keras_object(instance):
"""Serialize Keras object into JSON.""" """Serialize Keras object into JSON."""
@ -212,22 +262,13 @@ def serialize_keras_object(instance):
except ValueError: except ValueError:
serialization_config[key] = item serialization_config[key] = item
name = _get_name_or_custom_name(instance.__class__) name = get_registered_name(instance.__class__)
return serialize_keras_class_and_config(name, serialization_config) return serialize_keras_class_and_config(name, serialization_config)
if hasattr(instance, '__name__'): if hasattr(instance, '__name__'):
return _get_name_or_custom_name(instance) return get_registered_name(instance)
raise ValueError('Cannot serialize', instance) raise ValueError('Cannot serialize', instance)
def _get_custom_objects_by_name(item, custom_objects=None):
"""Returns the item if it is in either local or global custom objects."""
if item in _GLOBAL_CUSTOM_OBJECTS:
return _GLOBAL_CUSTOM_OBJECTS[item]
elif custom_objects and item in custom_objects:
return custom_objects[item]
return None
def class_and_config_for_serialized_keras_object( def class_and_config_for_serialized_keras_object(
config, config,
module_objects=None, module_objects=None,
@ -239,15 +280,9 @@ def class_and_config_for_serialized_keras_object(
raise ValueError('Improper config format: ' + str(config)) raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name'] class_name = config['class_name']
if custom_objects and class_name in custom_objects: cls = get_registered_object(class_name, custom_objects, module_objects)
cls = custom_objects[class_name] if cls is None:
elif class_name in _GLOBAL_CUSTOM_OBJECTS: raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
cls = _GLOBAL_CUSTOM_OBJECTS[class_name]
else:
module_objects = module_objects or {}
cls = module_objects.get(class_name)
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
cls_config = config['config'] cls_config = config['config']
deserialized_objects = {} deserialized_objects = {}
@ -258,9 +293,9 @@ def class_and_config_for_serialized_keras_object(
module_objects=module_objects, module_objects=module_objects,
custom_objects=custom_objects, custom_objects=custom_objects,
printable_module_name='config_item') printable_module_name='config_item')
# TODO(momernick): Should this also have 'module_objects'?
elif (isinstance(item, six.string_types) and elif (isinstance(item, six.string_types) and
tf_inspect.isfunction( tf_inspect.isfunction(get_registered_object(item, custom_objects))):
_get_custom_objects_by_name(item, custom_objects))):
# Handle custom functions here. When saving functions, we only save the # Handle custom functions here. When saving functions, we only save the
# function's name as a string. If we find a matching string in the custom # function's name as a string. If we find a matching string in the custom
# objects during deserialization, we convert the string back to the # objects during deserialization, we convert the string back to the
@ -269,8 +304,7 @@ def class_and_config_for_serialized_keras_object(
# conflict with a custom function name, but this should be a rare case. # conflict with a custom function name, but this should be a rare case.
# This issue does not occur if a string field has a naming conflict with # This issue does not occur if a string field has a naming conflict with
# a custom object, since the config of an object will always be a dict. # a custom object, since the config of an object will always be a dict.
deserialized_objects[key] = _get_custom_objects_by_name( deserialized_objects[key] = get_registered_object(item, custom_objects)
item, custom_objects)
for key, item in deserialized_objects.items(): for key, item in deserialized_objects.items():
cls_config[key] = deserialized_objects[key] cls_config[key] = deserialized_objects[key]
@ -382,6 +416,7 @@ def func_load(code, defaults=None, closure=None, globs=None):
Returns: Returns:
A value wrapped as a cell object (see function "func_load") A value wrapped as a cell object (see function "func_load")
""" """
def dummy_fn(): def dummy_fn():
# pylint: disable=pointless-statement # pylint: disable=pointless-statement
value # just access it so it gets captured in .__closure__ value # just access it so it gets captured in .__closure__
@ -410,8 +445,8 @@ def has_arg(fn, name, accept_all=False):
Arguments: Arguments:
fn: Callable to inspect. fn: Callable to inspect.
name: Check if `fn` can be called with `name` as a keyword argument. name: Check if `fn` can be called with `name` as a keyword argument.
accept_all: What to return if there is no parameter called `name` accept_all: What to return if there is no parameter called `name` but the
but the function accepts a `**kwargs` argument. function accepts a `**kwargs` argument.
Returns: Returns:
bool, whether `fn` accepts a `name` keyword argument. bool, whether `fn` accepts a `name` keyword argument.
@ -430,16 +465,20 @@ class Progbar(object):
target: Total number of steps expected, None if unknown. target: Total number of steps expected, None if unknown.
width: Progress bar width on screen. width: Progress bar width on screen.
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics: Iterable of string names of metrics that stateful_metrics: Iterable of string names of metrics that should *not* be
should *not* be averaged over time. Metrics in this list averaged over time. Metrics in this list will be displayed as-is. All
will be displayed as-is. All others will be averaged others will be averaged by the progbar before display.
by the progbar before display.
interval: Minimum visual progress update interval (in seconds). interval: Minimum visual progress update interval (in seconds).
unit_name: Display name for step counts (usually "step" or "sample"). unit_name: Display name for step counts (usually "step" or "sample").
""" """
def __init__(self, target, width=30, verbose=1, interval=0.05, def __init__(self,
stateful_metrics=None, unit_name='step'): target,
width=30,
verbose=1,
interval=0.05,
stateful_metrics=None,
unit_name='step'):
self.target = target self.target = target
self.width = width self.width = width
self.verbose = verbose self.verbose = verbose
@ -469,11 +508,9 @@ class Progbar(object):
Arguments: Arguments:
current: Index of current step. current: Index of current step.
values: List of tuples: values: List of tuples: `(name, value_for_last_step)`. If `name` is in
`(name, value_for_last_step)`. `stateful_metrics`, `value_for_last_step` will be displayed as-is.
If `name` is in `stateful_metrics`, Else, an average of the metric over time will be displayed.
`value_for_last_step` will be displayed as-is.
Else, an average of the metric over time will be displayed.
""" """
values = values or [] values = values or []
for k, v in values: for k, v in values:
@ -538,8 +575,7 @@ class Progbar(object):
eta = time_per_unit * (self.target - current) eta = time_per_unit * (self.target - current)
if eta > 3600: if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600, eta_format = '%d:%02d:%02d' % (eta // 3600,
(eta % 3600) // 60, (eta % 3600) // 60, eta % 60)
eta % 60)
elif eta > 60: elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60) eta_format = '%d:%02d' % (eta // 60, eta % 60)
else: else:
@ -625,10 +661,8 @@ def slice_arrays(arrays, start=None, stop=None):
Arguments: Arguments:
arrays: Single array or list of arrays. arrays: Single array or list of arrays.
start: can be an integer index (start index) start: can be an integer index (start index) or a list/array of indices
or a list/array of indices stop: integer (stop index); should be None if `start` was a list.
stop: integer (stop index); should be None if
`start` was a list.
Returns: Returns:
A slice of the array(s). A slice of the array(s).
@ -711,7 +745,8 @@ def check_for_unexpected_keys(name, input_dict, expected_values):
expected_values)) expected_values))
def validate_kwargs(kwargs, allowed_kwargs, def validate_kwargs(kwargs,
allowed_kwargs,
error_message='Keyword argument not understood:'): error_message='Keyword argument not understood:'):
"""Checks that all keyword arguments are in the set of allowed keys.""" """Checks that all keyword arguments are in the set of allowed keys."""
for kwarg in kwargs: for kwarg in kwargs:

View File

@ -129,6 +129,13 @@ class SerializeKerasObjectTest(test.TestCase):
inst = OtherTestClass(val=5) inst = OtherTestClass(val=5)
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass] class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
self.assertEqual(serialized_name, class_name) self.assertEqual(serialized_name, class_name)
fn_class_name = keras.utils.generic_utils.get_registered_name(
OtherTestClass)
self.assertEqual(fn_class_name, class_name)
cls = keras.utils.generic_utils.get_registered_object(fn_class_name)
self.assertEqual(OtherTestClass, cls)
config = keras.utils.generic_utils.serialize_keras_object(inst) config = keras.utils.generic_utils.serialize_keras_object(inst)
self.assertEqual(class_name, config['class_name']) self.assertEqual(class_name, config['class_name'])
new_inst = keras.utils.generic_utils.deserialize_keras_object(config) new_inst = keras.utils.generic_utils.deserialize_keras_object(config)
@ -145,11 +152,17 @@ class SerializeKerasObjectTest(test.TestCase):
serialized_name = 'Custom>my_fn' serialized_name = 'Custom>my_fn'
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn] class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
self.assertEqual(serialized_name, class_name) self.assertEqual(serialized_name, class_name)
fn_class_name = keras.utils.generic_utils.get_registered_name(my_fn)
self.assertEqual(fn_class_name, class_name)
config = keras.utils.generic_utils.serialize_keras_object(my_fn) config = keras.utils.generic_utils.serialize_keras_object(my_fn)
self.assertEqual(class_name, config) self.assertEqual(class_name, config)
fn = keras.utils.generic_utils.deserialize_keras_object(config) fn = keras.utils.generic_utils.deserialize_keras_object(config)
self.assertEqual(42, fn()) self.assertEqual(42, fn())
fn_2 = keras.utils.generic_utils.get_registered_object(fn_class_name)
self.assertEqual(42, fn_2())
def test_serialize_custom_class_without_get_config_fails(self): def test_serialize_custom_class_without_get_config_fails(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(

View File

@ -48,6 +48,14 @@ tf_module {
name: "get_file" name: "get_file"
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], " argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
} }
member_method {
name: "get_registered_name"
argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_registered_object"
argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method { member_method {
name: "get_source_inputs" name: "get_source_inputs"
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -48,6 +48,14 @@ tf_module {
name: "get_file" name: "get_file"
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], " argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
} }
member_method {
name: "get_registered_name"
argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_registered_object"
argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method { member_method {
name: "get_source_inputs" name: "get_source_inputs"
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "