Fix save model issue for ops with a list of functions
PiperOrigin-RevId: 315496681 Change-Id: I622550d1a073e4c21c3c7af625cf76481e365dbc
This commit is contained in:
parent
3217649491
commit
75f09e6973
|
@ -400,8 +400,11 @@ def fix_node_def(node_def, functions, shared_name_suffix, debug_name):
|
||||||
if node_def.op in functions:
|
if node_def.op in functions:
|
||||||
node_def.op = functions[node_def.op].name
|
node_def.op = functions[node_def.op].name
|
||||||
for _, attr_value in node_def.attr.items():
|
for _, attr_value in node_def.attr.items():
|
||||||
if attr_value.func.name:
|
if attr_value.WhichOneof("value") == "func":
|
||||||
attr_value.func.name = functions[attr_value.func.name].name
|
attr_value.func.name = functions[attr_value.func.name].name
|
||||||
|
elif attr_value.WhichOneof("value") == "list":
|
||||||
|
for fn in attr_value.list.func:
|
||||||
|
fn.name = functions[fn.name].name
|
||||||
|
|
||||||
# Fix old table creation bug.
|
# Fix old table creation bug.
|
||||||
if node_def.op == "HashTableV2":
|
if node_def.op == "HashTableV2":
|
||||||
|
@ -471,6 +474,10 @@ def _list_function_deps(fdef, library_function_names):
|
||||||
for _, attr_value in node_def.attr.items():
|
for _, attr_value in node_def.attr.items():
|
||||||
if attr_value.WhichOneof("value") == "func":
|
if attr_value.WhichOneof("value") == "func":
|
||||||
deps.add(attr_value.func.name)
|
deps.add(attr_value.func.name)
|
||||||
|
elif attr_value.WhichOneof("value") == "list":
|
||||||
|
for fn in attr_value.list.func:
|
||||||
|
deps.add(fn.name)
|
||||||
|
|
||||||
return deps
|
return deps
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.module import module
|
from tensorflow.python.module import module
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
@ -117,6 +118,24 @@ class SaveTest(test.TestCase):
|
||||||
{"output_0": 2.},
|
{"output_0": 2.},
|
||||||
_import_and_infer(save_dir, {"x": 1.}))
|
_import_and_infer(save_dir, {"x": 1.}))
|
||||||
|
|
||||||
|
def test_method_save_list_func(self):
|
||||||
|
root = tracking.AutoTrackable()
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def case_fn(x):
|
||||||
|
branch_index = constant_op.constant(1)
|
||||||
|
branches = [lambda: x, lambda: x + 1]
|
||||||
|
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||||
|
return case_out
|
||||||
|
|
||||||
|
root.f = def_function.function(
|
||||||
|
lambda x: 2. * case_fn(x),
|
||||||
|
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||||
|
root.f(constant_op.constant(1.))
|
||||||
|
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||||
|
save.save(root, save_dir, root.f)
|
||||||
|
self.assertEqual({"output_0": 4.}, _import_and_infer(save_dir, {"x": 1.}))
|
||||||
|
|
||||||
def test_method_save_concrete(self):
|
def test_method_save_concrete(self):
|
||||||
root = tracking.AutoTrackable()
|
root = tracking.AutoTrackable()
|
||||||
root.f = def_function.function(
|
root.f = def_function.function(
|
||||||
|
|
Loading…
Reference in New Issue