From 75f09e6973f84858906983e19d4544a70615ff7f Mon Sep 17 00:00:00 2001 From: Yanhua Sun Date: Tue, 9 Jun 2020 09:16:36 -0700 Subject: [PATCH] Fix save model issue for ops with a list of functions PiperOrigin-RevId: 315496681 Change-Id: I622550d1a073e4c21c3c7af625cf76481e365dbc --- .../saved_model/function_deserialization.py | 9 ++++++++- tensorflow/python/saved_model/save_test.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index dccb222c26e..63fa4a4acbd 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -400,8 +400,11 @@ def fix_node_def(node_def, functions, shared_name_suffix, debug_name): if node_def.op in functions: node_def.op = functions[node_def.op].name for _, attr_value in node_def.attr.items(): - if attr_value.func.name: + if attr_value.WhichOneof("value") == "func": attr_value.func.name = functions[attr_value.func.name].name + elif attr_value.WhichOneof("value") == "list": + for fn in attr_value.list.func: + fn.name = functions[fn.name].name # Fix old table creation bug. if node_def.op == "HashTableV2": @@ -471,6 +474,10 @@ def _list_function_deps(fdef, library_function_names): for _, attr_value in node_def.attr.items(): if attr_value.WhichOneof("value") == "func": deps.add(attr_value.func.name) + elif attr_value.WhichOneof("value") == "list": + for fn in attr_value.list.func: + deps.add(fn.name) + return deps diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 09e7296a483..f94cae8a4de 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -44,6 +44,7 @@ from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.lib.io import file_io from tensorflow.python.module import module from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -117,6 +118,24 @@ class SaveTest(test.TestCase): {"output_0": 2.}, _import_and_infer(save_dir, {"x": 1.})) + def test_method_save_list_func(self): + root = tracking.AutoTrackable() + + @def_function.function + def case_fn(x): + branch_index = constant_op.constant(1) + branches = [lambda: x, lambda: x + 1] + case_out = control_flow_ops.switch_case(branch_index, branches) + return case_out + + root.f = def_function.function( + lambda x: 2. * case_fn(x), + input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) + root.f(constant_op.constant(1.)) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(root, save_dir, root.f) + self.assertEqual({"output_0": 4.}, _import_and_infer(save_dir, {"x": 1.})) + def test_method_save_concrete(self): root = tracking.AutoTrackable() root.f = def_function.function(