From 111879a5bc7a2317a92421b0c07c4be27b96b094 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 19 Mar 2020 12:43:10 -0700 Subject: [PATCH] Do not propagate parent name_scope in v2 control flow when inside v1 graph since that would break existing TF2 estimator models. PiperOrigin-RevId: 301874559 Change-Id: Ie42fdd40ce52ada1c4a1307a920625b489d5db67 --- tensorflow/python/framework/op_callbacks_test.py | 5 ++++- tensorflow/python/kernel_tests/cond_v2_test.py | 1 + .../kernel_tests/control_flow_ops_py_test.py | 15 +++++++++++---- tensorflow/python/kernel_tests/while_v2_test.py | 1 + .../python/ops/control_flow_v2_func_graphs.py | 16 ++++++++++------ 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/framework/op_callbacks_test.py b/tensorflow/python/framework/op_callbacks_test.py index f04d85bba21..14304536f65 100644 --- a/tensorflow/python/framework/op_callbacks_test.py +++ b/tensorflow/python/framework/op_callbacks_test.py @@ -632,7 +632,10 @@ class OpCallbacksTest(test_util.TensorFlowTestCase): greater_op_outputs = instrument.graph_internal_ndarrays[_GREATER_OP] self.assertEqual(len(greater_op_outputs), 1) self.assertAllClose(greater_op_outputs[0], False) - pow_op_outputs = instrument.graph_internal_ndarrays[b"cond/pow"] + # This was needed for backwards compatibility with TF2 Estimators which + # rely on variable names. + prefix = b"cond/" if context.executing_eagerly() else b"" + pow_op_outputs = instrument.graph_internal_ndarrays[b"%spow" % prefix] self.assertEqual(len(pow_op_outputs), 1) self.assertAllClose(pow_op_outputs[0], -64.0) diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index de8ea8d89d7..1682f2275c1 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -260,6 +260,7 @@ class CondV2Test(test.TestCase): self.assertRegexpMatches( cond2_op.get_attr("else_branch").name, r"foo_cond_1_false_\d*") + @test_util.run_v2_only def testInheritParentNameScope(self): @def_function.function diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 99fff136314..2533cf0a645 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -809,9 +809,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): return control_flow_ops.cond( pred, lambda: true_fn(inputs), lambda: false_fn(inputs)) + # This was needed for backwards compatibility with TF2 Estimators which + # rely on variable names. + prefix = "cond/" if context.executing_eagerly() else "" + with self.assertRaisesRegexp( ValueError, - "Tensor cond/true_branch:0 in true_fn is accessed from false_fn."): + "Tensor %strue_branch:0 in true_fn is accessed from false_fn." % + prefix): f() def testSwitchCaseAccessBranch1TensorInBranch4Raises(self): @@ -836,10 +841,12 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): [other_fn, lambda: br1_fn(inputs), other_fn, other_fn, lambda: br4_fn(inputs)]) + # This was needed for backwards compatibility with TF2 Estimators which + # rely on variable names. + prefix = "switch_case/indexed_case/" if context.executing_eagerly() else "" with self.assertRaisesRegexp( - ValueError, - "Tensor switch_case/indexed_case/br1_identity:0 in branch 1 is " - "accessed from branch 4."): + ValueError, "Tensor %sbr1_identity:0 in branch 1 is " + "accessed from branch 4." % prefix): f() def testCondListOutput(self): diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py index 1fa6c179e7a..3f53f49fc30 100644 --- a/tensorflow/python/kernel_tests/while_v2_test.py +++ b/tensorflow/python/kernel_tests/while_v2_test.py @@ -1175,6 +1175,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase): Fn() + @test_util.run_v2_only def testInheritParentNameScope(self): @def_function.function diff --git a/tensorflow/python/ops/control_flow_v2_func_graphs.py b/tensorflow/python/ops/control_flow_v2_func_graphs.py index 537ad2b4b8a..97e04f8d73d 100644 --- a/tensorflow/python/ops/control_flow_v2_func_graphs.py +++ b/tensorflow/python/ops/control_flow_v2_func_graphs.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import func_graph +from tensorflow.python.framework import ops class CondBranchFuncGraph(func_graph.FuncGraph): @@ -29,8 +30,9 @@ class CondBranchFuncGraph(func_graph.FuncGraph): def __init__(self, *args, **kwargs): super(CondBranchFuncGraph, self).__init__(*args, **kwargs) - func_graph.override_func_graph_name_scope(self, - self.outer_graph.get_name_scope()) + if ops.executing_eagerly_outside_functions(): + func_graph.override_func_graph_name_scope( + self, self.outer_graph.get_name_scope()) class WhileCondFuncGraph(func_graph.FuncGraph): @@ -41,8 +43,9 @@ class WhileCondFuncGraph(func_graph.FuncGraph): def __init__(self, *args, **kwargs): super(WhileCondFuncGraph, self).__init__(*args, **kwargs) - func_graph.override_func_graph_name_scope(self, - self.outer_graph.get_name_scope()) + if ops.executing_eagerly_outside_functions(): + func_graph.override_func_graph_name_scope( + self, self.outer_graph.get_name_scope()) class WhileBodyFuncGraph(func_graph.FuncGraph): @@ -53,5 +56,6 @@ class WhileBodyFuncGraph(func_graph.FuncGraph): def __init__(self, *args, **kwargs): super(WhileBodyFuncGraph, self).__init__(*args, **kwargs) - func_graph.override_func_graph_name_scope(self, - self.outer_graph.get_name_scope()) + if ops.executing_eagerly_outside_functions(): + func_graph.override_func_graph_name_scope( + self, self.outer_graph.get_name_scope())