Do not propagate parent name_scope in v2 control flow when inside v1 graph since that would break existing TF2 estimator models.
PiperOrigin-RevId: 301874559 Change-Id: Ie42fdd40ce52ada1c4a1307a920625b489d5db67
This commit is contained in:
parent
511a7490e3
commit
111879a5bc
@ -632,7 +632,10 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
|
|||||||
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
|
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
|
||||||
self.assertEqual(len(greater_op_outputs), 1)
|
self.assertEqual(len(greater_op_outputs), 1)
|
||||||
self.assertAllClose(greater_op_outputs[0], False)
|
self.assertAllClose(greater_op_outputs[0], False)
|
||||||
pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"]
|
# This was needed for backwards compatibility with TF2 Estimators which
|
||||||
|
# rely on variable names.
|
||||||
|
prefix = b"cond/" if context.executing_eagerly() else b""
|
||||||
|
pow_op_outputs = instrument.graph_internal_ndarrays[b"%spow" % prefix]
|
||||||
self.assertEqual(len(pow_op_outputs), 1)
|
self.assertEqual(len(pow_op_outputs), 1)
|
||||||
self.assertAllClose(pow_op_outputs[0], -64.0)
|
self.assertAllClose(pow_op_outputs[0], -64.0)
|
||||||
|
|
||||||
|
@ -260,6 +260,7 @@ class CondV2Test(test.TestCase):
|
|||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")
|
cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
def testInheritParentNameScope(self):
|
def testInheritParentNameScope(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
|
@ -809,9 +809,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
return control_flow_ops.cond(
|
return control_flow_ops.cond(
|
||||||
pred, lambda: true_fn(inputs), lambda: false_fn(inputs))
|
pred, lambda: true_fn(inputs), lambda: false_fn(inputs))
|
||||||
|
|
||||||
|
# This was needed for backwards compatibility with TF2 Estimators which
|
||||||
|
# rely on variable names.
|
||||||
|
prefix = "cond/" if context.executing_eagerly() else ""
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
"Tensor cond/true_branch:0 in true_fn is accessed from false_fn."):
|
"Tensor %strue_branch:0 in true_fn is accessed from false_fn." %
|
||||||
|
prefix):
|
||||||
f()
|
f()
|
||||||
|
|
||||||
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
|
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
|
||||||
@ -836,10 +841,12 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
[other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
|
[other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
|
||||||
lambda: br4_fn(inputs)])
|
lambda: br4_fn(inputs)])
|
||||||
|
|
||||||
|
# This was needed for backwards compatibility with TF2 Estimators which
|
||||||
|
# rely on variable names.
|
||||||
|
prefix = "switch_case/indexed_case/" if context.executing_eagerly() else ""
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError, "Tensor %sbr1_identity:0 in branch 1 is "
|
||||||
"Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is "
|
"accessed from branch 4." % prefix):
|
||||||
"accessed from branch 4."):
|
|
||||||
f()
|
f()
|
||||||
|
|
||||||
def testCondListOutput(self):
|
def testCondListOutput(self):
|
||||||
|
@ -1175,6 +1175,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
Fn()
|
Fn()
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
def testInheritParentNameScope(self):
|
def testInheritParentNameScope(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import func_graph
|
from tensorflow.python.framework import func_graph
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
|
||||||
|
|
||||||
class CondBranchFuncGraph(func_graph.FuncGraph):
|
class CondBranchFuncGraph(func_graph.FuncGraph):
|
||||||
@ -29,8 +30,9 @@ class CondBranchFuncGraph(func_graph.FuncGraph):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
|
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
|
||||||
func_graph.override_func_graph_name_scope(self,
|
if ops.executing_eagerly_outside_functions():
|
||||||
self.outer_graph.get_name_scope())
|
func_graph.override_func_graph_name_scope(
|
||||||
|
self, self.outer_graph.get_name_scope())
|
||||||
|
|
||||||
|
|
||||||
class WhileCondFuncGraph(func_graph.FuncGraph):
|
class WhileCondFuncGraph(func_graph.FuncGraph):
|
||||||
@ -41,8 +43,9 @@ class WhileCondFuncGraph(func_graph.FuncGraph):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
|
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
|
||||||
func_graph.override_func_graph_name_scope(self,
|
if ops.executing_eagerly_outside_functions():
|
||||||
self.outer_graph.get_name_scope())
|
func_graph.override_func_graph_name_scope(
|
||||||
|
self, self.outer_graph.get_name_scope())
|
||||||
|
|
||||||
|
|
||||||
class WhileBodyFuncGraph(func_graph.FuncGraph):
|
class WhileBodyFuncGraph(func_graph.FuncGraph):
|
||||||
@ -53,5 +56,6 @@ class WhileBodyFuncGraph(func_graph.FuncGraph):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
|
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
|
||||||
func_graph.override_func_graph_name_scope(self,
|
if ops.executing_eagerly_outside_functions():
|
||||||
self.outer_graph.get_name_scope())
|
func_graph.override_func_graph_name_scope(
|
||||||
|
self, self.outer_graph.get_name_scope())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user