Make tf.name_scope(..) reentrant in TF2.

PiperOrigin-RevId: 252989408
This commit is contained in:
Tom Hennigan 2019-06-13 02:17:14 -07:00 committed by TensorFlower Gardener
parent c02a9dab45
commit b11313feca
5 changed files with 139 additions and 45 deletions

View File

@ -6167,25 +6167,9 @@ class name_scope(object): # pylint: disable=invalid-name
return self._name_scope.__enter__() return self._name_scope.__enter__()
if self._in_eager_mode: if self._in_eager_mode:
self._old_name = self._ctx.scope_name scope_name, old_name = enter_eager_name_scope(self._ctx, self._name,
if not self._name: self._default_name)
scope_name = "" self._old_name = old_name
else:
cache_key = self._name, self._old_name, self._default_name
if cache_key in name_scope_cache:
self._ctx.scope_name = name_scope_cache[cache_key]
return self._ctx.scope_name
elif self._name[-1] == "/":
# A trailing slash breaks out of nested name scopes, indicating a
# fully specified scope name, for compatibility with Graph.name_scope.
scope_name = self._name
else:
name_with_trailing_slash = self._name + "/"
scope_name = (
self._old_name + name_with_trailing_slash
if self._old_name else name_with_trailing_slash)
name_scope_cache[cache_key] = scope_name
self._ctx.scope_name = scope_name
return scope_name return scope_name
else: else:
if self._name is None and self._values is not None: if self._name is None and self._values is not None:
@ -6218,6 +6202,29 @@ class name_scope(object): # pylint: disable=invalid-name
return False # False values do not suppress exceptions return False # False values do not suppress exceptions
def enter_eager_name_scope(ctx, name, default_name=None):
"""Updates the eager context to enter the given name scope."""
old_name = ctx.scope_name
if not name:
scope_name = ""
else:
if name[-1] == "/":
# A trailing slash breaks out of nested name scopes, indicating a
# fully specified scope name, for compatibility with Graph.name_scope.
scope_name = name
else:
# TODO(tomhennigan) Benchmark and consider removing the cache.
cache_key = name, old_name, default_name
scope_name = name_scope_cache.get(cache_key, None)
if scope_name is None:
scope_name = name + "/"
if old_name:
scope_name = old_name + scope_name
name_scope_cache[cache_key] = scope_name
ctx.scope_name = scope_name
return scope_name, old_name
@tf_export("name_scope", v1=[]) @tf_export("name_scope", v1=[])
class name_scope_v2(name_scope): class name_scope_v2(name_scope):
"""A context manager for use when defining a Python op. """A context manager for use when defining a Python op.
@ -6256,7 +6263,38 @@ class name_scope_v2(name_scope):
""" """
if name is None or not isinstance(name, six.string_types): if name is None or not isinstance(name, six.string_types):
raise ValueError("name for name_scope must be a string.") raise ValueError("name for name_scope must be a string.")
super(name_scope_v2, self).__init__(name=None, default_name=name) self._name = name
self._exit_fns = []
@property
def name(self):
return self._name
def __enter__(self):
"""Start the scope block.
Returns:
The scope name.
Raises:
ValueError: if neither `name` nor `default_name` is provided
but `values` are.
"""
ctx = context.context()
if ctx.executing_eagerly():
scope_name, old_scope_name = enter_eager_name_scope(ctx, self._name)
self._exit_fns.append(
lambda *a: setattr(ctx, "scope_name", old_scope_name))
else:
scope = get_default_graph().name_scope(self._name)
scope_name = scope.__enter__()
self._exit_fns.append(scope.__exit__)
return scope_name
def __exit__(self, type_arg, value_arg, traceback_arg):
exit_fn = self._exit_fns.pop()
exit_fn(type_arg, value_arg, traceback_arg)
return False # False values do not suppress exceptions
def strip_name_scope(name, export_scope): def strip_name_scope(name, export_scope):

View File

@ -2048,6 +2048,21 @@ class OpScopeTest(test_util.TensorFlowTestCase):
with ops.name_scope(None, "default2") as scope2: with ops.name_scope(None, "default2") as scope2:
self.assertEqual(scope2, "default/default2/") self.assertEqual(scope2, "default/default2/")
@test_util.run_in_graph_and_eager_modes
def testNameScopeV2IsReEntrant(self):
foo = ops.name_scope_v2("foo")
bar = ops.name_scope_v2("bar")
with foo as scope_name:
self.assertEqual("foo/", scope_name)
with foo as scope_name:
self.assertEqual("foo/foo/", scope_name)
with bar as scope_name:
self.assertEqual("foo/bar/", scope_name)
with foo as scope_name:
self.assertEqual("foo/bar/foo/", scope_name)
with bar as scope_name:
self.assertEqual("bar/", scope_name)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testNoScopeName(self): def testNoScopeName(self):
g0 = ops.Graph() g0 = ops.Graph()

