Do not propagate parent name_scope in v2 control flow when inside v1 graph since that would break existing TF2 estimator models.

PiperOrigin-RevId: 301874559
Change-Id: Ie42fdd40ce52ada1c4a1307a920625b489d5db67
This commit is contained in:
Saurabh Saxena 2020-03-19 12:43:10 -07:00 committed by TensorFlower Gardener
parent 511a7490e3
commit 111879a5bc
5 changed files with 27 additions and 11 deletions

View File

@ -632,7 +632,10 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP]
self.assertEqual(len(greater_op_outputs), 1)
self.assertAllClose(greater_op_outputs[0], False)
pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"]
# This was needed for backwards compatibility with TF2 Estimators which
# rely on variable names.
prefix = b"cond/" if context.executing_eagerly() else b""
pow_op_outputs = instrument.graph_internal_ndarrays[b"%spow" % prefix]
self.assertEqual(len(pow_op_outputs), 1)
self.assertAllClose(pow_op_outputs[0], -64.0)

View File

@ -260,6 +260,7 @@ class CondV2Test(test.TestCase):
self.assertRegexpMatches(
cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*")
@test_util.run_v2_only
def testInheritParentNameScope(self):
@def_function.function

View File

@ -809,9 +809,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
return control_flow_ops.cond(
pred, lambda: true_fn(inputs), lambda: false_fn(inputs))
# This was needed for backwards compatibility with TF2 Estimators which
# rely on variable names.
prefix = "cond/" if context.executing_eagerly() else ""
with self.assertRaisesRegexp(
ValueError,
"Tensor cond/true_branch:0 in true_fn is accessed from false_fn."):
"Tensor %strue_branch:0 in true_fn is accessed from false_fn." %
prefix):
f()
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
@ -836,10 +841,12 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
[other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
lambda: br4_fn(inputs)])
# This was needed for backwards compatibility with TF2 Estimators which
# rely on variable names.
prefix = "switch_case/indexed_case/" if context.executing_eagerly() else ""
with self.assertRaisesRegexp(
ValueError,
"Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is "
"accessed from branch 4."):
ValueError, "Tensor %sbr1_identity:0 in branch 1 is "
"accessed from branch 4." % prefix):
f()
def testCondListOutput(self):

View File

@ -1175,6 +1175,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
Fn()
@test_util.run_v2_only
def testInheritParentNameScope(self):
@def_function.function

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
class CondBranchFuncGraph(func_graph.FuncGraph):
@ -29,8 +30,9 @@ class CondBranchFuncGraph(func_graph.FuncGraph):
def __init__(self, *args, **kwargs):
super(CondBranchFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())
if ops.executing_eagerly_outside_functions():
func_graph.override_func_graph_name_scope(
self, self.outer_graph.get_name_scope())
class WhileCondFuncGraph(func_graph.FuncGraph):
@ -41,8 +43,9 @@ class WhileCondFuncGraph(func_graph.FuncGraph):
def __init__(self, *args, **kwargs):
super(WhileCondFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())
if ops.executing_eagerly_outside_functions():
func_graph.override_func_graph_name_scope(
self, self.outer_graph.get_name_scope())
class WhileBodyFuncGraph(func_graph.FuncGraph):
@ -53,5 +56,6 @@ class WhileBodyFuncGraph(func_graph.FuncGraph):
def __init__(self, *args, **kwargs):
super(WhileBodyFuncGraph, self).__init__(*args, **kwargs)
func_graph.override_func_graph_name_scope(self,
self.outer_graph.get_name_scope())
if ops.executing_eagerly_outside_functions():
func_graph.override_func_graph_name_scope(
self, self.outer_graph.get_name_scope())