Inherit parent name scope stack when building branches of control flow ops.
PiperOrigin-RevId: 301207416 Change-Id: I1911ee21f4754373424fe6230945af5edab181e2
This commit is contained in:
parent
32aeb9957e
commit
e861b664e6
@ -1281,3 +1281,7 @@ def dismantle_func_graph(func_graph):
|
|||||||
"""
|
"""
|
||||||
func_graph.clear_captures()
|
func_graph.clear_captures()
|
||||||
ops.dismantle_graph(func_graph)
|
ops.dismantle_graph(func_graph)
|
||||||
|
|
||||||
|
|
||||||
|
def override_func_graph_name_scope(func_graph, name_scope):
|
||||||
|
func_graph._name_stack = name_scope # pylint: disable=protected-access
|
||||||
|
@ -632,7 +632,7 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
|
|||||||
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
|
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
|
||||||
self.assertEqual(len(greater_op_outputs), 1)
|
self.assertEqual(len(greater_op_outputs), 1)
|
||||||
self.assertAllClose(greater_op_outputs[0], False)
|
self.assertAllClose(greater_op_outputs[0], False)
|
||||||
pow_op_outputs = instrument.graph_internal_ndarrays[b"pow"]
|
pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"]
|
||||||
self.assertEqual(len(pow_op_outputs), 1)
|
self.assertEqual(len(pow_op_outputs), 1)
|
||||||
self.assertAllClose(pow_op_outputs[0], -64.0)
|
self.assertAllClose(pow_op_outputs[0], -64.0)
|
||||||
|
|
||||||
@ -660,9 +660,9 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
# Check the graph internal ndarrays recorded at runtime.
|
# Check the graph internal ndarrays recorded at runtime.
|
||||||
read_variable_op_outputs = instrument.graph_internal_ndarrays[
|
read_variable_op_outputs = instrument.graph_internal_ndarrays[
|
||||||
_READ_VARIABLE_OP]
|
b"while/" + _READ_VARIABLE_OP]
|
||||||
self.assertAllClose(read_variable_op_outputs, [1.0, 2.0, 4.0, 8.0])
|
self.assertAllClose(read_variable_op_outputs, [1.0, 2.0, 4.0, 8.0])
|
||||||
less_op_outputs = instrument.graph_internal_ndarrays[_LESS_OP]
|
less_op_outputs = instrument.graph_internal_ndarrays[b"while/" + _LESS_OP]
|
||||||
self.assertAllClose(less_op_outputs, [True, True, True, True, False])
|
self.assertAllClose(less_op_outputs, [True, True, True, True, False])
|
||||||
|
|
||||||
# TODO(cais): The following isn't decorated with
|
# TODO(cais): The following isn't decorated with
|
||||||
|
@ -260,6 +260,31 @@ class CondV2Test(test.TestCase):
|
|||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")
|
cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")
|
||||||
|
|
||||||
|
def testInheritParentNameScope(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f():
|
||||||
|
with ops.name_scope("foo"):
|
||||||
|
|
||||||
|
def then_branch():
|
||||||
|
with ops.name_scope("then"):
|
||||||
|
actual_name_scope = ops.get_name_scope()
|
||||||
|
expected_name_scope = "foo/cond/then"
|
||||||
|
self.assertEqual(actual_name_scope, expected_name_scope)
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
def else_branch():
|
||||||
|
with ops.name_scope("else"):
|
||||||
|
actual_name_scope = ops.get_name_scope()
|
||||||
|
expected_name_scope = "foo/cond/else"
|
||||||
|
self.assertEqual(actual_name_scope, expected_name_scope)
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
return cond_v2.cond_v2(
|
||||||
|
constant_op.constant(True), then_branch, else_branch)
|
||||||
|
|
||||||
|
f()
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testDefunInCond(self):
|
def testDefunInCond(self):
|
||||||
x = constant_op.constant(1.0, name="x")
|
x = constant_op.constant(1.0, name="x")
|
||||||
|
@ -811,7 +811,7 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
"Tensor true_branch:0 in true_fn is accessed from false_fn."):
|
"Tensor cond/true_branch:0 in true_fn is accessed from false_fn."):
|
||||||
f()
|
f()
|
||||||
|
|
||||||
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
|
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
|
||||||
@ -838,7 +838,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
"Tensor br1_identity:0 in branch 1 is accessed from branch 4."):
|
"Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is "
|
||||||
|
"accessed from branch 4."):
|
||||||
f()
|
f()
|
||||||
|
|
||||||
def testCondListOutput(self):
|
def testCondListOutput(self):
|
||||||
|
@ -1175,6 +1175,34 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
Fn()
|
Fn()
|
||||||
|
|
||||||
|
def testInheritParentNameScope(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def F():
|
||||||
|
with ops.name_scope("foo"):
|
||||||
|
|
||||||
|
def Cond(unused_i):
|
||||||
|
with ops.name_scope("cond"):
|
||||||
|
actual_name_scope = ops.get_name_scope()
|
||||||
|
expected_name_scope = "foo/while/cond"
|
||||||
|
assert actual_name_scope == expected_name_scope, (
|
||||||
|
"%s does not match %s" %
|
||||||
|
(actual_name_scope, expected_name_scope))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def Body(i):
|
||||||
|
with ops.name_scope("body"):
|
||||||
|
actual_name_scope = ops.get_name_scope()
|
||||||
|
expected_name_scope = "foo/while/body"
|
||||||
|
assert actual_name_scope == expected_name_scope, (
|
||||||
|
"%s does not match %s" %
|
||||||
|
(actual_name_scope, expected_name_scope))
|
||||||
|
return i
|
||||||
|
|
||||||
|
return while_v2.while_loop(Cond, Body, [0.])
|
||||||
|
|
||||||
|
F()
|
||||||
|
|
||||||
|
|
||||||
def ScalarShape():
|
def ScalarShape():
|
||||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||||
|
@ -18,28 +18,40 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework.func_graph import FuncGraph
|
from tensorflow.python.framework import func_graph
|
||||||
|
|
||||||
|
|
||||||
class CondBranchFuncGraph(FuncGraph):
|
class CondBranchFuncGraph(func_graph.FuncGraph):
|
||||||
"""FuncGraph for branches of tf.cond().
|
"""FuncGraph for branches of tf.cond().
|
||||||
|
|
||||||
This is used to distinguish cond branches from other functions.
|
This is used to distinguish cond branches from other functions.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
|
||||||
|
func_graph.override_func_graph_name_scope(self,
|
||||||
|
self.outer_graph.get_name_scope())
|
||||||
|
|
||||||
|
|
||||||
class WhileCondFuncGraph(FuncGraph):
|
class WhileCondFuncGraph(func_graph.FuncGraph):
|
||||||
"""FuncGraph for the condition of tf.while_loop().
|
"""FuncGraph for the condition of tf.while_loop().
|
||||||
|
|
||||||
This is used to distinguish while conditions from other functions.
|
This is used to distinguish while conditions from other functions.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
|
||||||
|
func_graph.override_func_graph_name_scope(self,
|
||||||
|
self.outer_graph.get_name_scope())
|
||||||
|
|
||||||
|
|
||||||
class WhileBodyFuncGraph(FuncGraph):
|
class WhileBodyFuncGraph(func_graph.FuncGraph):
|
||||||
"""FuncGraph for the body of tf.while_loop().
|
"""FuncGraph for the body of tf.while_loop().
|
||||||
|
|
||||||
This is used to distinguish while bodies from other functions.
|
This is used to distinguish while bodies from other functions.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
|
||||||
|
func_graph.override_func_graph_name_scope(self,
|
||||||
|
self.outer_graph.get_name_scope())
|
||||||
|
Loading…
Reference in New Issue
Block a user