Improves public-facing documentation for tf.keras.utils.custom_object_scope
.
PiperOrigin-RevId: 302504282 Change-Id: Ieb892c2681557a99ddec6f6ba2f9db01efde61df
This commit is contained in:
parent
28c22acb3a
commit
25ae1e130a
tensorflow
@ -47,25 +47,31 @@ _SKIP_FAILED_SERIALIZATION = False
|
||||
_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
|
||||
|
||||
|
||||
@keras_export('keras.utils.CustomObjectScope')
|
||||
@keras_export('keras.utils.custom_object_scope', # pylint: disable=g-classes-have-attributes
|
||||
'keras.utils.CustomObjectScope')
|
||||
class CustomObjectScope(object):
|
||||
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
|
||||
"""Exposes custom classes/functions to Keras deserialization internals.
|
||||
|
||||
Code within a `with` statement will be able to access custom objects
|
||||
by name. Changes to global custom objects persist
|
||||
within the enclosing `with` statement. At end of the `with` statement,
|
||||
global custom objects are reverted to state
|
||||
at beginning of the `with` statement.
|
||||
Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
|
||||
as `tf.keras.models.load_model` or `tf.keras.models.model_from_config`
|
||||
will be able to deserialize any custom object referenced by a
|
||||
saved config (e.g. a custom layer or metric).
|
||||
|
||||
Example:
|
||||
|
||||
Consider a custom object `MyObject` (e.g. a class):
|
||||
Consider a custom regularizer `my_regularizer`:
|
||||
|
||||
```python
|
||||
with CustomObjectScope({'MyObject':MyObject}):
|
||||
layer = Dense(..., kernel_regularizer='MyObject')
|
||||
# save, load, etc. will recognize custom object by name
|
||||
layer = Dense(3, kernel_regularizer=my_regularizer)
|
||||
config = layer.get_config() # Config contains a reference to `my_regularizer`
|
||||
...
|
||||
# Later:
|
||||
with custom_object_scope({'my_regularizer': my_regularizer}):
|
||||
layer = Dense.from_config(config)
|
||||
```
|
||||
|
||||
Arguments:
|
||||
*args: Dictionary or dictionaries of `{name: object}` pairs.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
@ -83,50 +89,19 @@ class CustomObjectScope(object):
|
||||
_GLOBAL_CUSTOM_OBJECTS.update(self.backup)
|
||||
|
||||
|
||||
@keras_export('keras.utils.custom_object_scope')
|
||||
def custom_object_scope(*args):
|
||||
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
|
||||
|
||||
Convenience wrapper for `CustomObjectScope`.
|
||||
Code within a `with` statement will be able to access custom objects
|
||||
by name. Changes to global custom objects persist
|
||||
within the enclosing `with` statement. At end of the `with` statement,
|
||||
global custom objects are reverted to state
|
||||
at beginning of the `with` statement.
|
||||
|
||||
Example:
|
||||
|
||||
Consider a custom object `MyObject`
|
||||
|
||||
```python
|
||||
with custom_object_scope({'MyObject':MyObject}):
|
||||
layer = Dense(..., kernel_regularizer='MyObject')
|
||||
# save, load, etc. will recognize custom object by name
|
||||
```
|
||||
|
||||
Arguments:
|
||||
*args: Variable length list of dictionaries of name, class pairs to add to
|
||||
custom objects.
|
||||
|
||||
Returns:
|
||||
Object of type `CustomObjectScope`.
|
||||
"""
|
||||
return CustomObjectScope(*args)
|
||||
|
||||
|
||||
@keras_export('keras.utils.get_custom_objects')
|
||||
def get_custom_objects():
|
||||
"""Retrieves a live reference to the global dictionary of custom objects.
|
||||
|
||||
Updating and clearing custom objects using `custom_object_scope`
|
||||
is preferred, but `get_custom_objects` can
|
||||
be used to directly access `_GLOBAL_CUSTOM_OBJECTS`.
|
||||
be used to directly access the current collection of custom objects.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
get_custom_objects().clear()
|
||||
get_custom_objects()['MyObject'] = MyObject
|
||||
get_custom_objects().clear()
|
||||
get_custom_objects()['MyObject'] = MyObject
|
||||
```
|
||||
|
||||
Returns:
|
||||
@ -158,7 +133,7 @@ def register_keras_serializable(package='Custom', name=None):
|
||||
Arguments:
|
||||
package: The package that this class belongs to.
|
||||
name: The name to serialize this class under in this package. If None, the
|
||||
class's name will be used.
|
||||
class' name will be used.
|
||||
|
||||
Returns:
|
||||
A decorator that registers the decorated class with the passed names.
|
||||
@ -806,3 +781,9 @@ def default(method):
|
||||
def is_default(method):
|
||||
"""Check if a method is decorated with the `default` wrapper."""
|
||||
return getattr(method, '_is_default', False)
|
||||
|
||||
|
||||
# Aliases
|
||||
|
||||
|
||||
custom_object_scope = CustomObjectScope # pylint: disable=invalid-name
|
||||
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.keras.utils.custom_object_scope"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.generic_utils.CustomObjectScope\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -28,14 +28,14 @@ tf_module {
|
||||
name: "SequenceEnqueuer"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "custom_object_scope"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "convert_all_kernels_in_model"
|
||||
argspec: "args=[\'model\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "custom_object_scope"
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "deserialize_keras_object"
|
||||
argspec: "args=[\'identifier\', \'module_objects\', \'custom_objects\', \'printable_module_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'object\'], "
|
||||
|
@ -0,0 +1,9 @@
|
||||
path: "tensorflow.keras.utils.custom_object_scope"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.generic_utils.CustomObjectScope\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -28,14 +28,14 @@ tf_module {
|
||||
name: "SequenceEnqueuer"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "custom_object_scope"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "convert_all_kernels_in_model"
|
||||
argspec: "args=[\'model\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "custom_object_scope"
|
||||
argspec: "args=[], varargs=args, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "deserialize_keras_object"
|
||||
argspec: "args=[\'identifier\', \'module_objects\', \'custom_objects\', \'printable_module_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'object\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user