From af8f596d21f03372291384227cae75d95fa6e468 Mon Sep 17 00:00:00 2001 From: Sung Jin Hwang Date: Mon, 22 Jun 2020 12:13:10 -0700 Subject: [PATCH] Add _SwitchN op to the handled op lists in convert_variables_to_constants_v2. Grappler transforms tf.switch_case() to use _SwitchN instead of Case op, and because _SwitchN is currently not handled convert_variables_to_constants_v2() fails when function contains tf.switch_case. PiperOrigin-RevId: 317707465 Change-Id: I3524c0c50f0acd02c2347c3c4f49ece5a5d671f7 --- .../python/framework/convert_to_constants.py | 3 +- .../framework/convert_to_constants_test.py | 34 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py index 87c74c3263d..4c3cbb06bf1 100644 --- a/tensorflow/python/framework/convert_to_constants.py +++ b/tensorflow/python/framework/convert_to_constants.py @@ -235,7 +235,8 @@ class _Node(_Convertible): return _If(node, function, enclosing_graph) elif node.op in ["While", "StatelessWhile"]: return _While(node, function, enclosing_graph) - elif node.op in ["Enter", "Exit", "Identity", "NextIteration", "Switch"]: + elif node.op in [ + "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]: return _Intermediate(node, function, enclosing_graph) else: return _Node(node, function, enclosing_graph) diff --git a/tensorflow/python/framework/convert_to_constants_test.py b/tensorflow/python/framework/convert_to_constants_test.py index b1e11003939..7252082d084 100644 --- a/tensorflow/python/framework/convert_to_constants_test.py +++ b/tensorflow/python/framework/convert_to_constants_test.py @@ -486,6 +486,40 @@ class VariablesToConstantsTest(test.TestCase): root, output_func = self._freezeModel(model) self._testConvertedFunction(root, root.f, output_func, input_data) + @test_util.run_v2_only + def testSwitchCase(self): + """Test a switch_case statement.""" + input_data = { + "i": constant_op.constant(np.random.randint(0, 3, dtype=np.int32)), + "x": constant_op.constant( + np.asarray(np.random.random_sample((10, 3)), dtype=np.float32)), + } + + w0 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32) + w1 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32) + w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32) + + def branch0(x): + return math_ops.matmul(x, w0) + + def branch1(x): + return math_ops.matmul(x, w1) + + def branch2(x): + x = array_ops.pad(x, [[0, 0], [0, 1]]) + return x + w2 + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32), + tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32), + ]) + def model(i, x): + return control_flow_ops.switch_case(i, [ + lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)]) + + root, output_func = self._freezeModel(model) + self._testConvertedFunction(root, root.f, output_func, input_data) + class ConvertVariablesToConstantsSessionTest(test.TestCase):