Inherit parent name scope stack when building branches of control flow ops.

PiperOrigin-RevId: 301207416
Change-Id: I1911ee21f4754373424fe6230945af5edab181e2
This commit is contained in:
Saurabh Saxena 2020-03-16 11:58:00 -07:00 committed by TensorFlower Gardener
parent 32aeb9957e
commit e861b664e6
6 changed files with 82 additions and 12 deletions

View File

@ -1281,3 +1281,7 @@ def dismantle_func_graph(func_graph):
""" """
func_graph.clear_captures() func_graph.clear_captures()
ops.dismantle_graph(func_graph) ops.dismantle_graph(func_graph)
def override_func_graph_name_scope(func_graph, name_scope):
func_graph._name_stack = name_scope # pylint: disable=protected-access

View File

@ -632,7 +632,7 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP] greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
self.assertEqual(len(greater_op_outputs), 1) self.assertEqual(len(greater_op_outputs), 1)
self.assertAllClose(greater_op_outputs[0], False) self.assertAllClose(greater_op_outputs[0], False)
pow_op_outputs = instrument.graph_internal_ndarrays[b"pow"] pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"]
self.assertEqual(len(pow_op_outputs), 1) self.assertEqual(len(pow_op_outputs), 1)
self.assertAllClose(pow_op_outputs[0], -64.0) self.assertAllClose(pow_op_outputs[0], -64.0)
@ -660,9 +660,9 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
# Check the graph internal ndarrays recorded at runtime. # Check the graph internal ndarrays recorded at runtime.
read_variable_op_outputs = instrument.graph_internal_ndarrays[ read_variable_op_outputs = instrument.graph_internal_ndarrays[
_READ_VARIABLE_OP] b"while/" + _READ_VARIABLE_OP]
self.assertAllClose(read_variable_op_outputs, [1.0, 2.0, 4.0, 8.0]) self.assertAllClose(read_variable_op_outputs, [1.0, 2.0, 4.0, 8.0])
less_op_outputs = instrument.graph_internal_ndarrays[_LESS_OP] less_op_outputs = instrument.graph_internal_ndarrays[b"while/" + _LESS_OP]
self.assertAllClose(less_op_outputs, [True, True, True, True, False]) self.assertAllClose(less_op_outputs, [True, True, True, True, False])
# TODO(cais): The following isn't decorated with # TODO(cais): The following isn't decorated with

View File

@ -260,6 +260,31 @@ class CondV2Test(test.TestCase):
self.assertRegexpMatches( self.assertRegexpMatches(
cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*") cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")
def testInheritParentNameScope(self):
@def_function.function
def f():
with ops.name_scope("foo"):
def then_branch():
with ops.name_scope("then"):
actual_name_scope = ops.get_name_scope()
expected_name_scope = "foo/cond/then"
self.assertEqual(actual_name_scope, expected_name_scope)
return 0.
def else_branch():
with ops.name_scope("else"):
actual_name_scope = ops.get_name_scope()
expected_name_scope = "foo/cond/else"
self.assertEqual(actual_name_scope, expected_name_scope)
return 0.
return cond_v2.cond_v2(
constant_op.constant(True), then_branch, else_branch)
f()
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testDefunInCond(self): def testDefunInCond(self):
x = constant_op.constant(1.0, name="x") x = constant_op.constant(1.0, name="x")

View File

@ -811,7 +811,7 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
"Tensor true_branch:0 in true_fn is accessed from false_fn."): "Tensor cond/true_branch:0 in true_fn is accessed from false_fn."):
f() f()
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self): def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
@ -838,7 +838,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
"Tensor br1_identity:0 in branch 1 is accessed from branch 4."): "Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is "
"accessed from branch 4."):
f() f()
def testCondListOutput(self): def testCondListOutput(self):

View File

@ -1175,6 +1175,34 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
Fn() Fn()
def testInheritParentNameScope(self):
@def_function.function
def F():
with ops.name_scope("foo"):
def Cond(unused_i):
with ops.name_scope("cond"):
actual_name_scope = ops.get_name_scope()
expected_name_scope = "foo/while/cond"
assert actual_name_scope == expected_name_scope, (
"%s does not match %s" %
(actual_name_scope, expected_name_scope))
return False
def Body(i):
with ops.name_scope("body"):
actual_name_scope = ops.get_name_scope()
expected_name_scope = "foo/while/body"
assert actual_name_scope == expected_name_scope, (
"%s does not match %s" %
(actual_name_scope, expected_name_scope))
return i
return while_v2.while_loop(Cond, Body, [0.])
F()
def ScalarShape(): def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32) return ops.convert_to_tensor([], dtype=dtypes.int32)

View File

@ -18,28 +18,40 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework.func_graph import FuncGraph from tensorflow.python.framework import func_graph
class CondBranchFuncGraph(FuncGraph): class CondBranchFuncGraph(func_graph.FuncGraph):
"""FuncGraph for branches of tf.cond(). """FuncGraph for branches of tf.cond().
This is used to distinguish cond branches from other functions. This is used to distinguish cond branches from other functions.
""" """
pass
def __init__(self, *args, **kwargs):
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())
class WhileCondFuncGraph(FuncGraph): class WhileCondFuncGraph(func_graph.FuncGraph):
"""FuncGraph for the condition of tf.while_loop(). """FuncGraph for the condition of tf.while_loop().
This is used to distinguish while conditions from other functions. This is used to distinguish while conditions from other functions.
""" """
pass
def __init__(self, *args, **kwargs):
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())
class WhileBodyFuncGraph(FuncGraph): class WhileBodyFuncGraph(func_graph.FuncGraph):
"""FuncGraph for the body of tf.while_loop(). """FuncGraph for the body of tf.while_loop().
This is used to distinguish while bodies from other functions. This is used to distinguish while bodies from other functions.
""" """
pass
def __init__(self, *args, **kwargs):
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())