Remove implicit name scoping from tf.Modules.

After attempting to integrate `tf.Module` into existing codebases (e.g.
`tf.keras`) we've found that the automatic name scoping is too invasive (e.g.
changing op and variable names) and it is desirable to disable it ~everywhere.

We propose that name scoping for `tf.Module` becomes opt-in:

>>> class MyModule(tf.Module):
...
...   @tf.Module.with_name_scope
...   def auto_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       self.w = tf.Variable(1., name='w')
...     return x * self.w
...
...   def manual_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       with self.name_scope:
...         self.w = tf.Variable(1., name='w')
...     return x * self.w
...
...   def no_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       self.w = tf.Variable(1., name='w')
...     return x * self.w

We will move opt-out name scoping into Sonnet:

>>> class MyModule(snt.Module):
...
...   def auto_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       self.w = tf.Variable(1., name='w')
...     return x * self.w
...
...   @snt.no_name_scope
...   def no_name_scope(self, x):
...     if not hasattr(self, 'w'):
...       self.w = tf.Variable(1., name='w')
...     return x * self.w

In TF2 name scopes are cosmetic and this should be less of a big deal. We might
consider encouraging users who want to filter on names to instead use flatten
to extract a state dictionary for their objects (c.f.
https://github.com/tensorflow/community/pull/56#discussion_r255048762).

I have moved the automatic name scoping logic (Metaclass etc) and associated
tests into Sonnet 2.

PiperOrigin-RevId: 235540184
This commit is contained in:
Tom Hennigan 2019-02-25 09:06:45 -08:00 committed by TensorFlower Gardener
parent ad76a4b373
commit e866995aff
6 changed files with 76 additions and 268 deletions

View File

@ -18,160 +18,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import abc
import re import re
import sys
import six
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
NO_MODULE_NAME_SCOPE = "__no_module_name_scope__"
class ModuleMetaclass(abc.ABCMeta):
"""Metaclass for `tf.Module`."""
def __new__(mcs, name, bases, clsdict):
methods = []
for key, value in clsdict.items():
if key == "name_scope":
continue
elif key.startswith("__") and key != "__call__":
# Don't patch methods like `__getattr__` or `__del__`.
continue
elif tf_inspect.isfunction(value):
# We defer patching methods until after the type is created such that we
# can trigger the descriptor binding them to the class.
methods.append(key)
elif isinstance(value, property):
# TODO(tomhennigan) Preserve the type of property subclasses.
clsdict[key] = property(
value.fget if not value.fget else with_name_scope(value.fget),
value.fset if not value.fset else with_name_scope(value.fset),
value.fdel if not value.fdel else with_name_scope(value.fdel),
doc=value.__doc__)
cls = super(ModuleMetaclass, mcs).__new__(mcs, name, bases, clsdict)
for method_name in methods:
# Note: the below is quite subtle, we need to ensure that we're wrapping
# the method bound to the class. In some cases (e.g. `wrapt`) this is
# important since the method can trigger different behavior when it is
# bound (e.g. in wrapt `FunctionWrapper.__get__(None, cls)` produces a
# `BoundFunctionWrapper` which in turn populates the `instance` argument
# to decorator functions using args[0]).
# Equivalent to: `cls.__dict__[method_name].__get__(None, cls)`
method = getattr(cls, method_name)
method = with_name_scope(method)
setattr(cls, method_name, method)
return cls
def __call__(cls, *args, **kwargs):
# Call new such that we have an un-initialized module instance that we can
# still reference even if there is an exception during __init__. This is
# needed such that we can make sure the name_scope constructed in __init__
# is closed even if there is an exception.
module = cls.__new__(cls, *args, **kwargs)
# Now attempt to initialize the object.
try:
module.__init__(*args, **kwargs)
except:
# We must explicitly catch so that in Python 2 sys.exc_info() is populated
# before entering the finally block.
raise
finally:
# The base Module constructor enters the modules name scope before
# returning such that other functionality in the ctor happens within the
# modules name scope.
scope = getattr(module, "_ctor_name_scope", None)
exc_info = sys.exc_info()
if scope is None:
if exc_info[0] is None:
raise ValueError(
"Constructing a tf.Module without calling the super constructor "
"is not supported. Add the following as the first line in your "
"__init__ method:\n\n"
"super(%s, self).__init__()" % cls.__name__)
else:
scope.__exit__(*exc_info)
del module._ctor_name_scope
return module
def wrap_with_name_scope(unbound_method):
"""Patches the given method so it enters the modules name scope."""
def enter_name_scope(self, *args, **kwargs):
"""Decorator that calls the given function in the module name scope.
Args:
self: Module instance.
*args: Positional arguments to `unbound_method`.
**kwargs: Keyword arguments to `unbound_method`.
Returns:
`with self.name_scope: return unbound_method(self, *args, **kwargs)`
"""
try:
module_name_scope = self.name_scope
except AttributeError as exc_value_from:
exc_value = AttributeError(
"The super constructor must be called before any other methods in "
"your constructor. If this is not possible then annotate all the "
"methods called with `@no_module_name_scope`.")
six.raise_from(exc_value, exc_value_from)
with module_name_scope:
# tf.Module enters the module name scope for all methods. To disable this
# for a particular method annotate it with `@no_module_name_scope`.
return unbound_method(self, *args, **kwargs)
return enter_name_scope
def wrap_with_name_scope_no_exception(unbound_method):
"""Patches the given method so it enters the modules name scope."""
def enter_name_scope(self, *args, **kwargs):
with self.name_scope:
# tf.Module enters the module name scope for all methods. To disable this
# for a particular method annotate it with `@no_module_name_scope`.
return unbound_method(self, *args, **kwargs)
return enter_name_scope
def with_name_scope(unbound_method):
"""Patches the given method so it enters the modules name scope."""
if getattr(unbound_method, NO_MODULE_NAME_SCOPE, False):
# The function has been annotated to say that no autoscoping should be
# applied, so do not patch it.
return unbound_method
if isinstance(unbound_method, def_function.Function):
# Autograph cannot convert functions that have try/catch.
unbound_method._decorate(wrap_with_name_scope_no_exception) # pylint: disable=protected-access
return unbound_method
else:
return tf_decorator.make_decorator(unbound_method,
wrap_with_name_scope(unbound_method))
@tf_export("Module") @tf_export("Module")
class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)): class Module(tracking.AutoTrackable):
"""Base neural network module class. """Base neural network module class.
A module is a named container for `tf.Variable`s, other `tf.Module`s and A module is a named container for `tf.Variable`s, other `tf.Module`s and
@ -179,37 +37,54 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
network might be implemented as a `tf.Module`: network might be implemented as a `tf.Module`:
>>> class Dense(tf.Module): >>> class Dense(tf.Module):
... def __init__(self, in_features, output_features): ... def __init__(self, in_features, output_features, name=None):
... super(Dense, self).__init__() ... super(Dense, self).__init__(name=name)
... self.w = tf.Variable( ... self.w = tf.Variable(
... tf.random_normal([input_features, output_features]), name='w') ... tf.random_normal([input_features, output_features]), name='w')
... self.b = tf.Variable(tf.zeros([output_features]), name='b') ... self.b = tf.Variable(tf.zeros([output_features]), name='b')
... ...
... def __call__(self, x): ... def __call__(self, x):
... x = tf.convert_to_tensor(x, name='x')
... y = tf.matmul(x, self.w) + self.b ... y = tf.matmul(x, self.w) + self.b
... return tf.nn.relu(y) ... return tf.nn.relu(y)
You can use the dense layer as you would expect: You can use the Dense layer as you would expect:
>>> d = Dense(input_features=64, output_features=10) >>> d = Dense(input_features=64, output_features=10)
>>> d(tf.ones([100, 64])) >>> d(tf.ones([100, 64]))
<tf.Tensor: ...> <tf.Tensor: ...>
By subclassing `tf.Module` instead of `object` any variables created inside By subclassing `tf.Module` instead of `object` any `tf.Variable` or
the module are automatically created within the modules name scope: `tf.Module` instances assigned to object properties can be collected using
the `variables`, `trainable_variables` or `submodules` property:
>>> d.w.name
"dense/w:0"
In eager mode this is useful for debugging, and when used with `@tf.function`
the use of name scopes gives operations (e.g. matmul) useful names as well.
As well as automatic naming, the Dense module inherits methods for tracking
its variables:
>>> d.variables >>> d.variables
(<tf.Variable 'dense/b:0' ...>, <tf.Variable 'dense/w:0' ...>) (<tf.Variable 'b:0' ...>, <tf.Variable 'w:0' ...>)
Subclasses of `tf.Module` can also take advantage of the `_flatten` method
which can be used to implement tracking of any other types.
All `tf.Module` classes have an associated `tf.name_scope` which can be used
to group operations in TensorBoard and create hierarchies for variable names
which can help with debugging. We suggest using the name scope when creating
nested submodules/parameters or for forward methods whose graph you might want
to inspect in TensorBoard. You can enter the name scope explicitly using
`with self.name_scope:` or you can annotate methods (apart from `__init__`)
with `@tf.Module.with_name_scope`.
>>> class MLP(tf.Module):
... def __init__(self, input_size, sizes, name=None):
... super(MLP, self).__init__(name=name)
... self.layers = []
... with self.name_scope:
... for size in sizes:
... self.layers.append(Dense(input_size=size, output_size=size))
... input_size = size
...
... @tf.Module.with_name_scope
... def __call__(self, x):
... for layer in self.layers:
... x = layer(x)
... return x
""" """
def __init__(self, name=None): def __init__(self, name=None):
@ -225,12 +100,6 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
with ops.name_scope(name) as scope_name: with ops.name_scope(name) as scope_name:
self._scope_name = scope_name self._scope_name = scope_name
# Enter the name scope so subsequent code in the contructor (e.g. creating
# submodules) happens inside the modules name scope. This is exited when
# the subclass __init__ returns (this is implemented in ModuleMetaclass).
self._ctor_name_scope = self.name_scope
self._ctor_name_scope.__enter__()
@property @property
def name(self): def name(self):
"""Returns the name of this module as passed or determined in the ctor. """Returns the name of this module as passed or determined in the ctor.
@ -360,27 +229,37 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
with_path=with_path) with_path=with_path)
@classmethod @classmethod
def no_name_scope(cls, method): def with_name_scope(cls, method):
"""Decorator to wrap a method, preventing automatic name scope wrapping. """Decorator to automatically enter the module name scope.
By default, any method on a module is considered as a forwards function, and >>> class MyModule(tf.Module):
so any variables / modules created by the method will be scoped as belonging ... @tf.Module.with_name_scope
to the module. In some cases this is undesirable, for example when ... def __call__(self, x):
implementing .clone() / .transpose(), as in those cases we want the new ... if not hasattr(self, 'w'):
module to have the scope of wherever the .transpose() call is made. To ... self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
allow this, decorate any methods with `no_module_name_scope`. ... return tf.matmul(x, self.w)
This logic is tied to ModuleMetaclass.__new__, if anything is Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
changed here corresponding changes will be needed there. names included the module name:
>>> mod = MyModule()
>>> mod(tf.ones([8, 32]))
<tf.Tensor: ...>
>>> mod.w
<tf.Variable ...'my_module/w:0'>
Args: Args:
method: the method to wrap. method: The method to wrap.
Returns: Returns:
The method, with a flag indicating no name scope wrapping should occur. The original method wrapped such that it enters the module's name scope.
""" """
setattr(method, NO_MODULE_NAME_SCOPE, True) def method_with_name_scope(self, *args, **kwargs):
return method with self.name_scope:
return method(self, *args, **kwargs)
return tf_decorator.make_decorator(method, method_with_name_scope)
_IS_VARIABLE = lambda o: isinstance(o, variables.Variable) _IS_VARIABLE = lambda o: isinstance(o, variables.Variable)
_IS_TRAINABLE_VARIABLE = lambda o: (_IS_VARIABLE(o) and o.trainable) _IS_TRAINABLE_VARIABLE = lambda o: (_IS_VARIABLE(o) and o.trainable)

