Add _SwitchN op to the handled op lists in convert_variables_to_constants_v2.
Grappler transforms tf.switch_case() to use _SwitchN instead of Case op, and because _SwitchN is currently not handled convert_variables_to_constants_v2() fails when function contains tf.switch_case. PiperOrigin-RevId: 317707465 Change-Id: I3524c0c50f0acd02c2347c3c4f49ece5a5d671f7
This commit is contained in:
parent
9ca89a201c
commit
af8f596d21
tensorflow/python/framework
@ -235,7 +235,8 @@ class _Node(_Convertible):
|
||||
return _If(node, function, enclosing_graph)
|
||||
elif node.op in ["While", "StatelessWhile"]:
|
||||
return _While(node, function, enclosing_graph)
|
||||
elif node.op in ["Enter", "Exit", "Identity", "NextIteration", "Switch"]:
|
||||
elif node.op in [
|
||||
"Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]:
|
||||
return _Intermediate(node, function, enclosing_graph)
|
||||
else:
|
||||
return _Node(node, function, enclosing_graph)
|
||||
|
@ -486,6 +486,40 @@ class VariablesToConstantsTest(test.TestCase):
|
||||
root, output_func = self._freezeModel(model)
|
||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testSwitchCase(self):
|
||||
"""Test a switch_case statement."""
|
||||
input_data = {
|
||||
"i": constant_op.constant(np.random.randint(0, 3, dtype=np.int32)),
|
||||
"x": constant_op.constant(
|
||||
np.asarray(np.random.random_sample((10, 3)), dtype=np.float32)),
|
||||
}
|
||||
|
||||
w0 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
|
||||
w1 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
|
||||
w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32)
|
||||
|
||||
def branch0(x):
|
||||
return math_ops.matmul(x, w0)
|
||||
|
||||
def branch1(x):
|
||||
return math_ops.matmul(x, w1)
|
||||
|
||||
def branch2(x):
|
||||
x = array_ops.pad(x, [[0, 0], [0, 1]])
|
||||
return x + w2
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
|
||||
tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32),
|
||||
])
|
||||
def model(i, x):
|
||||
return control_flow_ops.switch_case(i, [
|
||||
lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])
|
||||
|
||||
root, output_func = self._freezeModel(model)
|
||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||
|
||||
|
||||
class ConvertVariablesToConstantsSessionTest(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user