Make tf.name_scope(..) reentrant in TF2.
PiperOrigin-RevId: 252989408
This commit is contained in:
parent
c02a9dab45
commit
b11313feca
@ -6167,25 +6167,9 @@ class name_scope(object): # pylint: disable=invalid-name
|
|||||||
return self._name_scope.__enter__()
|
return self._name_scope.__enter__()
|
||||||
|
|
||||||
if self._in_eager_mode:
|
if self._in_eager_mode:
|
||||||
self._old_name = self._ctx.scope_name
|
scope_name, old_name = enter_eager_name_scope(self._ctx, self._name,
|
||||||
if not self._name:
|
self._default_name)
|
||||||
scope_name = ""
|
self._old_name = old_name
|
||||||
else:
|
|
||||||
cache_key = self._name, self._old_name, self._default_name
|
|
||||||
if cache_key in name_scope_cache:
|
|
||||||
self._ctx.scope_name = name_scope_cache[cache_key]
|
|
||||||
return self._ctx.scope_name
|
|
||||||
elif self._name[-1] == "/":
|
|
||||||
# A trailing slash breaks out of nested name scopes, indicating a
|
|
||||||
# fully specified scope name, for compatibility with Graph.name_scope.
|
|
||||||
scope_name = self._name
|
|
||||||
else:
|
|
||||||
name_with_trailing_slash = self._name + "/"
|
|
||||||
scope_name = (
|
|
||||||
self._old_name + name_with_trailing_slash
|
|
||||||
if self._old_name else name_with_trailing_slash)
|
|
||||||
name_scope_cache[cache_key] = scope_name
|
|
||||||
self._ctx.scope_name = scope_name
|
|
||||||
return scope_name
|
return scope_name
|
||||||
else:
|
else:
|
||||||
if self._name is None and self._values is not None:
|
if self._name is None and self._values is not None:
|
||||||
@ -6218,6 +6202,29 @@ class name_scope(object): # pylint: disable=invalid-name
|
|||||||
return False # False values do not suppress exceptions
|
return False # False values do not suppress exceptions
|
||||||
|
|
||||||
|
|
||||||
|
def enter_eager_name_scope(ctx, name, default_name=None):
|
||||||
|
"""Updates the eager context to enter the given name scope."""
|
||||||
|
old_name = ctx.scope_name
|
||||||
|
if not name:
|
||||||
|
scope_name = ""
|
||||||
|
else:
|
||||||
|
if name[-1] == "/":
|
||||||
|
# A trailing slash breaks out of nested name scopes, indicating a
|
||||||
|
# fully specified scope name, for compatibility with Graph.name_scope.
|
||||||
|
scope_name = name
|
||||||
|
else:
|
||||||
|
# TODO(tomhennigan) Benchmark and consider removing the cache.
|
||||||
|
cache_key = name, old_name, default_name
|
||||||
|
scope_name = name_scope_cache.get(cache_key, None)
|
||||||
|
if scope_name is None:
|
||||||
|
scope_name = name + "/"
|
||||||
|
if old_name:
|
||||||
|
scope_name = old_name + scope_name
|
||||||
|
name_scope_cache[cache_key] = scope_name
|
||||||
|
ctx.scope_name = scope_name
|
||||||
|
return scope_name, old_name
|
||||||
|
|
||||||
|
|
||||||
@tf_export("name_scope", v1=[])
|
@tf_export("name_scope", v1=[])
|
||||||
class name_scope_v2(name_scope):
|
class name_scope_v2(name_scope):
|
||||||
"""A context manager for use when defining a Python op.
|
"""A context manager for use when defining a Python op.
|
||||||
@ -6256,7 +6263,38 @@ class name_scope_v2(name_scope):
|
|||||||
"""
|
"""
|
||||||
if name is None or not isinstance(name, six.string_types):
|
if name is None or not isinstance(name, six.string_types):
|
||||||
raise ValueError("name for name_scope must be a string.")
|
raise ValueError("name for name_scope must be a string.")
|
||||||
super(name_scope_v2, self).__init__(name=None, default_name=name)
|
self._name = name
|
||||||
|
self._exit_fns = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Start the scope block.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The scope name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if neither `name` nor `default_name` is provided
|
||||||
|
but `values` are.
|
||||||
|
"""
|
||||||
|
ctx = context.context()
|
||||||
|
if ctx.executing_eagerly():
|
||||||
|
scope_name, old_scope_name = enter_eager_name_scope(ctx, self._name)
|
||||||
|
self._exit_fns.append(
|
||||||
|
lambda *a: setattr(ctx, "scope_name", old_scope_name))
|
||||||
|
else:
|
||||||
|
scope = get_default_graph().name_scope(self._name)
|
||||||
|
scope_name = scope.__enter__()
|
||||||
|
self._exit_fns.append(scope.__exit__)
|
||||||
|
return scope_name
|
||||||
|
|
||||||
|
def __exit__(self, type_arg, value_arg, traceback_arg):
|
||||||
|
exit_fn = self._exit_fns.pop()
|
||||||
|
exit_fn(type_arg, value_arg, traceback_arg)
|
||||||
|
return False # False values do not suppress exceptions
|
||||||
|
|
||||||
|
|
||||||
def strip_name_scope(name, export_scope):
|
def strip_name_scope(name, export_scope):
|
||||||
|
@ -2048,6 +2048,21 @@ class OpScopeTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.name_scope(None, "default2") as scope2:
|
with ops.name_scope(None, "default2") as scope2:
|
||||||
self.assertEqual(scope2, "default/default2/")
|
self.assertEqual(scope2, "default/default2/")
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testNameScopeV2IsReEntrant(self):
|
||||||
|
foo = ops.name_scope_v2("foo")
|
||||||
|
bar = ops.name_scope_v2("bar")
|
||||||
|
with foo as scope_name:
|
||||||
|
self.assertEqual("foo/", scope_name)
|
||||||
|
with foo as scope_name:
|
||||||
|
self.assertEqual("foo/foo/", scope_name)
|
||||||
|
with bar as scope_name:
|
||||||
|
self.assertEqual("foo/bar/", scope_name)
|
||||||
|
with foo as scope_name:
|
||||||
|
self.assertEqual("foo/bar/foo/", scope_name)
|
||||||
|
with bar as scope_name:
|
||||||
|
self.assertEqual("bar/", scope_name)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testNoScopeName(self):
|
def testNoScopeName(self):
|
||||||
g0 = ops.Graph()
|
g0 = ops.Graph()
|
||||||
|
@ -12,6 +12,7 @@ py_library(
|
|||||||
srcs = ["module.py"],
|
srcs = ["module.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:tf2",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/training/tracking",
|
"//tensorflow/python/training/tracking",
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
@ -113,8 +114,12 @@ class Module(tracking.AutoTrackable):
|
|||||||
"identifiers (e.g. a valid class name)." % name)
|
"identifiers (e.g. a valid class name)." % name)
|
||||||
|
|
||||||
self._name = name
|
self._name = name
|
||||||
with ops.name_scope(name) as scope_name:
|
if tf2.enabled():
|
||||||
self._scope_name = scope_name
|
with ops.name_scope_v2(name) as scope_name:
|
||||||
|
self._name_scope = ops.name_scope_v2(scope_name)
|
||||||
|
else:
|
||||||
|
with ops.name_scope(name) as scope_name:
|
||||||
|
self._scope_name = scope_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -128,8 +133,11 @@ class Module(tracking.AutoTrackable):
|
|||||||
@property
|
@property
|
||||||
def name_scope(self):
|
def name_scope(self):
|
||||||
"""Returns a `tf.name_scope` instance for this class."""
|
"""Returns a `tf.name_scope` instance for this class."""
|
||||||
# TODO(tomhennigan) Memoize once name scopes are re-entrant.
|
if tf2.enabled():
|
||||||
return ops.name_scope(self._scope_name)
|
return self._name_scope
|
||||||
|
else:
|
||||||
|
# In TF1 name_scope is not re-entrant in eager so we cannot memoize it.
|
||||||
|
return ops.name_scope(self._scope_name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variables(self):
|
def variables(self):
|
||||||
|
@ -25,10 +25,12 @@ import itertools
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.compat import v2_compat
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.distribute import values as distributed_values
|
from tensorflow.python.distribute import values as distributed_values
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras import layers
|
from tensorflow.python.keras import layers
|
||||||
from tensorflow.python.keras import models
|
from tensorflow.python.keras import models
|
||||||
from tensorflow.python.module import module
|
from tensorflow.python.module import module
|
||||||
@ -36,7 +38,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class TestModuleNaming(test.TestCase):
|
class TestModuleNaming(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def test_single_name(self):
|
def test_single_name(self):
|
||||||
mod = module.Module(name="simple")
|
mod = module.Module(name="simple")
|
||||||
@ -110,7 +112,11 @@ class TestModuleNaming(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, msg):
|
with self.assertRaisesRegexp(ValueError, msg):
|
||||||
module.Module(name="$Foo")
|
module.Module(name="$Foo")
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_modules_not_numbered_in_eager(self):
|
def test_modules_not_numbered_in_eager(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest("Eager specific")
|
||||||
|
|
||||||
mod = RecursiveModule(2)
|
mod = RecursiveModule(2)
|
||||||
self.assertEqual(mod.name_scope.name, "badger/")
|
self.assertEqual(mod.name_scope.name, "badger/")
|
||||||
self.assertEqual(mod.child.name_scope.name, "badger/badger/")
|
self.assertEqual(mod.child.name_scope.name, "badger/badger/")
|
||||||
@ -119,15 +125,18 @@ class TestModuleNaming(test.TestCase):
|
|||||||
self.assertEqual(mod.name_scope.name, "badger/")
|
self.assertEqual(mod.name_scope.name, "badger/")
|
||||||
self.assertEqual(mod.child.name_scope.name, "badger/badger/")
|
self.assertEqual(mod.child.name_scope.name, "badger/badger/")
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_module_numbering_in_graph(self):
|
def test_module_numbering_in_graph(self):
|
||||||
with ops.Graph().as_default():
|
if context.executing_eagerly():
|
||||||
mod = RecursiveModule(2)
|
self.skipTest("Graph specific")
|
||||||
self.assertEqual(mod.name_scope.name, "badger/")
|
|
||||||
self.assertEqual(mod.child.name_scope.name, "badger/badger/")
|
|
||||||
|
|
||||||
mod = RecursiveModule(2)
|
mod = RecursiveModule(2)
|
||||||
self.assertEqual(mod.name_scope.name, "badger_1/")
|
self.assertEqual(mod.name_scope.name, "badger/")
|
||||||
self.assertEqual(mod.child.name_scope.name, "badger_1/badger/")
|
self.assertEqual(mod.child.name_scope.name, "badger/badger/")
|
||||||
|
|
||||||
|
mod = RecursiveModule(2)
|
||||||
|
self.assertEqual(mod.name_scope.name, "badger_1/")
|
||||||
|
self.assertEqual(mod.child.name_scope.name, "badger_1/badger/")
|
||||||
|
|
||||||
def test_ctor_error_closes_name_scope(self):
|
def test_ctor_error_closes_name_scope(self):
|
||||||
with self.assertRaises(ErrorModuleError):
|
with self.assertRaises(ErrorModuleError):
|
||||||
@ -183,7 +192,7 @@ class TestModuleNaming(test.TestCase):
|
|||||||
self.assertIn(("does_not_exist", ""), scope_names)
|
self.assertIn(("does_not_exist", ""), scope_names)
|
||||||
|
|
||||||
|
|
||||||
class VariableNamingTest(test.TestCase):
|
class VariableNamingTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def test_variable_names(self):
|
def test_variable_names(self):
|
||||||
mod = RecursiveModule(3)
|
mod = RecursiveModule(3)
|
||||||
@ -192,7 +201,30 @@ class VariableNamingTest(test.TestCase):
|
|||||||
self.assertEqual(mod.child.child.w.name, "badger/badger/badger/mushroom:0")
|
self.assertEqual(mod.child.child.w.name, "badger/badger/badger/mushroom:0")
|
||||||
|
|
||||||
|
|
||||||
class VariableTrackingTest(test.TestCase):
|
class NameScopeTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def test_not_memoized_in_tf1(self):
|
||||||
|
if tf2.enabled():
|
||||||
|
self.skipTest("Requires TF1")
|
||||||
|
|
||||||
|
mod = module.Module(name="name")
|
||||||
|
name_scope_1 = mod.name_scope
|
||||||
|
name_scope_2 = mod.name_scope
|
||||||
|
self.assertIsNot(name_scope_1, name_scope_2)
|
||||||
|
self.assertEqual(name_scope_1.name, name_scope_2.name)
|
||||||
|
|
||||||
|
def test_memoized_in_tf2(self):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Requires TF2")
|
||||||
|
|
||||||
|
mod = module.Module(name="name")
|
||||||
|
name_scope_1 = mod.name_scope
|
||||||
|
name_scope_2 = mod.name_scope
|
||||||
|
self.assertIs(name_scope_1, name_scope_2)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableTrackingTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def test_variables(self):
|
def test_variables(self):
|
||||||
m = RecursiveModule(3)
|
m = RecursiveModule(3)
|
||||||
@ -234,7 +266,7 @@ class VariableTrackingTest(test.TestCase):
|
|||||||
self.assertEqual(m.variables, (mirrored, tpu, aggregating))
|
self.assertEqual(m.variables, (mirrored, tpu, aggregating))
|
||||||
|
|
||||||
|
|
||||||
class ModuleTrackingTest(test.TestCase):
|
class ModuleTrackingTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def test_submodules(self):
|
def test_submodules(self):
|
||||||
m = RecursiveModule(3)
|
m = RecursiveModule(3)
|
||||||
@ -250,7 +282,7 @@ class ModuleTrackingTest(test.TestCase):
|
|||||||
self.assertEqual(set(m.submodules), {leaf1, leaf2})
|
self.assertEqual(set(m.submodules), {leaf1, leaf2})
|
||||||
|
|
||||||
|
|
||||||
class ForwardMethodsTest(test.TestCase):
|
class ForwardMethodsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testFunctionType(self):
|
def testFunctionType(self):
|
||||||
mod = ModuleWithFunctionAnnotatedCall()
|
mod = ModuleWithFunctionAnnotatedCall()
|
||||||
@ -259,20 +291,20 @@ class ForwardMethodsTest(test.TestCase):
|
|||||||
|
|
||||||
def testEntersNameScope_call(self):
|
def testEntersNameScope_call(self):
|
||||||
mod = ModuleWithFunctionAnnotatedCall()
|
mod = ModuleWithFunctionAnnotatedCall()
|
||||||
self.assertEqual(mod.forward().numpy(),
|
self.assertEqual(self.evaluate(mod.forward()),
|
||||||
b"module_with_function_annotated_call/")
|
b"module_with_function_annotated_call/")
|
||||||
self.assertEqual(mod.forward_ag().numpy(),
|
self.assertEqual(self.evaluate(mod.forward_ag()),
|
||||||
b"module_with_function_annotated_call/")
|
b"module_with_function_annotated_call/")
|
||||||
|
|
||||||
def testEntersNameScope_concreteFunction(self):
|
def testEntersNameScope_concreteFunction(self):
|
||||||
mod = ModuleWithFunctionAnnotatedCall()
|
mod = ModuleWithFunctionAnnotatedCall()
|
||||||
self.assertEqual(mod.forward.get_concrete_function()().numpy(),
|
self.assertEqual(self.evaluate(mod.forward.get_concrete_function()()),
|
||||||
b"module_with_function_annotated_call/")
|
b"module_with_function_annotated_call/")
|
||||||
self.assertEqual(mod.forward_ag.get_concrete_function()().numpy(),
|
self.assertEqual(self.evaluate(mod.forward_ag.get_concrete_function()()),
|
||||||
b"module_with_function_annotated_call/")
|
b"module_with_function_annotated_call/")
|
||||||
|
|
||||||
|
|
||||||
class AbcTest(test.TestCase):
|
class AbcTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testAbstract(self):
|
def testAbstract(self):
|
||||||
msg = "Can't instantiate .* abstract methods"
|
msg = "Can't instantiate .* abstract methods"
|
||||||
@ -289,7 +321,8 @@ class AbcTest(test.TestCase):
|
|||||||
|
|
||||||
def get_name_scope():
|
def get_name_scope():
|
||||||
with ops.name_scope("x") as ns:
|
with ops.name_scope("x") as ns:
|
||||||
return ns[:-2]
|
ns = "/".join(ns.split("/")[:-2])
|
||||||
|
return ns + "/" if ns else ""
|
||||||
|
|
||||||
|
|
||||||
class ErrorModuleError(Exception):
|
class ErrorModuleError(Exception):
|
||||||
@ -422,7 +455,7 @@ NamedPair = collections.namedtuple("NamedPair", ("first", "second"))
|
|||||||
mk_index_dict = lambda v: dict(enumerate(v))
|
mk_index_dict = lambda v: dict(enumerate(v))
|
||||||
|
|
||||||
|
|
||||||
class FlattenTest(parameterized.TestCase, test.TestCase):
|
class FlattenTest(parameterized.TestCase, test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict)
|
@parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict)
|
||||||
def test_flatten(self, container_type):
|
def test_flatten(self, container_type):
|
||||||
@ -548,5 +581,4 @@ class SimpleModule(module.Module):
|
|||||||
is_member = lambda v: isinstance(v, MemberType)
|
is_member = lambda v: isinstance(v, MemberType)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
v2_compat.enable_v2_behavior()
|
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user