Do not propagate parent name_scope in v2 control flow when inside v1 graph since that would break existing TF2 estimator models.
PiperOrigin-RevId: 301874559 Change-Id: Ie42fdd40ce52ada1c4a1307a920625b489d5db67
This commit is contained in:
parent
511a7490e3
commit
111879a5bc
|
@ -632,7 +632,10 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
|
|||
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
|
||||
self.assertEqual(len(greater_op_outputs), 1)
|
||||
self.assertAllClose(greater_op_outputs[0], False)
|
||||
pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"]
|
||||
# This was needed for backwards compatibility with TF2 Estimators which
|
||||
# rely on variable names.
|
||||
prefix = b"cond/" if context.executing_eagerly() else b""
|
||||
pow_op_outputs = instrument.graph_internal_ndarrays[b"%spow" % prefix]
|
||||
self.assertEqual(len(pow_op_outputs), 1)
|
||||
self.assertAllClose(pow_op_outputs[0], -64.0)
|
||||
|
||||
|
|
|
@ -260,6 +260,7 @@ class CondV2Test(test.TestCase):
|
|||
self.assertRegexpMatches(
|
||||
cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testInheritParentNameScope(self):
|
||||
|
||||
@def_function.function
|
||||
|
|
|
@ -809,9 +809,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||
return control_flow_ops.cond(
|
||||
pred, lambda: true_fn(inputs), lambda: false_fn(inputs))
|
||||
|
||||
# This was needed for backwards compatibility with TF2 Estimators which
|
||||
# rely on variable names.
|
||||
prefix = "cond/" if context.executing_eagerly() else ""
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Tensor cond/true_branch:0 in true_fn is accessed from false_fn."):
|
||||
"Tensor %strue_branch:0 in true_fn is accessed from false_fn." %
|
||||
prefix):
|
||||
f()
|
||||
|
||||
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
|
||||
|
@ -836,10 +841,12 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||
[other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
|
||||
lambda: br4_fn(inputs)])
|
||||
|
||||
# This was needed for backwards compatibility with TF2 Estimators which
|
||||
# rely on variable names.
|
||||
prefix = "switch_case/indexed_case/" if context.executing_eagerly() else ""
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is "
|
||||
"accessed from branch 4."):
|
||||
ValueError, "Tensor %sbr1_identity:0 in branch 1 is "
|
||||
"accessed from branch 4." % prefix):
|
||||
f()
|
||||
|
||||
def testCondListOutput(self):
|
||||
|
|
|
@ -1175,6 +1175,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||
|
||||
Fn()
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testInheritParentNameScope(self):
|
||||
|
||||
@def_function.function
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
class CondBranchFuncGraph(func_graph.FuncGraph):
|
||||
|
@ -29,8 +30,9 @@ class CondBranchFuncGraph(func_graph.FuncGraph):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
|
||||
func_graph.override_func_graph_name_scope(self,
|
||||
self.outer_graph.get_name_scope())
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
func_graph.override_func_graph_name_scope(
|
||||
self, self.outer_graph.get_name_scope())
|
||||
|
||||
|
||||
class WhileCondFuncGraph(func_graph.FuncGraph):
|
||||
|
@ -41,8 +43,9 @@ class WhileCondFuncGraph(func_graph.FuncGraph):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
|
||||
func_graph.override_func_graph_name_scope(self,
|
||||
self.outer_graph.get_name_scope())
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
func_graph.override_func_graph_name_scope(
|
||||
self, self.outer_graph.get_name_scope())
|
||||
|
||||
|
||||
class WhileBodyFuncGraph(func_graph.FuncGraph):
|
||||
|
@ -53,5 +56,6 @@ class WhileBodyFuncGraph(func_graph.FuncGraph):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
|
||||
func_graph.override_func_graph_name_scope(self,
|
||||
self.outer_graph.get_name_scope())
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
func_graph.override_func_graph_name_scope(
|
||||
self, self.outer_graph.get_name_scope())
|
||||
|
|
Loading…
Reference in New Issue