From 00526b3758c86999e895e9a225eeec0931ea961f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Dec 2019 15:23:26 -0800 Subject: [PATCH] Expose utilities for getting custom Keras objects/object names. PiperOrigin-RevId: 286474375 Change-Id: I9951b7cd40b409ad70bd3111acef8a735c4170db --- .../python/keras/utils/generic_utils.py | 129 +++++++++++------- .../python/keras/utils/generic_utils_test.py | 13 ++ .../golden/v1/tensorflow.keras.utils.pbtxt | 8 ++ .../golden/v2/tensorflow.keras.utils.pbtxt | 8 ++ 4 files changed, 111 insertions(+), 47 deletions(-) diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 4fbb6d68eeb..ebab3d79424 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -96,8 +96,8 @@ def custom_object_scope(*args): ``` Arguments: - *args: Variable length list of dictionaries of name, - class pairs to add to custom objects. + *args: Variable length list of dictionaries of name, class pairs to add to + custom objects. Returns: Object of type `CustomObjectScope`. @@ -180,13 +180,63 @@ def register_keras_serializable(package='Custom', name=None): return decorator -def _get_name_or_custom_name(obj): +@keras_export('keras.utils.get_registered_name') +def get_registered_name(obj): + """Returns the name registered to an object within the Keras framework. + + This function is part of the Keras serialization and deserialization + framework. It maps objects to the string names associated with those objects + for serialization/deserialization. + + Args: + obj: The object to look up. + + Returns: + The name associated with the object, or the default Python name if the + object is not registered. + """ if obj in _GLOBAL_CUSTOM_NAMES: return _GLOBAL_CUSTOM_NAMES[obj] else: return obj.__name__ +@keras_export('keras.utils.get_registered_object') +def get_registered_object(name, custom_objects=None, module_objects=None): + """Returns the class associated with `name` if it is registered with Keras. + + This function is part of the Keras serialization and deserialization + framework. It maps strings to the objects associated with them for + serialization/deserialization. + + Example: + ``` + def from_config(cls, config, custom_objects=None): + if 'my_custom_object_name' in config: + config['hidden_cls'] = tf.keras.utils.get_registered_object( + config['my_custom_object_name'], custom_objects=custom_objects) + ``` + + Args: + name: The name to look up. + custom_objects: A dictionary of custom objects to look the name up in. + Generally, custom_objects is provided by the user. + module_objects: A dictionary of custom objects to look the name up in. + Generally, module_objects is provided by midlevel library implementers. + + Returns: + An instantiable class associated with 'name', or None if no such class + exists. + """ + if name in _GLOBAL_CUSTOM_OBJECTS: + return _GLOBAL_CUSTOM_OBJECTS[name] + elif custom_objects and name in custom_objects: + return custom_objects[name] + elif module_objects and name in module_objects: + return module_objects[name] + return None + + @keras_export('keras.utils.serialize_keras_object') def serialize_keras_object(instance): """Serialize Keras object into JSON.""" @@ -212,22 +262,13 @@ def serialize_keras_object(instance): except ValueError: serialization_config[key] = item - name = _get_name_or_custom_name(instance.__class__) + name = get_registered_name(instance.__class__) return serialize_keras_class_and_config(name, serialization_config) if hasattr(instance, '__name__'): - return _get_name_or_custom_name(instance) + return get_registered_name(instance) raise ValueError('Cannot serialize', instance) -def _get_custom_objects_by_name(item, custom_objects=None): - """Returns the item if it is in either local or global custom objects.""" - if item in _GLOBAL_CUSTOM_OBJECTS: - return _GLOBAL_CUSTOM_OBJECTS[item] - elif custom_objects and item in custom_objects: - return custom_objects[item] - return None - - def class_and_config_for_serialized_keras_object( config, module_objects=None, @@ -239,15 +280,9 @@ def class_and_config_for_serialized_keras_object( raise ValueError('Improper config format: ' + str(config)) class_name = config['class_name'] - if custom_objects and class_name in custom_objects: - cls = custom_objects[class_name] - elif class_name in _GLOBAL_CUSTOM_OBJECTS: - cls = _GLOBAL_CUSTOM_OBJECTS[class_name] - else: - module_objects = module_objects or {} - cls = module_objects.get(class_name) - if cls is None: - raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) + cls = get_registered_object(class_name, custom_objects, module_objects) + if cls is None: + raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) cls_config = config['config'] deserialized_objects = {} @@ -258,9 +293,9 @@ def class_and_config_for_serialized_keras_object( module_objects=module_objects, custom_objects=custom_objects, printable_module_name='config_item') + # TODO(momernick): Should this also have 'module_objects'? elif (isinstance(item, six.string_types) and - tf_inspect.isfunction( - _get_custom_objects_by_name(item, custom_objects))): + tf_inspect.isfunction(get_registered_object(item, custom_objects))): # Handle custom functions here. When saving functions, we only save the # function's name as a string. If we find a matching string in the custom # objects during deserialization, we convert the string back to the @@ -269,8 +304,7 @@ def class_and_config_for_serialized_keras_object( # conflict with a custom function name, but this should be a rare case. # This issue does not occur if a string field has a naming conflict with # a custom object, since the config of an object will always be a dict. - deserialized_objects[key] = _get_custom_objects_by_name( - item, custom_objects) + deserialized_objects[key] = get_registered_object(item, custom_objects) for key, item in deserialized_objects.items(): cls_config[key] = deserialized_objects[key] @@ -382,6 +416,7 @@ def func_load(code, defaults=None, closure=None, globs=None): Returns: A value wrapped as a cell object (see function "func_load") """ + def dummy_fn(): # pylint: disable=pointless-statement value # just access it so it gets captured in .__closure__ @@ -410,8 +445,8 @@ def has_arg(fn, name, accept_all=False): Arguments: fn: Callable to inspect. name: Check if `fn` can be called with `name` as a keyword argument. - accept_all: What to return if there is no parameter called `name` - but the function accepts a `**kwargs` argument. + accept_all: What to return if there is no parameter called `name` but the + function accepts a `**kwargs` argument. Returns: bool, whether `fn` accepts a `name` keyword argument. @@ -430,16 +465,20 @@ class Progbar(object): target: Total number of steps expected, None if unknown. width: Progress bar width on screen. verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) - stateful_metrics: Iterable of string names of metrics that - should *not* be averaged over time. Metrics in this list - will be displayed as-is. All others will be averaged - by the progbar before display. + stateful_metrics: Iterable of string names of metrics that should *not* be + averaged over time. Metrics in this list will be displayed as-is. All + others will be averaged by the progbar before display. interval: Minimum visual progress update interval (in seconds). unit_name: Display name for step counts (usually "step" or "sample"). """ - def __init__(self, target, width=30, verbose=1, interval=0.05, - stateful_metrics=None, unit_name='step'): + def __init__(self, + target, + width=30, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name='step'): self.target = target self.width = width self.verbose = verbose @@ -469,11 +508,9 @@ class Progbar(object): Arguments: current: Index of current step. - values: List of tuples: - `(name, value_for_last_step)`. - If `name` is in `stateful_metrics`, - `value_for_last_step` will be displayed as-is. - Else, an average of the metric over time will be displayed. + values: List of tuples: `(name, value_for_last_step)`. If `name` is in + `stateful_metrics`, `value_for_last_step` will be displayed as-is. + Else, an average of the metric over time will be displayed. """ values = values or [] for k, v in values: @@ -538,8 +575,7 @@ class Progbar(object): eta = time_per_unit * (self.target - current) if eta > 3600: eta_format = '%d:%02d:%02d' % (eta // 3600, - (eta % 3600) // 60, - eta % 60) + (eta % 3600) // 60, eta % 60) elif eta > 60: eta_format = '%d:%02d' % (eta // 60, eta % 60) else: @@ -625,10 +661,8 @@ def slice_arrays(arrays, start=None, stop=None): Arguments: arrays: Single array or list of arrays. - start: can be an integer index (start index) - or a list/array of indices - stop: integer (stop index); should be None if - `start` was a list. + start: can be an integer index (start index) or a list/array of indices + stop: integer (stop index); should be None if `start` was a list. Returns: A slice of the array(s). @@ -711,7 +745,8 @@ def check_for_unexpected_keys(name, input_dict, expected_values): expected_values)) -def validate_kwargs(kwargs, allowed_kwargs, +def validate_kwargs(kwargs, + allowed_kwargs, error_message='Keyword argument not understood:'): """Checks that all keyword arguments are in the set of allowed keys.""" for kwarg in kwargs: diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py index 619d31e8f8c..334758871fa 100644 --- a/tensorflow/python/keras/utils/generic_utils_test.py +++ b/tensorflow/python/keras/utils/generic_utils_test.py @@ -129,6 +129,13 @@ class SerializeKerasObjectTest(test.TestCase): inst = OtherTestClass(val=5) class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass] self.assertEqual(serialized_name, class_name) + fn_class_name = keras.utils.generic_utils.get_registered_name( + OtherTestClass) + self.assertEqual(fn_class_name, class_name) + + cls = keras.utils.generic_utils.get_registered_object(fn_class_name) + self.assertEqual(OtherTestClass, cls) + config = keras.utils.generic_utils.serialize_keras_object(inst) self.assertEqual(class_name, config['class_name']) new_inst = keras.utils.generic_utils.deserialize_keras_object(config) @@ -145,11 +152,17 @@ class SerializeKerasObjectTest(test.TestCase): serialized_name = 'Custom>my_fn' class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn] self.assertEqual(serialized_name, class_name) + fn_class_name = keras.utils.generic_utils.get_registered_name(my_fn) + self.assertEqual(fn_class_name, class_name) + config = keras.utils.generic_utils.serialize_keras_object(my_fn) self.assertEqual(class_name, config) fn = keras.utils.generic_utils.deserialize_keras_object(config) self.assertEqual(42, fn()) + fn_2 = keras.utils.generic_utils.get_registered_object(fn_class_name) + self.assertEqual(42, fn_2()) + def test_serialize_custom_class_without_get_config_fails(self): with self.assertRaisesRegex( diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt index e6a82676a73..6f0000b84fb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt @@ -48,6 +48,14 @@ tf_module { name: "get_file" argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], " } + member_method { + name: "get_registered_name" + argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_registered_object" + argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } member_method { name: "get_source_inputs" argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt index e6a82676a73..6f0000b84fb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt @@ -48,6 +48,14 @@ tf_module { name: "get_file" argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], " } + member_method { + name: "get_registered_name" + argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_registered_object" + argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } member_method { name: "get_source_inputs" argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "