Remove implicit name scoping from tf.Module
s.
After attempting to integrate `tf.Module` into existing codebases (e.g. `tf.keras`) we've found that the automatic name scoping is too invasive (e.g. changing op and variable names) and it is desirable to disable it ~everywhere. We propose that name scoping for `tf.Module` becomes opt-in: >>> class MyModule(tf.Module): ... ... @tf.Module.with_name_scope ... def auto_name_scope(self, x): ... if not hasattr(self, 'w'): ... self.w = tf.Variable(1., name='w') ... return x * self.w ... ... def manual_name_scope(self, x): ... if not hasattr(self, 'w'): ... with self.name_scope: ... self.w = tf.Variable(1., name='w') ... return x * self.w ... ... def no_name_scope(self, x): ... if not hasattr(self, 'w'): ... self.w = tf.Variable(1., name='w') ... return x * self.w We will move opt-out name scoping into Sonnet: >>> class MyModule(snt.Module): ... ... def auto_name_scope(self, x): ... if not hasattr(self, 'w'): ... self.w = tf.Variable(1., name='w') ... return x * self.w ... ... @snt.no_name_scope ... def no_name_scope(self, x): ... if not hasattr(self, 'w'): ... self.w = tf.Variable(1., name='w') ... return x * self.w In TF2 name scopes are cosmetic and this should be less of a big deal. We might consider encouraging users who want to filter on names to instead use flatten to extract a state dictionary for their objects (c.f. https://github.com/tensorflow/community/pull/56#discussion_r255048762). I have moved the automatic name scoping logic (Metaclass etc) and associated tests into Sonnet 2. PiperOrigin-RevId: 235540184
This commit is contained in:
parent
ad76a4b373
commit
e866995aff
@ -18,160 +18,18 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import abc
|
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
from tensorflow.python.eager import def_function
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
from tensorflow.python.util import tf_inspect
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
NO_MODULE_NAME_SCOPE = "__no_module_name_scope__"
|
|
||||||
|
|
||||||
|
|
||||||
class ModuleMetaclass(abc.ABCMeta):
|
|
||||||
"""Metaclass for `tf.Module`."""
|
|
||||||
|
|
||||||
def __new__(mcs, name, bases, clsdict):
|
|
||||||
methods = []
|
|
||||||
|
|
||||||
for key, value in clsdict.items():
|
|
||||||
if key == "name_scope":
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif key.startswith("__") and key != "__call__":
|
|
||||||
# Don't patch methods like `__getattr__` or `__del__`.
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif tf_inspect.isfunction(value):
|
|
||||||
# We defer patching methods until after the type is created such that we
|
|
||||||
# can trigger the descriptor binding them to the class.
|
|
||||||
methods.append(key)
|
|
||||||
|
|
||||||
elif isinstance(value, property):
|
|
||||||
# TODO(tomhennigan) Preserve the type of property subclasses.
|
|
||||||
clsdict[key] = property(
|
|
||||||
value.fget if not value.fget else with_name_scope(value.fget),
|
|
||||||
value.fset if not value.fset else with_name_scope(value.fset),
|
|
||||||
value.fdel if not value.fdel else with_name_scope(value.fdel),
|
|
||||||
doc=value.__doc__)
|
|
||||||
|
|
||||||
cls = super(ModuleMetaclass, mcs).__new__(mcs, name, bases, clsdict)
|
|
||||||
|
|
||||||
for method_name in methods:
|
|
||||||
# Note: the below is quite subtle, we need to ensure that we're wrapping
|
|
||||||
# the method bound to the class. In some cases (e.g. `wrapt`) this is
|
|
||||||
# important since the method can trigger different behavior when it is
|
|
||||||
# bound (e.g. in wrapt `FunctionWrapper.__get__(None, cls)` produces a
|
|
||||||
# `BoundFunctionWrapper` which in turn populates the `instance` argument
|
|
||||||
# to decorator functions using args[0]).
|
|
||||||
# Equivalent to: `cls.__dict__[method_name].__get__(None, cls)`
|
|
||||||
method = getattr(cls, method_name)
|
|
||||||
method = with_name_scope(method)
|
|
||||||
setattr(cls, method_name, method)
|
|
||||||
|
|
||||||
return cls
|
|
||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
|
||||||
# Call new such that we have an un-initialized module instance that we can
|
|
||||||
# still reference even if there is an exception during __init__. This is
|
|
||||||
# needed such that we can make sure the name_scope constructed in __init__
|
|
||||||
# is closed even if there is an exception.
|
|
||||||
module = cls.__new__(cls, *args, **kwargs)
|
|
||||||
|
|
||||||
# Now attempt to initialize the object.
|
|
||||||
try:
|
|
||||||
module.__init__(*args, **kwargs)
|
|
||||||
except:
|
|
||||||
# We must explicitly catch so that in Python 2 sys.exc_info() is populated
|
|
||||||
# before entering the finally block.
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# The base Module constructor enters the modules name scope before
|
|
||||||
# returning such that other functionality in the ctor happens within the
|
|
||||||
# modules name scope.
|
|
||||||
scope = getattr(module, "_ctor_name_scope", None)
|
|
||||||
exc_info = sys.exc_info()
|
|
||||||
if scope is None:
|
|
||||||
if exc_info[0] is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Constructing a tf.Module without calling the super constructor "
|
|
||||||
"is not supported. Add the following as the first line in your "
|
|
||||||
"__init__ method:\n\n"
|
|
||||||
"super(%s, self).__init__()" % cls.__name__)
|
|
||||||
else:
|
|
||||||
scope.__exit__(*exc_info)
|
|
||||||
del module._ctor_name_scope
|
|
||||||
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_with_name_scope(unbound_method):
|
|
||||||
"""Patches the given method so it enters the modules name scope."""
|
|
||||||
def enter_name_scope(self, *args, **kwargs):
|
|
||||||
"""Decorator that calls the given function in the module name scope.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
self: Module instance.
|
|
||||||
*args: Positional arguments to `unbound_method`.
|
|
||||||
**kwargs: Keyword arguments to `unbound_method`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`with self.name_scope: return unbound_method(self, *args, **kwargs)`
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
module_name_scope = self.name_scope
|
|
||||||
except AttributeError as exc_value_from:
|
|
||||||
exc_value = AttributeError(
|
|
||||||
"The super constructor must be called before any other methods in "
|
|
||||||
"your constructor. If this is not possible then annotate all the "
|
|
||||||
"methods called with `@no_module_name_scope`.")
|
|
||||||
six.raise_from(exc_value, exc_value_from)
|
|
||||||
|
|
||||||
with module_name_scope:
|
|
||||||
# tf.Module enters the module name scope for all methods. To disable this
|
|
||||||
# for a particular method annotate it with `@no_module_name_scope`.
|
|
||||||
return unbound_method(self, *args, **kwargs)
|
|
||||||
|
|
||||||
return enter_name_scope
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_with_name_scope_no_exception(unbound_method):
|
|
||||||
"""Patches the given method so it enters the modules name scope."""
|
|
||||||
def enter_name_scope(self, *args, **kwargs):
|
|
||||||
with self.name_scope:
|
|
||||||
# tf.Module enters the module name scope for all methods. To disable this
|
|
||||||
# for a particular method annotate it with `@no_module_name_scope`.
|
|
||||||
return unbound_method(self, *args, **kwargs)
|
|
||||||
return enter_name_scope
|
|
||||||
|
|
||||||
|
|
||||||
def with_name_scope(unbound_method):
|
|
||||||
"""Patches the given method so it enters the modules name scope."""
|
|
||||||
if getattr(unbound_method, NO_MODULE_NAME_SCOPE, False):
|
|
||||||
# The function has been annotated to say that no autoscoping should be
|
|
||||||
# applied, so do not patch it.
|
|
||||||
return unbound_method
|
|
||||||
|
|
||||||
if isinstance(unbound_method, def_function.Function):
|
|
||||||
# Autograph cannot convert functions that have try/catch.
|
|
||||||
unbound_method._decorate(wrap_with_name_scope_no_exception) # pylint: disable=protected-access
|
|
||||||
return unbound_method
|
|
||||||
else:
|
|
||||||
return tf_decorator.make_decorator(unbound_method,
|
|
||||||
wrap_with_name_scope(unbound_method))
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("Module")
|
@tf_export("Module")
|
||||||
class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
|
class Module(tracking.AutoTrackable):
|
||||||
"""Base neural network module class.
|
"""Base neural network module class.
|
||||||
|
|
||||||
A module is a named container for `tf.Variable`s, other `tf.Module`s and
|
A module is a named container for `tf.Variable`s, other `tf.Module`s and
|
||||||
@ -179,37 +37,54 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
|
|||||||
network might be implemented as a `tf.Module`:
|
network might be implemented as a `tf.Module`:
|
||||||
|
|
||||||
>>> class Dense(tf.Module):
|
>>> class Dense(tf.Module):
|
||||||
... def __init__(self, in_features, output_features):
|
... def __init__(self, in_features, output_features, name=None):
|
||||||
... super(Dense, self).__init__()
|
... super(Dense, self).__init__(name=name)
|
||||||
... self.w = tf.Variable(
|
... self.w = tf.Variable(
|
||||||
... tf.random_normal([input_features, output_features]), name='w')
|
... tf.random_normal([input_features, output_features]), name='w')
|
||||||
... self.b = tf.Variable(tf.zeros([output_features]), name='b')
|
... self.b = tf.Variable(tf.zeros([output_features]), name='b')
|
||||||
...
|
...
|
||||||
... def __call__(self, x):
|
... def __call__(self, x):
|
||||||
... x = tf.convert_to_tensor(x, name='x')
|
|
||||||
... y = tf.matmul(x, self.w) + self.b
|
... y = tf.matmul(x, self.w) + self.b
|
||||||
... return tf.nn.relu(y)
|
... return tf.nn.relu(y)
|
||||||
|
|
||||||
You can use the dense layer as you would expect:
|
You can use the Dense layer as you would expect:
|
||||||
|
|
||||||
>>> d = Dense(input_features=64, output_features=10)
|
>>> d = Dense(input_features=64, output_features=10)
|
||||||
>>> d(tf.ones([100, 64]))
|
>>> d(tf.ones([100, 64]))
|
||||||
<tf.Tensor: ...>
|
<tf.Tensor: ...>
|
||||||
|
|
||||||
By subclassing `tf.Module` instead of `object` any variables created inside
|
By subclassing `tf.Module` instead of `object` any `tf.Variable` or
|
||||||
the module are automatically created within the modules name scope:
|
`tf.Module` instances assigned to object properties can be collected using
|
||||||
|
the `variables`, `trainable_variables` or `submodules` property:
|
||||||
>>> d.w.name
|
|
||||||
"dense/w:0"
|
|
||||||
|
|
||||||
In eager mode this is useful for debugging, and when used with `@tf.function`
|
|
||||||
the use of name scopes gives operations (e.g. matmul) useful names as well.
|
|
||||||
|
|
||||||
As well as automatic naming, the Dense module inherits methods for tracking
|
|
||||||
its variables:
|
|
||||||
|
|
||||||
>>> d.variables
|
>>> d.variables
|
||||||
(<tf.Variable 'dense/b:0' ...>, <tf.Variable 'dense/w:0' ...>)
|
(<tf.Variable 'b:0' ...>, <tf.Variable 'w:0' ...>)
|
||||||
|
|
||||||
|
Subclasses of `tf.Module` can also take advantage of the `_flatten` method
|
||||||
|
which can be used to implement tracking of any other types.
|
||||||
|
|
||||||
|
All `tf.Module` classes have an associated `tf.name_scope` which can be used
|
||||||
|
to group operations in TensorBoard and create hierarchies for variable names
|
||||||
|
which can help with debugging. We suggest using the name scope when creating
|
||||||
|
nested submodules/parameters or for forward methods whose graph you might want
|
||||||
|
to inspect in TensorBoard. You can enter the name scope explicitly using
|
||||||
|
`with self.name_scope:` or you can annotate methods (apart from `__init__`)
|
||||||
|
with `@tf.Module.with_name_scope`.
|
||||||
|
|
||||||
|
>>> class MLP(tf.Module):
|
||||||
|
... def __init__(self, input_size, sizes, name=None):
|
||||||
|
... super(MLP, self).__init__(name=name)
|
||||||
|
... self.layers = []
|
||||||
|
... with self.name_scope:
|
||||||
|
... for size in sizes:
|
||||||
|
... self.layers.append(Dense(input_size=size, output_size=size))
|
||||||
|
... input_size = size
|
||||||
|
...
|
||||||
|
... @tf.Module.with_name_scope
|
||||||
|
... def __call__(self, x):
|
||||||
|
... for layer in self.layers:
|
||||||
|
... x = layer(x)
|
||||||
|
... return x
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None):
|
||||||
@ -225,12 +100,6 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
|
|||||||
with ops.name_scope(name) as scope_name:
|
with ops.name_scope(name) as scope_name:
|
||||||
self._scope_name = scope_name
|
self._scope_name = scope_name
|
||||||
|
|
||||||
# Enter the name scope so subsequent code in the contructor (e.g. creating
|
|
||||||
# submodules) happens inside the modules name scope. This is exited when
|
|
||||||
# the subclass __init__ returns (this is implemented in ModuleMetaclass).
|
|
||||||
self._ctor_name_scope = self.name_scope
|
|
||||||
self._ctor_name_scope.__enter__()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
"""Returns the name of this module as passed or determined in the ctor.
|
"""Returns the name of this module as passed or determined in the ctor.
|
||||||
@ -360,27 +229,37 @@ class Module(six.with_metaclass(ModuleMetaclass, tracking.AutoTrackable)):
|
|||||||
with_path=with_path)
|
with_path=with_path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def no_name_scope(cls, method):
|
def with_name_scope(cls, method):
|
||||||
"""Decorator to wrap a method, preventing automatic name scope wrapping.
|
"""Decorator to automatically enter the module name scope.
|
||||||
|
|
||||||
By default, any method on a module is considered as a forwards function, and
|
>>> class MyModule(tf.Module):
|
||||||
so any variables / modules created by the method will be scoped as belonging
|
... @tf.Module.with_name_scope
|
||||||
to the module. In some cases this is undesirable, for example when
|
... def __call__(self, x):
|
||||||
implementing .clone() / .transpose(), as in those cases we want the new
|
... if not hasattr(self, 'w'):
|
||||||
module to have the scope of wherever the .transpose() call is made. To
|
... self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
|
||||||
allow this, decorate any methods with `no_module_name_scope`.
|
... return tf.matmul(x, self.w)
|
||||||
|
|
||||||
This logic is tied to ModuleMetaclass.__new__, if anything is
|
Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
|
||||||
changed here corresponding changes will be needed there.
|
names included the module name:
|
||||||
|
|
||||||
|
>>> mod = MyModule()
|
||||||
|
>>> mod(tf.ones([8, 32]))
|
||||||
|
<tf.Tensor: ...>
|
||||||
|
>>> mod.w
|
||||||
|
<tf.Variable ...'my_module/w:0'>
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
method: the method to wrap.
|
method: The method to wrap.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The method, with a flag indicating no name scope wrapping should occur.
|
The original method wrapped such that it enters the module's name scope.
|
||||||
"""
|
"""
|
||||||
setattr(method, NO_MODULE_NAME_SCOPE, True)
|
def method_with_name_scope(self, *args, **kwargs):
|
||||||
return method
|
with self.name_scope:
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return tf_decorator.make_decorator(method, method_with_name_scope)
|
||||||
|
|
||||||
|
|
||||||
_IS_VARIABLE = lambda o: isinstance(o, variables.Variable)
|
_IS_VARIABLE = lambda o: isinstance(o, variables.Variable)
|
||||||
_IS_TRAINABLE_VARIABLE = lambda o: (_IS_VARIABLE(o) and o.trainable)
|
_IS_TRAINABLE_VARIABLE = lambda o: (_IS_VARIABLE(o) and o.trainable)
|
||||||
|
@ -227,24 +227,6 @@ class ModuleTrackingTest(test.TestCase):
|
|||||||
self.assertEqual(set(m.submodules), {leaf1, leaf2})
|
self.assertEqual(set(m.submodules), {leaf1, leaf2})
|
||||||
|
|
||||||
|
|
||||||
class CommonErrorsTest(test.TestCase):
|
|
||||||
|
|
||||||
def test_not_calling_super_constructor(self):
|
|
||||||
msg = ("Constructing a tf.Module without calling the super constructor is "
|
|
||||||
"not supported")
|
|
||||||
with self.assertRaisesRegexp(ValueError, msg):
|
|
||||||
DoesNotCallSuperConstructorModule()
|
|
||||||
|
|
||||||
def test_calls_method_before_super(self):
|
|
||||||
msg = "super constructor must be called before any other methods"
|
|
||||||
with self.assertRaisesRegexp(AttributeError, msg):
|
|
||||||
CallsMethodBeforeSuperConstructorModule(allowed_method=False)
|
|
||||||
|
|
||||||
def test_annotated_method_is_allowed(self):
|
|
||||||
self.assertIsNotNone(
|
|
||||||
CallsMethodBeforeSuperConstructorModule(allowed_method=True))
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardMethodsTest(test.TestCase):
|
class ForwardMethodsTest(test.TestCase):
|
||||||
|
|
||||||
def testFunctionType(self):
|
def testFunctionType(self):
|
||||||
@ -307,6 +289,7 @@ class RecursiveModule(module.Module):
|
|||||||
|
|
||||||
def __init__(self, depth, trainable=True):
|
def __init__(self, depth, trainable=True):
|
||||||
super(RecursiveModule, self).__init__(name="badger")
|
super(RecursiveModule, self).__init__(name="badger")
|
||||||
|
with self.name_scope:
|
||||||
self.child = None
|
self.child = None
|
||||||
if depth > 1:
|
if depth > 1:
|
||||||
self.child = RecursiveModule(depth - 1, trainable=trainable)
|
self.child = RecursiveModule(depth - 1, trainable=trainable)
|
||||||
@ -323,6 +306,7 @@ class AbstractModule(module.Module):
|
|||||||
|
|
||||||
class ConcreteModule(AbstractModule):
|
class ConcreteModule(AbstractModule):
|
||||||
|
|
||||||
|
@module.Module.with_name_scope
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return x ** 2, get_name_scope()
|
return x ** 2, get_name_scope()
|
||||||
|
|
||||||
@ -333,6 +317,7 @@ class TreeModule(module.Module):
|
|||||||
super(TreeModule, self).__init__(name=name)
|
super(TreeModule, self).__init__(name=name)
|
||||||
self._leaves = []
|
self._leaves = []
|
||||||
|
|
||||||
|
@module.Module.with_name_scope
|
||||||
def new_leaf(self, name=None):
|
def new_leaf(self, name=None):
|
||||||
leaf = TreeModule(name=name)
|
leaf = TreeModule(name=name)
|
||||||
self._leaves.append(leaf)
|
self._leaves.append(leaf)
|
||||||
@ -341,15 +326,18 @@ class TreeModule(module.Module):
|
|||||||
|
|
||||||
class ReturnsNameScopeModule(module.Module):
|
class ReturnsNameScopeModule(module.Module):
|
||||||
|
|
||||||
|
@module.Module.with_name_scope
|
||||||
def alternative_forward(self):
|
def alternative_forward(self):
|
||||||
return get_name_scope()
|
return get_name_scope()
|
||||||
|
|
||||||
|
@module.Module.with_name_scope
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
return get_name_scope()
|
return get_name_scope()
|
||||||
|
|
||||||
|
|
||||||
class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule):
|
class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule):
|
||||||
|
|
||||||
|
@module.Module.with_name_scope
|
||||||
def alternative_alternative_forward(self):
|
def alternative_alternative_forward(self):
|
||||||
return get_name_scope()
|
return get_name_scope()
|
||||||
|
|
||||||
@ -368,37 +356,15 @@ class ModuleOverridingNameScope(ReturnsNameScopeModule):
|
|||||||
return ops.name_scope("yolo/")
|
return ops.name_scope("yolo/")
|
||||||
|
|
||||||
|
|
||||||
class DoesNotCallSuperConstructorModule(module.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# NOTE: Intentionally does not call super constructor.
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CallsMethodBeforeSuperConstructorModule(module.Module):
|
|
||||||
|
|
||||||
def __init__(self, allowed_method):
|
|
||||||
if allowed_method:
|
|
||||||
self.no_name_scope()
|
|
||||||
else:
|
|
||||||
self.with_name_scope()
|
|
||||||
super(CallsMethodBeforeSuperConstructorModule, self).__init__()
|
|
||||||
|
|
||||||
@module.Module.no_name_scope
|
|
||||||
def no_name_scope(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def with_name_scope(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModuleWithFunctionAnnotatedCall(module.Module):
|
class ModuleWithFunctionAnnotatedCall(module.Module):
|
||||||
|
|
||||||
@def_function.function(autograph=False)
|
@def_function.function(autograph=False)
|
||||||
|
@module.Module.with_name_scope
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return get_name_scope()
|
return get_name_scope()
|
||||||
|
|
||||||
@def_function.function(autograph=True)
|
@def_function.function(autograph=True)
|
||||||
|
@module.Module.with_name_scope
|
||||||
def forward_ag(self):
|
def forward_ag(self):
|
||||||
return get_name_scope()
|
return get_name_scope()
|
||||||
|
|
||||||
@ -410,22 +376,22 @@ class PropertyModule(module.Module):
|
|||||||
self._setter_scope_name = None
|
self._setter_scope_name = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@module.Module.with_name_scope
|
||||||
def some_property(self):
|
def some_property(self):
|
||||||
getter_scope_name = get_name_scope()
|
getter_scope_name = get_name_scope()
|
||||||
return getter_scope_name, self._setter_scope_name
|
return getter_scope_name, self._setter_scope_name
|
||||||
|
|
||||||
@some_property.setter
|
@some_property.setter
|
||||||
|
@module.Module.with_name_scope
|
||||||
def some_property(self, my_property):
|
def some_property(self, my_property):
|
||||||
self._setter_scope_name = get_name_scope()
|
self._setter_scope_name = get_name_scope()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@module.Module.no_name_scope
|
|
||||||
def no_name_scope_property(self):
|
def no_name_scope_property(self):
|
||||||
getter_scope_name = get_name_scope()
|
getter_scope_name = get_name_scope()
|
||||||
return getter_scope_name, self._setter_scope_name
|
return getter_scope_name, self._setter_scope_name
|
||||||
|
|
||||||
@no_name_scope_property.setter
|
@no_name_scope_property.setter
|
||||||
@module.Module.no_name_scope
|
|
||||||
def no_name_scope_property(self, my_property):
|
def no_name_scope_property(self, my_property):
|
||||||
self._setter_scope_name = get_name_scope()
|
self._setter_scope_name = get_name_scope()
|
||||||
|
|
||||||
@ -514,43 +480,6 @@ class SimpleModule(module.Module):
|
|||||||
IS_MEMBER = lambda v: isinstance(v, MemberType)
|
IS_MEMBER = lambda v: isinstance(v, MemberType)
|
||||||
IS_MODULE = lambda v: isinstance(v, module.Module)
|
IS_MODULE = lambda v: isinstance(v, module.Module)
|
||||||
|
|
||||||
|
|
||||||
class CustomMetaclass(type):
|
|
||||||
|
|
||||||
TAG = "__custom_metaclass__"
|
|
||||||
|
|
||||||
def __new__(mcs, name, bases, clsdict):
|
|
||||||
new_type = super(CustomMetaclass, mcs).__new__(mcs, name, bases, clsdict)
|
|
||||||
setattr(new_type, CustomMetaclass.TAG, True)
|
|
||||||
return new_type
|
|
||||||
|
|
||||||
|
|
||||||
class CombiningMetaclass(module.ModuleMetaclass, CustomMetaclass):
|
|
||||||
|
|
||||||
TAG = "__combining_metaclass__"
|
|
||||||
|
|
||||||
def __new__(mcs, name, bases, clsdict):
|
|
||||||
new_type = super(CombiningMetaclass, mcs).__new__(mcs, name, bases, clsdict)
|
|
||||||
setattr(new_type, CombiningMetaclass.TAG, True)
|
|
||||||
return new_type
|
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(CombiningMetaclass)
|
|
||||||
class ModuleWithCustomMetaclass(module.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(ModuleWithCustomMetaclass, self).__init__()
|
|
||||||
self.init_name_scope = get_name_scope()
|
|
||||||
|
|
||||||
|
|
||||||
class CustomMetaclassTest(test.TestCase):
|
|
||||||
|
|
||||||
def testSupportsCustomMetaclass(self):
|
|
||||||
m = ModuleWithCustomMetaclass()
|
|
||||||
self.assertEqual(m.init_name_scope, "module_with_custom_metaclass/")
|
|
||||||
self.assertTrue(getattr(ModuleWithCustomMetaclass, CombiningMetaclass.TAG))
|
|
||||||
self.assertTrue(getattr(ModuleWithCustomMetaclass, CustomMetaclass.TAG))
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -29,7 +29,7 @@ tf_class {
|
|||||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "no_name_scope"
|
name: "with_name_scope"
|
||||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -142,7 +142,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
name: "Module"
|
name: "Module"
|
||||||
mtype: "<class \'tensorflow.python.module.module.ModuleMetaclass\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
name: "NameAttrList"
|
name: "NameAttrList"
|
||||||
|
@ -29,7 +29,7 @@ tf_class {
|
|||||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "no_name_scope"
|
name: "with_name_scope"
|
||||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
name: "Module"
|
name: "Module"
|
||||||
mtype: "<class \'tensorflow.python.module.module.ModuleMetaclass\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
name: "Operation"
|
name: "Operation"
|
||||||
|
Loading…
Reference in New Issue
Block a user