ops.name_scope now dispatches no name_scope_{v1,v2} or Graph.name_scope
ops.name_scope is a context manager which is used internally and exported as tf.name_scope in V1. Internal usages mostly come from ops which enter a new name scope prior to computing anything. This change generalizes ops.name_scope to dispatch to either * Graph.name_scope if any of the values= is a graph tensor; * name_scope_v1 (old ops.name_scope) in graph mode and * name_scope_v2 in eager mode. This allows to streamline name_scope_v1 implementation and make it slightly faster as a result. Of course, a more substantial speedup would come from making name_scope a C++ class, but this isn't easily done atm. Microbenchmarks: | V | Code | Before | After | |---+--------------------------------------+--------+--------| | 1 | ops.name_scope("foo") | 679ns | 614ns | | 1 | ops.nane_scope("foo", values=[t, t]) | 808ns | 674ns | | 2 | ops.name_scope("foo") | 690ns | 691ns | | 2 | ops.name_scope("foo", values=[t, t]) | 1.47?s | 1.09?s | where t is either a graph (in V1) or an eager (in V2) tensor. PiperOrigin-RevId: 266641251
This commit is contained in:
parent
47ff0028a1
commit
a340629472
@ -6230,11 +6230,121 @@ def get_all_collection_keys():
|
|||||||
return get_default_graph().get_all_collection_keys()
|
return get_default_graph().get_all_collection_keys()
|
||||||
|
|
||||||
|
|
||||||
|
def name_scope(name, default_name=None, values=None):
|
||||||
|
"""Internal-only entry point for `name_scope*`.
|
||||||
|
|
||||||
|
Internal ops do not use the public API and instead rely on
|
||||||
|
`ops.name_scope` regardless of the execution mode. This function
|
||||||
|
dispatches to the correct `name_scope*` implementation based on
|
||||||
|
the arguments provided and the current mode. Specifically,
|
||||||
|
|
||||||
|
* if `values` contains a graph tensor `Graph.name_scope` is used;
|
||||||
|
* `name_scope_v1` is used in graph mode;
|
||||||
|
* `name_scope_v2` -- in eager mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name argument that is passed to the op function.
|
||||||
|
default_name: The default name to use if the `name` argument is `None`.
|
||||||
|
values: The list of `Tensor` arguments that are passed to the op function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`name_scope*` context manager.
|
||||||
|
"""
|
||||||
|
ctx = context.context()
|
||||||
|
in_eager_mode = ctx.executing_eagerly()
|
||||||
|
if not in_eager_mode:
|
||||||
|
return internal_name_scope_v1(name, default_name, values)
|
||||||
|
|
||||||
|
name = default_name if name is None else name
|
||||||
|
if values:
|
||||||
|
# The presence of a graph tensor in `values` overrides the context.
|
||||||
|
# TODO(slebedev): this is Keras-specific and should be removed.
|
||||||
|
# pylint: disable=unidiomatic-typecheck
|
||||||
|
graph_value = next((value for value in values if type(value) == Tensor),
|
||||||
|
None)
|
||||||
|
# pylint: enable=unidiomatic-typecheck
|
||||||
|
if graph_value is not None:
|
||||||
|
return graph_value.graph.name_scope(name)
|
||||||
|
return name_scope_v2(name or "")
|
||||||
|
|
||||||
|
|
||||||
|
class internal_name_scope_v1(object): # pylint: disable=invalid-name
|
||||||
|
"""Graph-only version of `name_scope_v1`."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def __init__(self, name, default_name=None, values=None):
|
||||||
|
"""Initialize the context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name argument that is passed to the op function.
|
||||||
|
default_name: The default name to use if the `name` argument is `None`.
|
||||||
|
values: The list of `Tensor` arguments that are passed to the op function.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `default_name` is passed in but not a string.
|
||||||
|
"""
|
||||||
|
if not (default_name is None or isinstance(default_name, six.string_types)):
|
||||||
|
raise TypeError(
|
||||||
|
"`default_name` type (%s) is not a string type. You likely meant to "
|
||||||
|
"pass this into the `values` kwarg." % type(default_name))
|
||||||
|
self._name = default_name if name is None else name
|
||||||
|
self._default_name = default_name
|
||||||
|
self._values = values
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Start the scope block.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The scope name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if neither `name` nor `default_name` is provided
|
||||||
|
but `values` are.
|
||||||
|
"""
|
||||||
|
if self._name is None and self._values is not None:
|
||||||
|
# We only raise an error if values is not None (provided) because
|
||||||
|
# currently tf.name_scope(None) (values=None then) is sometimes used as
|
||||||
|
# an idiom to reset to top scope.
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of name (%s) and default_name (%s) must be provided."
|
||||||
|
% (self._name, self._default_name))
|
||||||
|
|
||||||
|
g = get_default_graph()
|
||||||
|
if self._values and not g.building_function:
|
||||||
|
# Specialize based on the knowledge that `_get_graph_from_inputs()`
|
||||||
|
# ignores `inputs` when building a function.
|
||||||
|
g_from_inputs = _get_graph_from_inputs(self._values)
|
||||||
|
if g_from_inputs is not g:
|
||||||
|
g = g_from_inputs
|
||||||
|
self._g_manager = g.as_default()
|
||||||
|
self._g_manager.__enter__()
|
||||||
|
else:
|
||||||
|
self._g_manager = None
|
||||||
|
else:
|
||||||
|
self._g_manager = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._name_scope = g.name_scope(self._name)
|
||||||
|
return self._name_scope.__enter__()
|
||||||
|
except:
|
||||||
|
if self._g_manager is not None:
|
||||||
|
self._g_manager.__exit__(*sys.exc_info())
|
||||||
|
raise
|
||||||
|
|
||||||
|
def __exit__(self, *exc_info):
|
||||||
|
self._name_scope.__exit__(*exc_info)
|
||||||
|
if self._g_manager is not None:
|
||||||
|
self._g_manager.__exit__(*exc_info)
|
||||||
|
|
||||||
|
|
||||||
# Named like a function for backwards compatibility with the
|
# Named like a function for backwards compatibility with the
|
||||||
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
|
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
|
||||||
# some object creation overhead.
|
# some object creation overhead.
|
||||||
@tf_export(v1=["name_scope"])
|
@tf_export(v1=["name_scope"])
|
||||||
class name_scope(object): # pylint: disable=invalid-name
|
class name_scope_v1(object): # pylint: disable=invalid-name
|
||||||
"""A context manager for use when defining a Python op.
|
"""A context manager for use when defining a Python op.
|
||||||
|
|
||||||
This context manager validates that the given `values` are from the
|
This context manager validates that the given `values` are from the
|
||||||
@ -6271,80 +6381,14 @@ class name_scope(object): # pylint: disable=invalid-name
|
|||||||
Raises:
|
Raises:
|
||||||
TypeError: if `default_name` is passed in but not a string.
|
TypeError: if `default_name` is passed in but not a string.
|
||||||
"""
|
"""
|
||||||
if not (default_name is None or isinstance(default_name, six.string_types)):
|
self._name_scope = name_scope(name, default_name, values)
|
||||||
raise TypeError(
|
|
||||||
"`default_name` type (%s) is not a string type. You likely meant to "
|
|
||||||
"pass this into the `values` kwarg." % type(default_name))
|
|
||||||
self._name = default_name if name is None else name
|
self._name = default_name if name is None else name
|
||||||
self._default_name = default_name
|
|
||||||
self._values = values
|
|
||||||
self._ctx = context.context()
|
|
||||||
self._in_eager_mode = self._ctx.executing_eagerly()
|
|
||||||
self._has_symbolic_input_in_eager = False
|
|
||||||
if self._values and self._in_eager_mode:
|
|
||||||
# The presence of a graph tensor in `self._values` overrides the context.
|
|
||||||
for value in self._values:
|
|
||||||
if hasattr(value, "graph"):
|
|
||||||
self._has_symbolic_input_in_eager = True
|
|
||||||
self._name_scope = value.graph.name_scope(self._name)
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Start the scope block.
|
return self._name_scope.__enter__()
|
||||||
|
|
||||||
Returns:
|
def __exit__(self, *exc_info):
|
||||||
The scope name.
|
return self._name_scope.__exit__(*exc_info)
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if neither `name` nor `default_name` is provided
|
|
||||||
but `values` are.
|
|
||||||
"""
|
|
||||||
if self._has_symbolic_input_in_eager:
|
|
||||||
return self._name_scope.__enter__()
|
|
||||||
|
|
||||||
if self._in_eager_mode:
|
|
||||||
scope_name, self._old_name = enter_eager_name_scope(self._ctx, self._name)
|
|
||||||
return scope_name
|
|
||||||
else:
|
|
||||||
if self._name is None and self._values is not None:
|
|
||||||
# We only raise an error if values is not None (provided) because
|
|
||||||
# currently tf.name_scope(None) (values=None then) is sometimes used as
|
|
||||||
# an idiom to reset to top scope.
|
|
||||||
raise ValueError(
|
|
||||||
"At least one of name (%s) and default_name (%s) must be provided."
|
|
||||||
% (self._name, self._default_name))
|
|
||||||
|
|
||||||
g = get_default_graph()
|
|
||||||
if self._values and not g.building_function:
|
|
||||||
# Specialize based on the knowledge that `_get_graph_from_inputs()`
|
|
||||||
# ignores `inputs` when building a function.
|
|
||||||
g_from_inputs = _get_graph_from_inputs(self._values)
|
|
||||||
if g_from_inputs is not g:
|
|
||||||
g = g_from_inputs
|
|
||||||
self._g_manager = g.as_default()
|
|
||||||
self._g_manager.__enter__()
|
|
||||||
else:
|
|
||||||
self._g_manager = None
|
|
||||||
else:
|
|
||||||
self._g_manager = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._name_scope = g.name_scope(self._name)
|
|
||||||
return self._name_scope.__enter__()
|
|
||||||
except:
|
|
||||||
if self._g_manager is not None:
|
|
||||||
self._g_manager.__exit__(*sys.exc_info())
|
|
||||||
raise
|
|
||||||
|
|
||||||
def __exit__(self, type_arg, value_arg, traceback_arg):
|
|
||||||
if self._has_symbolic_input_in_eager:
|
|
||||||
self._name_scope.__exit__(type_arg, value_arg, traceback_arg)
|
|
||||||
elif self._in_eager_mode:
|
|
||||||
self._ctx.scope_name = self._old_name
|
|
||||||
else:
|
|
||||||
self._name_scope.__exit__(type_arg, value_arg, traceback_arg)
|
|
||||||
if self._g_manager is not None:
|
|
||||||
self._g_manager.__exit__(type_arg, value_arg, traceback_arg)
|
|
||||||
return False # False values do not suppress exceptions
|
|
||||||
|
|
||||||
|
|
||||||
def enter_eager_name_scope(ctx, name):
|
def enter_eager_name_scope(ctx, name):
|
||||||
@ -6366,7 +6410,7 @@ def enter_eager_name_scope(ctx, name):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("name_scope", v1=[])
|
@tf_export("name_scope", v1=[])
|
||||||
class name_scope_v2(name_scope):
|
class name_scope_v2(object):
|
||||||
"""A context manager for use when defining a Python op.
|
"""A context manager for use when defining a Python op.
|
||||||
|
|
||||||
This context manager pushes a name scope, which will make the name of all
|
This context manager pushes a name scope, which will make the name of all
|
||||||
|
@ -98,4 +98,4 @@ keras_export("keras.initializers.TruncatedNormal", v1=[])(
|
|||||||
init_ops_v2.TruncatedNormal)
|
init_ops_v2.TruncatedNormal)
|
||||||
# pylint: enable=bad-continuation
|
# pylint: enable=bad-continuation
|
||||||
|
|
||||||
keras_export(v1=["keras.backend.name_scope"])(ops.name_scope)
|
keras_export(v1=["keras.backend.name_scope"])(ops.name_scope_v1)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
path: "tensorflow.keras.backend.name_scope"
|
path: "tensorflow.keras.backend.name_scope"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.framework.ops.name_scope\'>"
|
is_instance: "<class \'tensorflow.python.framework.ops.name_scope_v1\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "name"
|
name: "name"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
path: "tensorflow.name_scope"
|
path: "tensorflow.name_scope"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.framework.ops.name_scope\'>"
|
is_instance: "<class \'tensorflow.python.framework.ops.name_scope_v1\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "name"
|
name: "name"
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
path: "tensorflow.name_scope"
|
path: "tensorflow.name_scope"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.framework.ops.name_scope_v2\'>"
|
is_instance: "<class \'tensorflow.python.framework.ops.name_scope_v2\'>"
|
||||||
is_instance: "<class \'tensorflow.python.framework.ops.name_scope\'>"
|
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "name"
|
name: "name"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user