View File

@ -227,24 +227,6 @@ class ModuleTrackingTest(test.TestCase):
self.assertEqual(set(m.submodules), {leaf1, leaf2}) self.assertEqual(set(m.submodules), {leaf1, leaf2})
class CommonErrorsTest(test.TestCase):
def test_not_calling_super_constructor(self):
msg = ("Constructing a tf.Module without calling the super constructor is "
"not supported")
with self.assertRaisesRegexp(ValueError, msg):
DoesNotCallSuperConstructorModule()
def test_calls_method_before_super(self):
msg = "super constructor must be called before any other methods"
with self.assertRaisesRegexp(AttributeError, msg):
CallsMethodBeforeSuperConstructorModule(allowed_method=False)
def test_annotated_method_is_allowed(self):
self.assertIsNotNone(
CallsMethodBeforeSuperConstructorModule(allowed_method=True))
class ForwardMethodsTest(test.TestCase): class ForwardMethodsTest(test.TestCase):
def testFunctionType(self): def testFunctionType(self):
@ -307,6 +289,7 @@ class RecursiveModule(module.Module):
def __init__(self, depth, trainable=True): def __init__(self, depth, trainable=True):
super(RecursiveModule, self).__init__(name="badger") super(RecursiveModule, self).__init__(name="badger")
with self.name_scope:
self.child = None self.child = None
if depth > 1: if depth > 1:
self.child = RecursiveModule(depth - 1, trainable=trainable) self.child = RecursiveModule(depth - 1, trainable=trainable)
@ -323,6 +306,7 @@ class AbstractModule(module.Module):
class ConcreteModule(AbstractModule): class ConcreteModule(AbstractModule):
@module.Module.with_name_scope
def __call__(self, x): def __call__(self, x):
return x ** 2, get_name_scope() return x ** 2, get_name_scope()
@ -333,6 +317,7 @@ class TreeModule(module.Module):
super(TreeModule, self).__init__(name=name) super(TreeModule, self).__init__(name=name)
self._leaves = [] self._leaves = []
@module.Module.with_name_scope
def new_leaf(self, name=None): def new_leaf(self, name=None):
leaf = TreeModule(name=name) leaf = TreeModule(name=name)
self._leaves.append(leaf) self._leaves.append(leaf)
@ -341,15 +326,18 @@ class TreeModule(module.Module):
class ReturnsNameScopeModule(module.Module): class ReturnsNameScopeModule(module.Module):
@module.Module.with_name_scope
def alternative_forward(self): def alternative_forward(self):
return get_name_scope() return get_name_scope()
@module.Module.with_name_scope
def __call__(self): def __call__(self):
return get_name_scope() return get_name_scope()
class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule): class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule):
@module.Module.with_name_scope
def alternative_alternative_forward(self): def alternative_alternative_forward(self):
return get_name_scope() return get_name_scope()
@ -368,37 +356,15 @@ class ModuleOverridingNameScope(ReturnsNameScopeModule):
return ops.name_scope("yolo/") return ops.name_scope("yolo/")
class DoesNotCallSuperConstructorModule(module.Module):
def __init__(self):
# NOTE: Intentionally does not call super constructor.
pass
class CallsMethodBeforeSuperConstructorModule(module.Module):
def __init__(self, allowed_method):
if allowed_method:
self.no_name_scope()
else:
self.with_name_scope()
super(CallsMethodBeforeSuperConstructorModule, self).__init__()
@module.Module.no_name_scope
def no_name_scope(self):
pass
def with_name_scope(self):
pass
class ModuleWithFunctionAnnotatedCall(module.Module): class ModuleWithFunctionAnnotatedCall(module.Module):
@def_function.function(autograph=False) @def_function.function(autograph=False)
@module.Module.with_name_scope
def forward(self): def forward(self):
return get_name_scope() return get_name_scope()
@def_function.function(autograph=True) @def_function.function(autograph=True)
@module.Module.with_name_scope
def forward_ag(self): def forward_ag(self):
return get_name_scope() return get_name_scope()
@ -410,22 +376,22 @@ class PropertyModule(module.Module):
self._setter_scope_name = None self._setter_scope_name = None
@property @property
@module.Module.with_name_scope
def some_property(self): def some_property(self):
getter_scope_name = get_name_scope() getter_scope_name = get_name_scope()
return getter_scope_name, self._setter_scope_name return getter_scope_name, self._setter_scope_name
@some_property.setter @some_property.setter
@module.Module.with_name_scope
def some_property(self, my_property): def some_property(self, my_property):
self._setter_scope_name = get_name_scope() self._setter_scope_name = get_name_scope()
@property @property
@module.Module.no_name_scope
def no_name_scope_property(self): def no_name_scope_property(self):
getter_scope_name = get_name_scope() getter_scope_name = get_name_scope()
return getter_scope_name, self._setter_scope_name return getter_scope_name, self._setter_scope_name
@no_name_scope_property.setter @no_name_scope_property.setter
@module.Module.no_name_scope
def no_name_scope_property(self, my_property): def no_name_scope_property(self, my_property):
self._setter_scope_name = get_name_scope() self._setter_scope_name = get_name_scope()
@ -514,43 +480,6 @@ class SimpleModule(module.Module):
IS_MEMBER = lambda v: isinstance(v, MemberType) IS_MEMBER = lambda v: isinstance(v, MemberType)
IS_MODULE = lambda v: isinstance(v, module.Module) IS_MODULE = lambda v: isinstance(v, module.Module)
class CustomMetaclass(type):
TAG = "__custom_metaclass__"
def __new__(mcs, name, bases, clsdict):
new_type = super(CustomMetaclass, mcs).__new__(mcs, name, bases, clsdict)
setattr(new_type, CustomMetaclass.TAG, True)
return new_type
class CombiningMetaclass(module.ModuleMetaclass, CustomMetaclass):
TAG = "__combining_metaclass__"
def __new__(mcs, name, bases, clsdict):
new_type = super(CombiningMetaclass, mcs).__new__(mcs, name, bases, clsdict)
setattr(new_type, CombiningMetaclass.TAG, True)
return new_type
@six.add_metaclass(CombiningMetaclass)
class ModuleWithCustomMetaclass(module.Module):
def __init__(self):
super(ModuleWithCustomMetaclass, self).__init__()
self.init_name_scope = get_name_scope()
class CustomMetaclassTest(test.TestCase):
def testSupportsCustomMetaclass(self):
m = ModuleWithCustomMetaclass()
self.assertEqual(m.init_name_scope, "module_with_custom_metaclass/")
self.assertTrue(getattr(ModuleWithCustomMetaclass, CombiningMetaclass.TAG))
self.assertTrue(getattr(ModuleWithCustomMetaclass, CustomMetaclass.TAG))
if __name__ == "__main__": if __name__ == "__main__":
v2_compat.enable_v2_behavior() v2_compat.enable_v2_behavior()
test.main() test.main()

View File

@ -29,7 +29,7 @@ tf_class {
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "no_name_scope" name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -142,7 +142,7 @@ tf_module {
} }
member { member {
name: "Module" name: "Module"
mtype: "<class \'tensorflow.python.module.module.ModuleMetaclass\'>" mtype: "<type \'type\'>"
} }
member { member {
name: "NameAttrList" name: "NameAttrList"

View File

@ -29,7 +29,7 @@ tf_class {
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method { member_method {
name: "no_name_scope" name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
} }
} }

View File

@ -26,7 +26,7 @@ tf_module {
} }
member { member {
name: "Module" name: "Module"
mtype: "<class \'tensorflow.python.module.module.ModuleMetaclass\'>" mtype: "<type \'type\'>"
} }
member { member {
name: "Operation" name: "Operation"