Remove implicit name scoping from tf.Modules.

After attempting to integrate `tf.Module` into existing codebases (e.g.
`tf.keras`) we've found that the automatic name scoping is too invasive (e.g.
changing op and variable names) and it is desirable to disable it ~everywhere.

We propose that name scoping for `tf.Module` becomes opt-in:

>>> class MyModule(tf.Module):
...
...   @tf.Module.with_name_scope
...   def auto_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       self.w = tf.Variable(1., name='w')
...     return x * self.w
...
...   def manual_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       with self.name_scope:
...         self.w = tf.Variable(1., name='w')
...     return x * self.w
...
...   def no_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       self.w = tf.Variable(1., name='w')
...     return x * self.w

We will move opt-out name scoping into Sonnet:

>>> class MyModule(snt.Module):
...
...   def auto_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       self.w = tf.Variable(1., name='w')
...     return x * self.w
...
...   @snt.no_name_scope
...   def no_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       self.w = tf.Variable(1., name='w')
...     return x * self.w

In TF2 name scopes are cosmetic and this should be less of a big deal. We might
consider encouraging users who want to filter on names to instead use flatten
to extract a state dictionary for their objects (c.f.
https://github.com/tensorflow/community/pull/56#discussion_r255048762).

I have moved the automatic name scoping logic (Metaclass etc) and associated
tests into Sonnet 2.

PiperOrigin-RevId: 235540184
This commit is contained in:
Tom Hennigan 2019-02-25 09:06:45 -08:00 committed by TensorFlower Gardener
parent ad76a4b373
commit e866995aff
6 changed files with 76 additions and 268 deletions

View File

@ -18,160 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import re
import sys
import six
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
NO_MODULE_NAME_SCOPE = "__no_module_name_scope__"
class ModuleMetaclass(abc.ABCMeta):
"""Metaclass for `tf.Module`."""
def __new__(mcs, name, bases, clsdict):
methods = []
for key, value in clsdict.items():
if key == "name_scope":
continue
elif key.startswith("__") and key != "__call__":
# Don't patch methods like `__getattr__` or `__del__`.
continue
elif tf_inspect.isfunction(value):
# We defer patching methods until after the type is created such that we
# can trigger the descriptor binding them to the class.
methods.append(key)
elif isinstance(value, property):
# TODO(tomhennigan) Preserve the type of property subclasses.
clsdict[key] = property(
value.fget if not value.fget else with_name_scope(value.fget),
value.fset if not value.fset else with_name_scope(value.fset),
value.fdel if not value.fdel else with_name_scope(value.fdel),
doc=value.__doc__)
cls = super(ModuleMetaclass, mcs).__new__(mcs, name, bases, clsdict)
for method_name in methods:
# Note: the below is quite subtle, we need to ensure that we're wrapping
# the method bound to the class. In some cases (e.g. `wrapt`) this is
# important since the method can trigger different behavior when it is
# bound (e.g. in wrapt `FunctionWrapper.__get__(None, cls)` produces a
# `BoundFunctionWrapper` which in turn populates the `instance` argument
# to decorator functions using args[0]).
# Equivalent to: `cls.__dict__[method_name].__get__(None, cls)`
method = getattr(cls, method_name)
method = with_name_scope(method)
setattr(cls, method_name, method)
return cls
def __call__(cls, *args, **kwargs):
# Call new such that we have an un-initialized module instance that we can
# still reference even if there is an exception during __init__. This is
# needed such that we can make sure the name_scope constructed in __init__
# is closed even if there is an exception.
module = cls.__new__(cls, *args, **kwargs)
# Now attempt to initialize the object.
try:
module.__init__(*args, **kwargs)
except:
# We must explicitly catch so that in Python 2 sys.exc_info() is populated
# before entering the finally block.
raise
finally:
# The base Module constructor enters the modules name scope before
# returning such that other functionality in the ctor happens within the
# modules name scope.
scope = getattr(module, "_ctor_name_scope", None)
exc_info = sys.exc_info()
if scope is None:
if exc_info[0] is None:
raise ValueError(
"Constructing a tf.Module without calling the super constructor "
"is not supported. Add the following as the first line in your "
"__init__ method:\n\n"
"super(%s, self).__init__()" % cls.__name__)
else:
scope.__exit__(*exc_info)
del module._ctor_name_scope
return module
def wrap_with_name_scope(unbound_method):
"""Patches the given method so it enters the modules name scope."""
def enter_name_scope(self, *args, **kwargs):
"""Decorator that calls the given function in the module name scope.
Args:
self: Module instance.
*args: Positional arguments to `unbound_method`.
**kwargs: Keyword arguments to `unbound_method`.
Returns:
`with self.name_scope: return unbound_method(self, *args, **kwargs)`
"""
try:
module_name_scope = self.name_scope
except AttributeError as exc_value_from:
exc_value = AttributeError(
"The super constructor must be called before any other methods in "
"your constructor. If this is not possible then annotate all the "
"methods called with `@no_module_name_scope`.")
six.raise_from(exc_value, exc_value_from)
with module_name_scope:
# tf.Module enters the module name scope for all methods. To disable this
# for a particular method annotate it with `@no_module_name_scope`.
return unbound_method(self, *args, **kwargs)
return enter_name_scope
def wrap_with_name_scope_no_exception(unbound_method):
"""Patches the given method so it enters the modules name scope."""
def enter_name_scope(self, *args, **kwargs):
with self.name_scope:
# tf.Module enters the module name scope for all methods. To disable this
# for a particular method annotate it with `@no_module_name_scope`.
return unbound_method(self, *args, **kwargs)
return enter_name_scope
def with_name_scope(unbound_method):
"""Patches the given method so it enters the modules name scope."""
if getattr(unbound_method, NO_MODULE_NAME_SCOPE, False):
# The function has been annotated to say that no autoscoping should be
# applied, so do not patch it.
return unbound_method
if isinstance(unbound_method, def_function.Function):
# Autograph cannot convert functions that have try/catch.
unbound_method._decorate(wrap_with_name_scope_no_exception) # pylint: disable=protected-access
return unbound_method
else:
return tf_decorator.make_decorator(unbound_method,
wrap_with_name_scope(unbound_method))
@tf_export("Module")
class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
class Module(tracking.AutoTrackable):
"""Base neural network module class.
A module is a named container for `tf.Variable`s, other `tf.Module`s and
@ -179,37 +37,54 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
network might be implemented as a `tf.Module`:
>>> class Dense(tf.Module):
... def __init__(self, in_features, output_features):
... super(Dense, self).__init__()
... def __init__(self, in_features, output_features, name=None):
... super(Dense, self).__init__(name=name)
... self.w = tf.Variable(
... tf.random_normal([input_features, output_features]), name='w')
... self.b = tf.Variable(tf.zeros([output_features]), name='b')
...
... def __call__(self, x):
... x = tf.convert_to_tensor(x, name='x')
... y = tf.matmul(x, self.w) + self.b
... return tf.nn.relu(y)
You can use the dense layer as you would expect:
You can use the Dense layer as you would expect:
>>> d = Dense(input_features=64, output_features=10)
>>> d(tf.ones([100, 64]))
<tf.Tensor: ...>
By subclassing `tf.Module` instead of `object` any variables created inside
the module are automatically created within the modules name scope:
>>> d.w.name
"dense/w:0"
In eager mode this is useful for debugging, and when used with `@tf.function`
the use of name scopes gives operations (e.g. matmul) useful names as well.
As well as automatic naming, the Dense module inherits methods for tracking
its variables:
By subclassing `tf.Module` instead of `object` any `tf.Variable` or
`tf.Module` instances assigned to object properties can be collected using
the `variables`, `trainable_variables` or `submodules` property:
>>> d.variables
(<tf.Variable 'dense/b:0' ...>, <tf.Variable 'dense/w:0' ...>)
(<tf.Variable 'b:0' ...>, <tf.Variable 'w:0' ...>)
Subclasses of `tf.Module` can also take advantage of the `_flatten` method
which can be used to implement tracking of any other types.
All `tf.Module` classes have an associated `tf.name_scope` which can be used
to group operations in TensorBoard and create hierarchies for variable names
which can help with debugging. We suggest using the name scope when creating
nested submodules/parameters or for forward methods whose graph you might want
to inspect in TensorBoard. You can enter the name scope explicitly using
`with self.name_scope:` or you can annotate methods (apart from `__init__`)
with `@tf.Module.with_name_scope`.
>>> class MLP(tf.Module):
... def __init__(self, input_size, sizes, name=None):
... super(MLP, self).__init__(name=name)
... self.layers = []
... with self.name_scope:
... for size in sizes:
... self.layers.append(Dense(input_size=size, output_size=size))
... input_size = size
...
... @tf.Module.with_name_scope
... def __call__(self, x):
... for layer in self.layers:
... x = layer(x)
... return x
"""
def __init__(self, name=None):
@ -225,12 +100,6 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
with ops.name_scope(name) as scope_name:
self._scope_name = scope_name
# Enter the name scope so subsequent code in the contructor (e.g. creating
# submodules) happens inside the modules name scope. This is exited when
# the subclass __init__ returns (this is implemented in ModuleMetaclass).
self._ctor_name_scope = self.name_scope
self._ctor_name_scope.__enter__()
@property
def name(self):
"""Returns the name of this module as passed or determined in the ctor.
@ -360,27 +229,37 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
with_path=with_path)
@classmethod
def no_name_scope(cls, method):
"""Decorator to wrap a method, preventing automatic name scope wrapping.
def with_name_scope(cls, method):
"""Decorator to automatically enter the module name scope.
By default, any method on a module is considered as a forwards function, and
so any variables / modules created by the method will be scoped as belonging
to the module. In some cases this is undesirable, for example when
implementing .clone() / .transpose(), as in those cases we want the new
module to have the scope of wherever the .transpose() call is made. To
allow this, decorate any methods with `no_module_name_scope`.
>>> class MyModule(tf.Module):
... @tf.Module.with_name_scope
... def __call__(self, x):
... if not hasattr(self, 'w'):
... self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
... return tf.matmul(x, self.w)
This logic is tied to ModuleMetaclass.__new__, if anything is
changed here corresponding changes will be needed there.
Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
names included the module name:
>>> mod = MyModule()
>>> mod(tf.ones([8, 32]))
<tf.Tensor: ...>
>>> mod.w
<tf.Variable ...'my_module/w:0'>
Args:
method: the method to wrap.
method: The method to wrap.
Returns:
The method, with a flag indicating no name scope wrapping should occur.
The original method wrapped such that it enters the module's name scope.
"""
setattr(method, NO_MODULE_NAME_SCOPE, True)
return method
def method_with_name_scope(self, *args, **kwargs):
with self.name_scope:
return method(self, *args, **kwargs)
return tf_decorator.make_decorator(method, method_with_name_scope)
_IS_VARIABLE = lambda o: isinstance(o, variables.Variable)
_IS_TRAINABLE_VARIABLE = lambda o: (_IS_VARIABLE(o) and o.trainable)

View File

@ -227,24 +227,6 @@ class ModuleTrackingTest(test.TestCase):
self.assertEqual(set(m.submodules), {leaf1, leaf2})
class CommonErrorsTest(test.TestCase):
def test_not_calling_super_constructor(self):
msg = ("Constructing a tf.Module without calling the super constructor is "
"not supported")
with self.assertRaisesRegexp(ValueError, msg):
DoesNotCallSuperConstructorModule()
def test_calls_method_before_super(self):
msg = "super constructor must be called before any other methods"
with self.assertRaisesRegexp(AttributeError, msg):
CallsMethodBeforeSuperConstructorModule(allowed_method=False)
def test_annotated_method_is_allowed(self):
self.assertIsNotNone(
CallsMethodBeforeSuperConstructorModule(allowed_method=True))
class ForwardMethodsTest(test.TestCase):
def testFunctionType(self):
@ -307,6 +289,7 @@ class RecursiveModule(module.Module):
def __init__(self, depth, trainable=True):
super(RecursiveModule, self).__init__(name="badger")
with self.name_scope:
self.child = None
if depth > 1:
self.child = RecursiveModule(depth - 1, trainable=trainable)
@ -323,6 +306,7 @@ class AbstractModule(module.Module):
class ConcreteModule(AbstractModule):
@module.Module.with_name_scope
def __call__(self, x):
return x ** 2, get_name_scope()
@ -333,6 +317,7 @@ class TreeModule(module.Module):
super(TreeModule, self).__init__(name=name)
self._leaves = []
@module.Module.with_name_scope
def new_leaf(self, name=None):
leaf = TreeModule(name=name)
self._leaves.append(leaf)
@ -341,15 +326,18 @@ class TreeModule(module.Module):
class ReturnsNameScopeModule(module.Module):
@module.Module.with_name_scope
def alternative_forward(self):
return get_name_scope()
@module.Module.with_name_scope
def __call__(self):
return get_name_scope()
class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule):
@module.Module.with_name_scope
def alternative_alternative_forward(self):
return get_name_scope()
@ -368,37 +356,15 @@ class ModuleOverridingNameScope(ReturnsNameScopeModule):
return ops.name_scope("yolo/")
class DoesNotCallSuperConstructorModule(module.Module):
def __init__(self):
# NOTE: Intentionally does not call super constructor.
pass
class CallsMethodBeforeSuperConstructorModule(module.Module):
def __init__(self, allowed_method):
if allowed_method:
self.no_name_scope()
else:
self.with_name_scope()
super(CallsMethodBeforeSuperConstructorModule, self).__init__()
@module.Module.no_name_scope
def no_name_scope(self):
pass
def with_name_scope(self):
pass
class ModuleWithFunctionAnnotatedCall(module.Module):
@def_function.function(autograph=False)
@module.Module.with_name_scope
def forward(self):
return get_name_scope()
@def_function.function(autograph=True)
@module.Module.with_name_scope
def forward_ag(self):
return get_name_scope()
@ -410,22 +376,22 @@ class PropertyModule(module.Module):
self._setter_scope_name = None
@property
@module.Module.with_name_scope
def some_property(self):
getter_scope_name = get_name_scope()
return getter_scope_name, self._setter_scope_name
@some_property.setter
@module.Module.with_name_scope
def some_property(self, my_property):
self._setter_scope_name = get_name_scope()
@property
@module.Module.no_name_scope
def no_name_scope_property(self):
getter_scope_name = get_name_scope()
return getter_scope_name, self._setter_scope_name
@no_name_scope_property.setter
@module.Module.no_name_scope
def no_name_scope_property(self, my_property):
self._setter_scope_name = get_name_scope()
@ -514,43 +480,6 @@ class SimpleModule(module.Module):
IS_MEMBER = lambda v: isinstance(v, MemberType)
IS_MODULE = lambda v: isinstance(v, module.Module)
class CustomMetaclass(type):
TAG = "__custom_metaclass__"
def __new__(mcs, name, bases, clsdict):
new_type = super(CustomMetaclass, mcs).__new__(mcs, name, bases, clsdict)
setattr(new_type, CustomMetaclass.TAG, True)
return new_type
class CombiningMetaclass(module.ModuleMetaclass, CustomMetaclass):
TAG = "__combining_metaclass__"
def __new__(mcs, name, bases, clsdict):
new_type = super(CombiningMetaclass, mcs).__new__(mcs, name, bases, clsdict)
setattr(new_type, CombiningMetaclass.TAG, True)
return new_type
@six.add_metaclass(CombiningMetaclass)
class ModuleWithCustomMetaclass(module.Module):
def __init__(self):
super(ModuleWithCustomMetaclass, self).__init__()
self.init_name_scope = get_name_scope()
class CustomMetaclassTest(test.TestCase):
def testSupportsCustomMetaclass(self):
m = ModuleWithCustomMetaclass()
self.assertEqual(m.init_name_scope, "module_with_custom_metaclass/")
self.assertTrue(getattr(ModuleWithCustomMetaclass, CombiningMetaclass.TAG))
self.assertTrue(getattr(ModuleWithCustomMetaclass, CustomMetaclass.TAG))
if __name__ == "__main__":
v2_compat.enable_v2_behavior()
test.main()

View File

@ -29,7 +29,7 @@ tf_class {
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "no_name_scope"
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -142,7 +142,7 @@ tf_module {
}
member {
name: "Module"
mtype: "<class \'tensorflow.python.module.module.ModuleMetaclass\'>"
mtype: "<type \'type\'>"
}
member {
name: "NameAttrList"

View File

@ -29,7 +29,7 @@ tf_class {
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "no_name_scope"
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -26,7 +26,7 @@ tf_module {
}
member {
name: "Module"
mtype: "<class \'tensorflow.python.module.module.ModuleMetaclass\'>"
mtype: "<type \'type\'>"
}
member {
name: "Operation"