Fix save model issue for ops with a list of functions
PiperOrigin-RevId: 315496681 Change-Id: I622550d1a073e4c21c3c7af625cf76481e365dbc
This commit is contained in:
parent
3217649491
commit
75f09e6973
|
@ -400,8 +400,11 @@ def fix_node_def(node_def, functions, shared_name_suffix, debug_name):
|
|||
if node_def.op in functions:
|
||||
node_def.op = functions[node_def.op].name
|
||||
for _, attr_value in node_def.attr.items():
|
||||
if attr_value.func.name:
|
||||
if attr_value.WhichOneof("value") == "func":
|
||||
attr_value.func.name = functions[attr_value.func.name].name
|
||||
elif attr_value.WhichOneof("value") == "list":
|
||||
for fn in attr_value.list.func:
|
||||
fn.name = functions[fn.name].name
|
||||
|
||||
# Fix old table creation bug.
|
||||
if node_def.op == "HashTableV2":
|
||||
|
@ -471,6 +474,10 @@ def _list_function_deps(fdef, library_function_names):
|
|||
for _, attr_value in node_def.attr.items():
|
||||
if attr_value.WhichOneof("value") == "func":
|
||||
deps.add(attr_value.func.name)
|
||||
elif attr_value.WhichOneof("value") == "list":
|
||||
for fn in attr_value.list.func:
|
||||
deps.add(fn.name)
|
||||
|
||||
return deps
|
||||
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ from tensorflow.python.keras.optimizer_v2 import adam
|
|||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
|
@ -117,6 +118,24 @@ class SaveTest(test.TestCase):
|
|||
{"output_0": 2.},
|
||||
_import_and_infer(save_dir, {"x": 1.}))
|
||||
|
||||
def test_method_save_list_func(self):
|
||||
root = tracking.AutoTrackable()
|
||||
|
||||
@def_function.function
|
||||
def case_fn(x):
|
||||
branch_index = constant_op.constant(1)
|
||||
branches = [lambda: x, lambda: x + 1]
|
||||
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||
return case_out
|
||||
|
||||
root.f = def_function.function(
|
||||
lambda x: 2. * case_fn(x),
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
root.f(constant_op.constant(1.))
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(root, save_dir, root.f)
|
||||
self.assertEqual({"output_0": 4.}, _import_and_infer(save_dir, {"x": 1.}))
|
||||
|
||||
def test_method_save_concrete(self):
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(
|
||||
|
|
Loading…
Reference in New Issue