From 4eae0941b70bab5a3d00ce8e077e1ffa32416e4e Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Wed, 27 May 2020 14:31:00 -0700 Subject: [PATCH] Reduce Layer.__call__ overhead by ~3% Reduces method call invocation overhead since these methods are regularly used and reducing Python method call overhead is meaningful here Improvements: - Faster get_default_graph (~3x faster) - Check `if self.stack` rather than `if len(self.stack) >= 1' - Replace expensive calls to `super` and `_GetGlobalDefaultGraph()` with explicit logic. - Faster name_scope (~15% faster) - Remove redundant `name is not None` check. - One str concat operation instead of two (for nested scopes) - Move enter_eager_name_scope logic directly to __enter__ - Use `name[-] == '/'` instead of `name.endswith('/')`, faster for 1 char - Use `ctx.scope_name = old_name` rather than more expensive `setattr(ctx, 'scope_name', old_name)` PiperOrigin-RevId: 313464894 Change-Id: I9721b8af66f1d08c6abf5ebd92d14f15c9a9e8cc --- tensorflow/python/framework/ops.py | 71 ++++++++++++++---------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 5b6dac5be34..b68d613e045 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -5353,7 +5353,7 @@ class _DefaultStack(threading.local): self.stack = [] def get_default(self): - return self.stack[-1] if len(self.stack) >= 1 else None + return self.stack[-1] if self.stack else None def reset(self): self.stack = [] @@ -5541,10 +5541,13 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access def get_default(self): """Override that returns a global default if the stack is empty.""" - ret = super(_DefaultGraphStack, self).get_default() - if ret is None: - ret = self._GetGlobalDefaultGraph() - return ret + if self.stack: + return self.stack[-1] + elif self._global_default_graph: + return self._global_default_graph + else: + self._global_default_graph = Graph() + return self._global_default_graph def _GetGlobalDefaultGraph(self): if self._global_default_graph is None: @@ -6535,24 +6538,6 @@ class name_scope_v1(object): # pylint: disable=invalid-name return self._name_scope.__exit__(*exc_info) -def enter_eager_name_scope(ctx, name): - """Updates the eager context to enter the given name scope.""" - old_name = ctx.scope_name - if not name: - scope_name = "" - else: - if name.endswith("/"): - # A trailing slash breaks out of nested name scopes, indicating a - # fully specified scope name, for compatibility with Graph.name_scope. - scope_name = name - else: - scope_name = name + "/" - if old_name: - scope_name = old_name + scope_name - ctx.scope_name = scope_name - return scope_name, old_name - - @tf_export("name_scope", v1=[]) class name_scope_v2(object): """A context manager for use when defining a Python op. @@ -6575,9 +6560,9 @@ class name_scope_v2(object): When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`, and `MyOp/c`. - If the scope name already exists, the name will be made unique by appending - `_n`. For example, calling `my_op` the second time will generate `MyOp_1/a`, - etc. + Inside a `tf.function`, if the scope name already exists, the name will be + made unique by appending `_n`. For example, calling `my_op` the second time + will generate `MyOp_1/a`, etc. """ def __init__(self, name): @@ -6587,9 +6572,9 @@ class name_scope_v2(object): name: The prefix to use on all names created within the name scope. Raises: - ValueError: If name is None, or not a string. + ValueError: If name is not a string. """ - if name is None or not isinstance(name, six.string_types): + if not isinstance(name, six.string_types): raise ValueError("name for name_scope must be a string.") self._name = name self._exit_fns = [] @@ -6603,16 +6588,29 @@ class name_scope_v2(object): Returns: The scope name. - - Raises: - ValueError: if neither `name` nor `default_name` is provided - but `values` are. """ ctx = context.context() if ctx.executing_eagerly(): - scope_name, old_scope_name = enter_eager_name_scope(ctx, self._name) - self._exit_fns.append( - lambda *a: setattr(ctx, "scope_name", old_scope_name)) + # Names are not auto-incremented in eager mode. + # A trailing slash breaks out of nested name scopes, indicating a + # fully specified scope name, for compatibility with Graph.name_scope. + # This also prevents auto-incrementing. + old_name = ctx.scope_name + name = self._name + if not name: + scope_name = "" + elif name[-1] == "/": + scope_name = name + elif old_name: + scope_name = old_name + name + "/" + else: + scope_name = name + "/" + ctx.scope_name = scope_name + + def _restore_name_scope(*_): + ctx.scope_name = old_name + + self._exit_fns.append(_restore_name_scope) else: scope = get_default_graph().name_scope(self._name) scope_name = scope.__enter__() @@ -6620,8 +6618,7 @@ class name_scope_v2(object): return scope_name def __exit__(self, type_arg, value_arg, traceback_arg): - exit_fn = self._exit_fns.pop() - exit_fn(type_arg, value_arg, traceback_arg) + self._exit_fns.pop()(type_arg, value_arg, traceback_arg) return False # False values do not suppress exceptions