View File

@ -12,6 +12,7 @@ py_library(
srcs = ["module.py"], srcs = ["module.py"],
deps = [ deps = [
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:tf2",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//tensorflow/python/training/tracking", "//tensorflow/python/training/tracking",

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import re import re
from tensorflow.python import tf2
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
@ -113,6 +114,10 @@ class Module(tracking.AutoTrackable):
"identifiers (e.g. a valid class name)." % name) "identifiers (e.g. a valid class name)." % name)
self._name = name self._name = name
if tf2.enabled():
with ops.name_scope_v2(name) as scope_name:
self._name_scope = ops.name_scope_v2(scope_name)
else:
with ops.name_scope(name) as scope_name: with ops.name_scope(name) as scope_name:
self._scope_name = scope_name self._scope_name = scope_name
@ -128,7 +133,10 @@ class Module(tracking.AutoTrackable):
@property @property
def name_scope(self): def name_scope(self):
"""Returns a `tf.name_scope` instance for this class.""" """Returns a `tf.name_scope` instance for this class."""
# TODO(tomhennigan) Memoize once name scopes are re-entrant. if tf2.enabled():
return self._name_scope
else:
# In TF1 name_scope is not re-entrant in eager so we cannot memoize it.
return ops.name_scope(self._scope_name) return ops.name_scope(self._scope_name)
@property @property

View File

@ -25,10 +25,12 @@ import itertools
from absl.testing import parameterized from absl.testing import parameterized
import six import six
from tensorflow.python.compat import v2_compat from tensorflow.python import tf2
from tensorflow.python.distribute import values as distributed_values from tensorflow.python.distribute import values as distributed_values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import layers from tensorflow.python.keras import layers
from tensorflow.python.keras import models from tensorflow.python.keras import models
from tensorflow.python.module import module from tensorflow.python.module import module
@ -36,7 +38,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
class TestModuleNaming(test.TestCase): class TestModuleNaming(test_util.TensorFlowTestCase):
def test_single_name(self): def test_single_name(self):
mod = module.Module(name="simple") mod = module.Module(name="simple")
@ -110,7 +112,11 @@ class TestModuleNaming(test.TestCase):
with self.assertRaisesRegexp(ValueError, msg): with self.assertRaisesRegexp(ValueError, msg):
module.Module(name="$Foo") module.Module(name="$Foo")
@test_util.run_in_graph_and_eager_modes
def test_modules_not_numbered_in_eager(self): def test_modules_not_numbered_in_eager(self):
if not context.executing_eagerly():
self.skipTest("Eager specific")
mod = RecursiveModule(2) mod = RecursiveModule(2)
self.assertEqual(mod.name_scope.name, "badger/") self.assertEqual(mod.name_scope.name, "badger/")
self.assertEqual(mod.child.name_scope.name, "badger/badger/") self.assertEqual(mod.child.name_scope.name, "badger/badger/")
@ -119,8 +125,11 @@ class TestModuleNaming(test.TestCase):
self.assertEqual(mod.name_scope.name, "badger/") self.assertEqual(mod.name_scope.name, "badger/")
self.assertEqual(mod.child.name_scope.name, "badger/badger/") self.assertEqual(mod.child.name_scope.name, "badger/badger/")
@test_util.run_in_graph_and_eager_modes
def test_module_numbering_in_graph(self): def test_module_numbering_in_graph(self):
with ops.Graph().as_default(): if context.executing_eagerly():
self.skipTest("Graph specific")
mod = RecursiveModule(2) mod = RecursiveModule(2)
self.assertEqual(mod.name_scope.name, "badger/") self.assertEqual(mod.name_scope.name, "badger/")
self.assertEqual(mod.child.name_scope.name, "badger/badger/") self.assertEqual(mod.child.name_scope.name, "badger/badger/")
@ -183,7 +192,7 @@ class TestModuleNaming(test.TestCase):
self.assertIn(("does_not_exist", ""), scope_names) self.assertIn(("does_not_exist", ""), scope_names)
class VariableNamingTest(test.TestCase): class VariableNamingTest(test_util.TensorFlowTestCase):
def test_variable_names(self): def test_variable_names(self):
mod = RecursiveModule(3) mod = RecursiveModule(3)
@ -192,7 +201,30 @@ class VariableNamingTest(test.TestCase):
self.assertEqual(mod.child.child.w.name, "badger/badger/badger/mushroom:0") self.assertEqual(mod.child.child.w.name, "badger/badger/badger/mushroom:0")
class VariableTrackingTest(test.TestCase): class NameScopeTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def test_not_memoized_in_tf1(self):
if tf2.enabled():
self.skipTest("Requires TF1")
mod = module.Module(name="name")
name_scope_1 = mod.name_scope
name_scope_2 = mod.name_scope
self.assertIsNot(name_scope_1, name_scope_2)
self.assertEqual(name_scope_1.name, name_scope_2.name)
def test_memoized_in_tf2(self):
if not tf2.enabled():
self.skipTest("Requires TF2")
mod = module.Module(name="name")
name_scope_1 = mod.name_scope
name_scope_2 = mod.name_scope
self.assertIs(name_scope_1, name_scope_2)
class VariableTrackingTest(test_util.TensorFlowTestCase):
def test_variables(self): def test_variables(self):
m = RecursiveModule(3) m = RecursiveModule(3)
@ -234,7 +266,7 @@ class VariableTrackingTest(test.TestCase):
self.assertEqual(m.variables, (mirrored, tpu, aggregating)) self.assertEqual(m.variables, (mirrored, tpu, aggregating))
class ModuleTrackingTest(test.TestCase): class ModuleTrackingTest(test_util.TensorFlowTestCase):
def test_submodules(self): def test_submodules(self):
m = RecursiveModule(3) m = RecursiveModule(3)
@ -250,7 +282,7 @@ class ModuleTrackingTest(test.TestCase):
self.assertEqual(set(m.submodules), {leaf1, leaf2}) self.assertEqual(set(m.submodules), {leaf1, leaf2})
class ForwardMethodsTest(test.TestCase): class ForwardMethodsTest(test_util.TensorFlowTestCase):
def testFunctionType(self): def testFunctionType(self):
mod = ModuleWithFunctionAnnotatedCall() mod = ModuleWithFunctionAnnotatedCall()
@ -259,20 +291,20 @@ class ForwardMethodsTest(test.TestCase):
def testEntersNameScope_call(self): def testEntersNameScope_call(self):
mod = ModuleWithFunctionAnnotatedCall() mod = ModuleWithFunctionAnnotatedCall()
self.assertEqual(mod.forward().numpy(), self.assertEqual(self.evaluate(mod.forward()),
b"module_with_function_annotated_call/") b"module_with_function_annotated_call/")
self.assertEqual(mod.forward_ag().numpy(), self.assertEqual(self.evaluate(mod.forward_ag()),
b"module_with_function_annotated_call/") b"module_with_function_annotated_call/")
def testEntersNameScope_concreteFunction(self): def testEntersNameScope_concreteFunction(self):
mod = ModuleWithFunctionAnnotatedCall() mod = ModuleWithFunctionAnnotatedCall()
self.assertEqual(mod.forward.get_concrete_function()().numpy(), self.assertEqual(self.evaluate(mod.forward.get_concrete_function()()),
b"module_with_function_annotated_call/") b"module_with_function_annotated_call/")
self.assertEqual(mod.forward_ag.get_concrete_function()().numpy(), self.assertEqual(self.evaluate(mod.forward_ag.get_concrete_function()()),
b"module_with_function_annotated_call/") b"module_with_function_annotated_call/")
class AbcTest(test.TestCase): class AbcTest(test_util.TensorFlowTestCase):
def testAbstract(self): def testAbstract(self):
msg = "Can't instantiate .* abstract methods" msg = "Can't instantiate .* abstract methods"
@ -289,7 +321,8 @@ class AbcTest(test.TestCase):
def get_name_scope(): def get_name_scope():
with ops.name_scope("x") as ns: with ops.name_scope("x") as ns:
return ns[:-2] ns = "/".join(ns.split("/")[:-2])
return ns + "/" if ns else ""
class ErrorModuleError(Exception): class ErrorModuleError(Exception):
@ -422,7 +455,7 @@ NamedPair = collections.namedtuple("NamedPair", ("first", "second"))
mk_index_dict = lambda v: dict(enumerate(v)) mk_index_dict = lambda v: dict(enumerate(v))
class FlattenTest(parameterized.TestCase, test.TestCase): class FlattenTest(parameterized.TestCase, test_util.TensorFlowTestCase):
@parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict) @parameterized.parameters(lambda v: NamedPair(*v), list, tuple, mk_index_dict)
def test_flatten(self, container_type): def test_flatten(self, container_type):
@ -548,5 +581,4 @@ class SimpleModule(module.Module):
is_member = lambda v: isinstance(v, MemberType) is_member = lambda v: isinstance(v, MemberType)
if __name__ == "__main__": if __name__ == "__main__":
v2_compat.enable_v2_behavior()
test.main() test.main()