Add _SwitchN op to the handled op lists in convert_variables_to_constants_v2.

Grappler transforms tf.switch_case() to use _SwitchN instead of Case op, and
because _SwitchN is currently not handled convert_variables_to_constants_v2()
fails when function contains tf.switch_case.

PiperOrigin-RevId: 317707465
Change-Id: I3524c0c50f0acd02c2347c3c4f49ece5a5d671f7
This commit is contained in:
Sung Jin Hwang 2020-06-22 12:13:10 -07:00 committed by TensorFlower Gardener
parent 9ca89a201c
commit af8f596d21
2 changed files with 36 additions and 1 deletions

View File

@ -235,7 +235,8 @@ class _Node(_Convertible):
return _If(node, function, enclosing_graph)
elif node.op in ["While", "StatelessWhile"]:
return _While(node, function, enclosing_graph)
elif node.op in ["Enter", "Exit", "Identity", "NextIteration", "Switch"]:
elif node.op in [
"Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]:
return _Intermediate(node, function, enclosing_graph)
else:
return _Node(node, function, enclosing_graph)

View File

@ -486,6 +486,40 @@ class VariablesToConstantsTest(test.TestCase):
root, output_func = self._freezeModel(model)
self._testConvertedFunction(root, root.f, output_func, input_data)
@test_util.run_v2_only
def testSwitchCase(self):
"""Test a switch_case statement."""
input_data = {
"i": constant_op.constant(np.random.randint(0, 3, dtype=np.int32)),
"x": constant_op.constant(
np.asarray(np.random.random_sample((10, 3)), dtype=np.float32)),
}
w0 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
w1 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32)
def branch0(x):
return math_ops.matmul(x, w0)
def branch1(x):
return math_ops.matmul(x, w1)
def branch2(x):
x = array_ops.pad(x, [[0, 0], [0, 1]])
return x + w2
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32),
])
def model(i, x):
return control_flow_ops.switch_case(i, [
lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])
root, output_func = self._freezeModel(model)
self._testConvertedFunction(root, root.f, output_func, input_data)
class ConvertVariablesToConstantsSessionTest(test.TestCase):