Add _SwitchN op to the handled op lists in convert_variables_to_constants_v2.
Grappler transforms tf.switch_case() to use _SwitchN instead of Case op, and because _SwitchN is currently not handled convert_variables_to_constants_v2() fails when function contains tf.switch_case. PiperOrigin-RevId: 317707465 Change-Id: I3524c0c50f0acd02c2347c3c4f49ece5a5d671f7
This commit is contained in:
parent
9ca89a201c
commit
af8f596d21
@ -235,7 +235,8 @@ class _Node(_Convertible):
|
|||||||
return _If(node, function, enclosing_graph)
|
return _If(node, function, enclosing_graph)
|
||||||
elif node.op in ["While", "StatelessWhile"]:
|
elif node.op in ["While", "StatelessWhile"]:
|
||||||
return _While(node, function, enclosing_graph)
|
return _While(node, function, enclosing_graph)
|
||||||
elif node.op in ["Enter", "Exit", "Identity", "NextIteration", "Switch"]:
|
elif node.op in [
|
||||||
|
"Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]:
|
||||||
return _Intermediate(node, function, enclosing_graph)
|
return _Intermediate(node, function, enclosing_graph)
|
||||||
else:
|
else:
|
||||||
return _Node(node, function, enclosing_graph)
|
return _Node(node, function, enclosing_graph)
|
||||||
|
@ -486,6 +486,40 @@ class VariablesToConstantsTest(test.TestCase):
|
|||||||
root, output_func = self._freezeModel(model)
|
root, output_func = self._freezeModel(model)
|
||||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testSwitchCase(self):
|
||||||
|
"""Test a switch_case statement."""
|
||||||
|
input_data = {
|
||||||
|
"i": constant_op.constant(np.random.randint(0, 3, dtype=np.int32)),
|
||||||
|
"x": constant_op.constant(
|
||||||
|
np.asarray(np.random.random_sample((10, 3)), dtype=np.float32)),
|
||||||
|
}
|
||||||
|
|
||||||
|
w0 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
|
||||||
|
w1 = variables.Variable(np.random.random_sample((3, 4)), dtype=np.float32)
|
||||||
|
w2 = variables.Variable(np.random.random_sample((4,)), dtype=np.float32)
|
||||||
|
|
||||||
|
def branch0(x):
|
||||||
|
return math_ops.matmul(x, w0)
|
||||||
|
|
||||||
|
def branch1(x):
|
||||||
|
return math_ops.matmul(x, w1)
|
||||||
|
|
||||||
|
def branch2(x):
|
||||||
|
x = array_ops.pad(x, [[0, 0], [0, 1]])
|
||||||
|
return x + w2
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
|
||||||
|
tensor_spec.TensorSpec(shape=[10, 3], dtype=dtypes.float32),
|
||||||
|
])
|
||||||
|
def model(i, x):
|
||||||
|
return control_flow_ops.switch_case(i, [
|
||||||
|
lambda: branch0(x), lambda: branch1(x), lambda: branch2(x)])
|
||||||
|
|
||||||
|
root, output_func = self._freezeModel(model)
|
||||||
|
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||||
|
|
||||||
|
|
||||||
class ConvertVariablesToConstantsSessionTest(test.TestCase):
|
class ConvertVariablesToConstantsSessionTest(test.TestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user