Reduce Layer.__call__ overhead by ~3%
Reduces method call invocation overhead since these methods are regularly used and reducing Python method call overhead is meaningful here Improvements: - Faster get_default_graph (~3x faster) - Check `if self.stack` rather than `if len(self.stack) >= 1' - Replace expensive calls to `super` and `_GetGlobalDefaultGraph()` with explicit logic. - Faster name_scope (~15% faster) - Remove redundant `name is not None` check. - One str concat operation instead of two (for nested scopes) - Move enter_eager_name_scope logic directly to __enter__ - Use `name[-] == '/'` instead of `name.endswith('/')`, faster for 1 char - Use `ctx.scope_name = old_name` rather than more expensive `setattr(ctx, 'scope_name', old_name)` PiperOrigin-RevId: 313464894 Change-Id: I9721b8af66f1d08c6abf5ebd92d14f15c9a9e8cc
This commit is contained in:
parent
8dd818b0f6
commit
4eae0941b7
@ -5353,7 +5353,7 @@ class _DefaultStack(threading.local):
|
||||
self.stack = []
|
||||
|
||||
def get_default(self):
|
||||
return self.stack[-1] if len(self.stack) >= 1 else None
|
||||
return self.stack[-1] if self.stack else None
|
||||
|
||||
def reset(self):
|
||||
self.stack = []
|
||||
@ -5541,10 +5541,13 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access
|
||||
|
||||
def get_default(self):
|
||||
"""Override that returns a global default if the stack is empty."""
|
||||
ret = super(_DefaultGraphStack, self).get_default()
|
||||
if ret is None:
|
||||
ret = self._GetGlobalDefaultGraph()
|
||||
return ret
|
||||
if self.stack:
|
||||
return self.stack[-1]
|
||||
elif self._global_default_graph:
|
||||
return self._global_default_graph
|
||||
else:
|
||||
self._global_default_graph = Graph()
|
||||
return self._global_default_graph
|
||||
|
||||
def _GetGlobalDefaultGraph(self):
|
||||
if self._global_default_graph is None:
|
||||
@ -6535,24 +6538,6 @@ class name_scope_v1(object): # pylint: disable=invalid-name
|
||||
return self._name_scope.__exit__(*exc_info)
|
||||
|
||||
|
||||
def enter_eager_name_scope(ctx, name):
|
||||
"""Updates the eager context to enter the given name scope."""
|
||||
old_name = ctx.scope_name
|
||||
if not name:
|
||||
scope_name = ""
|
||||
else:
|
||||
if name.endswith("/"):
|
||||
# A trailing slash breaks out of nested name scopes, indicating a
|
||||
# fully specified scope name, for compatibility with Graph.name_scope.
|
||||
scope_name = name
|
||||
else:
|
||||
scope_name = name + "/"
|
||||
if old_name:
|
||||
scope_name = old_name + scope_name
|
||||
ctx.scope_name = scope_name
|
||||
return scope_name, old_name
|
||||
|
||||
|
||||
@tf_export("name_scope", v1=[])
|
||||
class name_scope_v2(object):
|
||||
"""A context manager for use when defining a Python op.
|
||||
@ -6575,9 +6560,9 @@ class name_scope_v2(object):
|
||||
When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`,
|
||||
and `MyOp/c`.
|
||||
|
||||
If the scope name already exists, the name will be made unique by appending
|
||||
`_n`. For example, calling `my_op` the second time will generate `MyOp_1/a`,
|
||||
etc.
|
||||
Inside a `tf.function`, if the scope name already exists, the name will be
|
||||
made unique by appending `_n`. For example, calling `my_op` the second time
|
||||
will generate `MyOp_1/a`, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
@ -6587,9 +6572,9 @@ class name_scope_v2(object):
|
||||
name: The prefix to use on all names created within the name scope.
|
||||
|
||||
Raises:
|
||||
ValueError: If name is None, or not a string.
|
||||
ValueError: If name is not a string.
|
||||
"""
|
||||
if name is None or not isinstance(name, six.string_types):
|
||||
if not isinstance(name, six.string_types):
|
||||
raise ValueError("name for name_scope must be a string.")
|
||||
self._name = name
|
||||
self._exit_fns = []
|
||||
@ -6603,16 +6588,29 @@ class name_scope_v2(object):
|
||||
|
||||
Returns:
|
||||
The scope name.
|
||||
|
||||
Raises:
|
||||
ValueError: if neither `name` nor `default_name` is provided
|
||||
but `values` are.
|
||||
"""
|
||||
ctx = context.context()
|
||||
if ctx.executing_eagerly():
|
||||
scope_name, old_scope_name = enter_eager_name_scope(ctx, self._name)
|
||||
self._exit_fns.append(
|
||||
lambda *a: setattr(ctx, "scope_name", old_scope_name))
|
||||
# Names are not auto-incremented in eager mode.
|
||||
# A trailing slash breaks out of nested name scopes, indicating a
|
||||
# fully specified scope name, for compatibility with Graph.name_scope.
|
||||
# This also prevents auto-incrementing.
|
||||
old_name = ctx.scope_name
|
||||
name = self._name
|
||||
if not name:
|
||||
scope_name = ""
|
||||
elif name[-1] == "/":
|
||||
scope_name = name
|
||||
elif old_name:
|
||||
scope_name = old_name + name + "/"
|
||||
else:
|
||||
scope_name = name + "/"
|
||||
ctx.scope_name = scope_name
|
||||
|
||||
def _restore_name_scope(*_):
|
||||
ctx.scope_name = old_name
|
||||
|
||||
self._exit_fns.append(_restore_name_scope)
|
||||
else:
|
||||
scope = get_default_graph().name_scope(self._name)
|
||||
scope_name = scope.__enter__()
|
||||
@ -6620,8 +6618,7 @@ class name_scope_v2(object):
|
||||
return scope_name
|
||||
|
||||
def __exit__(self, type_arg, value_arg, traceback_arg):
|
||||
exit_fn = self._exit_fns.pop()
|
||||
exit_fn(type_arg, value_arg, traceback_arg)
|
||||
self._exit_fns.pop()(type_arg, value_arg, traceback_arg)
|
||||
return False # False values do not suppress exceptions
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user