Fix save model issue for ops with a list of functions

PiperOrigin-RevId: 315496681
Change-Id: I622550d1a073e4c21c3c7af625cf76481e365dbc
This commit is contained in:
Yanhua Sun 2020-06-09 09:16:36 -07:00 committed by TensorFlower Gardener
parent 3217649491
commit 75f09e6973
2 changed files with 27 additions and 1 deletions

View File

@ -400,8 +400,11 @@ def fix_node_def(node_def, functions, shared_name_suffix, debug_name):
if node_def.op in functions: if node_def.op in functions:
node_def.op = functions[node_def.op].name node_def.op = functions[node_def.op].name
for _, attr_value in node_def.attr.items(): for _, attr_value in node_def.attr.items():
if attr_value.func.name: if attr_value.WhichOneof("value") == "func":
attr_value.func.name = functions[attr_value.func.name].name attr_value.func.name = functions[attr_value.func.name].name
elif attr_value.WhichOneof("value") == "list":
for fn in attr_value.list.func:
fn.name = functions[fn.name].name
# Fix old table creation bug. # Fix old table creation bug.
if node_def.op == "HashTableV2": if node_def.op == "HashTableV2":
@ -471,6 +474,10 @@ def _list_function_deps(fdef, library_function_names):
for _, attr_value in node_def.attr.items(): for _, attr_value in node_def.attr.items():
if attr_value.WhichOneof("value") == "func": if attr_value.WhichOneof("value") == "func":
deps.add(attr_value.func.name) deps.add(attr_value.func.name)
elif attr_value.WhichOneof("value") == "list":
for fn in attr_value.list.func:
deps.add(fn.name)
return deps return deps

View File

@ -44,6 +44,7 @@ from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
from tensorflow.python.module import module from tensorflow.python.module import module
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
@ -117,6 +118,24 @@ class SaveTest(test.TestCase):
{"output_0": 2.}, {"output_0": 2.},
_import_and_infer(save_dir, {"x": 1.})) _import_and_infer(save_dir, {"x": 1.}))
def test_method_save_list_func(self):
root = tracking.AutoTrackable()
@def_function.function
def case_fn(x):
branch_index = constant_op.constant(1)
branches = [lambda: x, lambda: x + 1]
case_out = control_flow_ops.switch_case(branch_index, branches)
return case_out
root.f = def_function.function(
lambda x: 2. * case_fn(x),
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root.f(constant_op.constant(1.))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir, root.f)
self.assertEqual({"output_0": 4.}, _import_and_infer(save_dir, {"x": 1.}))
def test_method_save_concrete(self): def test_method_save_concrete(self):
root = tracking.AutoTrackable() root = tracking.AutoTrackable()
root.f = def_function.function( root.f = def_function.function(