Inherit parent name scope stack when building branches of control flow ops.

PiperOrigin-RevId: 301207416
Change-Id: I1911ee21f4754373424fe6230945af5edab181e2
This commit is contained in:
Saurabh Saxena 2020-03-16 11:58:00 -07:00 committed by TensorFlower Gardener
parent 32aeb9957e
commit e861b664e6
6 changed files with 82 additions and 12 deletions

View File

@ -1281,3 +1281,7 @@ def dismantle_func_graph(func_graph):
"""
func_graph.clear_captures()
ops.dismantle_graph(func_graph)
def override_func_graph_name_scope(func_graph, name_scope):
func_graph._name_stack = name_scope # pylint: disable=protected-access

View File

@ -632,7 +632,7 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
self.assertEqual(len(greater_op_outputs), 1)
self.assertAllClose(greater_op_outputs[0], False)
pow_op_outputs = instrument.graph_internal_ndarrays[b"pow"]
pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"]
self.assertEqual(len(pow_op_outputs), 1)
self.assertAllClose(pow_op_outputs[0], -64.0)
@ -660,9 +660,9 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
# Check the graph internal ndarrays recorded at runtime.
read_variable_op_outputs = instrument.graph_internal_ndarrays[
_READ_VARIABLE_OP]
b"while/" + _READ_VARIABLE_OP]
self.assertAllClose(read_variable_op_outputs, [1.0, 2.0, 4.0, 8.0])
less_op_outputs = instrument.graph_internal_ndarrays[_LESS_OP]
less_op_outputs = instrument.graph_internal_ndarrays[b"while/" + _LESS_OP]
self.assertAllClose(less_op_outputs, [True, True, True, True, False])
# TODO(cais): The following isn't decorated with

View File

@ -260,6 +260,31 @@ class CondV2Test(test.TestCase):
self.assertRegexpMatches(
cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")
def testInheritParentNameScope(self):
@def_function.function
def f():
with ops.name_scope("foo"):
def then_branch():
with ops.name_scope("then"):
actual_name_scope = ops.get_name_scope()
expected_name_scope = "foo/cond/then"
self.assertEqual(actual_name_scope, expected_name_scope)
return 0.
def else_branch():
with ops.name_scope("else"):
actual_name_scope = ops.get_name_scope()
expected_name_scope = "foo/cond/else"
self.assertEqual(actual_name_scope, expected_name_scope)
return 0.
return cond_v2.cond_v2(
constant_op.constant(True), then_branch, else_branch)
f()
@test_util.run_v1_only("b/120545219")
def testDefunInCond(self):
x = constant_op.constant(1.0, name="x")

View File

@ -811,7 +811,7 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(
ValueError,
"Tensor true_branch:0 in true_fn is accessed from false_fn."):
"Tensor cond/true_branch:0 in true_fn is accessed from false_fn."):
f()
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
@ -838,7 +838,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(
ValueError,
"Tensor br1_identity:0 in branch 1 is accessed from branch 4."):
"Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is "
"accessed from branch 4."):
f()
def testCondListOutput(self):

View File

@ -1175,6 +1175,34 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
Fn()
def testInheritParentNameScope(self):
@def_function.function
def F():
with ops.name_scope("foo"):
def Cond(unused_i):
with ops.name_scope("cond"):
actual_name_scope = ops.get_name_scope()
expected_name_scope = "foo/while/cond"
assert actual_name_scope == expected_name_scope, (
"%s does not match %s" %
(actual_name_scope, expected_name_scope))
return False
def Body(i):
with ops.name_scope("body"):
actual_name_scope = ops.get_name_scope()
expected_name_scope = "foo/while/body"
assert actual_name_scope == expected_name_scope, (
"%s does not match %s" %
(actual_name_scope, expected_name_scope))
return i
return while_v2.while_loop(Cond, Body, [0.])
F()
def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32)

View File

@ -18,28 +18,40 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework.func_graph import FuncGraph
from tensorflow.python.framework import func_graph
class CondBranchFuncGraph(FuncGraph):
class CondBranchFuncGraph(func_graph.FuncGraph):
"""FuncGraph for branches of tf.cond().
This is used to distinguish cond branches from other functions.
"""
pass
def __init__(self, *args, **kwargs):
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())
class WhileCondFuncGraph(FuncGraph):
class WhileCondFuncGraph(func_graph.FuncGraph):
"""FuncGraph for the condition of tf.while_loop().
This is used to distinguish while conditions from other functions.
"""
pass
def __init__(self, *args, **kwargs):
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())
class WhileBodyFuncGraph(FuncGraph):
class WhileBodyFuncGraph(func_graph.FuncGraph):
"""FuncGraph for the body of tf.while_loop().
This is used to distinguish while bodies from other functions.
"""
pass
def __init__(self, *args, **kwargs):
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())