Add utility to get current name scope in tf.contrib.framework.
Change: 153409845
This commit is contained in:
parent
7448b27527
commit
949d457dc9
@ -53,6 +53,7 @@ See the @{$python/contrib.framework} guide.
|
|||||||
@@get_or_create_global_step
|
@@get_or_create_global_step
|
||||||
@@get_local_variables
|
@@get_local_variables
|
||||||
@@get_model_variables
|
@@get_model_variables
|
||||||
|
@@get_name_scope
|
||||||
@@get_trainable_variables
|
@@get_trainable_variables
|
||||||
@@get_unique_variable
|
@@get_unique_variable
|
||||||
@@get_variables_by_name
|
@@get_variables_by_name
|
||||||
|
@ -52,3 +52,21 @@ def get_graph_from_inputs(op_input_list, graph=None):
|
|||||||
"""
|
"""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return ops._get_graph_from_inputs(op_input_list, graph)
|
return ops._get_graph_from_inputs(op_input_list, graph)
|
||||||
|
|
||||||
|
|
||||||
|
def get_name_scope():
|
||||||
|
"""Returns the current name scope of the default graph.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with tf.name_scope('scope1'):
|
||||||
|
with tf.name_scope('scope2'):
|
||||||
|
print(tf.get_name_scope())
|
||||||
|
```
|
||||||
|
would print the string `scope1/scope2`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string represnting the current name scope.
|
||||||
|
"""
|
||||||
|
return ops.get_default_graph().get_name_scope()
|
||||||
|
@ -57,6 +57,15 @@ class OpsTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"):
|
with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"):
|
||||||
ops_lib.get_graph_from_inputs(values, g1)
|
ops_lib.get_graph_from_inputs(values, g1)
|
||||||
|
|
||||||
|
def testGetNameScope(self):
|
||||||
|
with ops.name_scope("scope1"):
|
||||||
|
with ops.name_scope("scope2"):
|
||||||
|
with ops.name_scope("scope3"):
|
||||||
|
self.assertEqual("scope1/scope2/scope3", ops_lib.get_name_scope())
|
||||||
|
self.assertEqual("scope1/scope2", ops_lib.get_name_scope())
|
||||||
|
self.assertEqual("scope1", ops_lib.get_name_scope())
|
||||||
|
self.assertEqual("", ops_lib.get_name_scope())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -2907,6 +2907,23 @@ class Graph(object):
|
|||||||
self._names_in_use[name] = 1
|
self._names_in_use[name] = 1
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
def get_name_scope(self):
|
||||||
|
"""Returns the current name scope.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with tf.name_scope('scope1'):
|
||||||
|
with tf.name_scope('scope2'):
|
||||||
|
print(tf.get_default_graph().get_name_scope())
|
||||||
|
```
|
||||||
|
would print the string `scope1/scope2`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string representing the current name scope.
|
||||||
|
"""
|
||||||
|
return self._name_stack
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def colocate_with(self, op, ignore_existing=False):
|
def colocate_with(self, op, ignore_existing=False):
|
||||||
"""Returns a context manager that specifies an op to colocate with.
|
"""Returns a context manager that specifies an op to colocate with.
|
||||||
|
@ -1662,6 +1662,16 @@ class NameScopeTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(es, striped)
|
self.assertEqual(es, striped)
|
||||||
self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
|
self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
|
||||||
|
|
||||||
|
def testGetNameScope(self):
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with ops.name_scope("scope1"):
|
||||||
|
with ops.name_scope("scope2"):
|
||||||
|
with ops.name_scope("scope3"):
|
||||||
|
self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
|
||||||
|
self.assertEqual("scope1/scope2", g.get_name_scope())
|
||||||
|
self.assertEqual("scope1", g.get_name_scope())
|
||||||
|
self.assertEqual("", g.get_name_scope())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -86,6 +86,10 @@ tf_class {
|
|||||||
name: "get_collection_ref"
|
name: "get_collection_ref"
|
||||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_name_scope"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_operation_by_name"
|
name: "get_operation_by_name"
|
||||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user