Internal change on model loading.

PiperOrigin-RevId: 356429530
Change-Id: I1416414927aed1ff12ab3491e66140da1a8defde
This commit is contained in:
Hyeonjong Ryu 2021-02-08 22:23:47 -08:00 committed by TensorFlower Gardener
parent 81e97f2832
commit f47acb1257
2 changed files with 38 additions and 1 deletions

View File

@ -136,7 +136,9 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
signature_functions = {}
for signature_key, signature_def in meta_graph_def.signature_def.items():
if signature_def.inputs:
original_input_names, input_specs = zip(*signature_def.inputs.items())
input_items = sorted(
signature_def.inputs.items(), key=lambda item: item[1].name)
original_input_names, input_specs = zip(*input_items)
else:
original_input_names = []
input_specs = []

View File

@ -660,6 +660,41 @@ class LoadTest(test.TestCase):
self.assertAllEqual(
kwargs, {"start": tensor_spec.TensorSpec(shape=None, name="start")})
def _v1_multi_input_saved_model(self):
export_graph = ops.Graph()
with export_graph.as_default():
input1 = array_ops.placeholder(
shape=[None], dtype=dtypes.float32, name="input1")
input2 = array_ops.placeholder(
shape=[None], dtype=dtypes.float32, name="input2")
v = resource_variable_ops.ResourceVariable(21.)
output = array_ops.identity(input1 * v + input2, name="output")
with session_lib.Session() as session:
session.run(v.initializer)
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
builder = builder_impl.SavedModelBuilder(path)
builder.add_meta_graph_and_variables(
session,
tags=[tag_constants.SERVING],
signature_def_map={
"serving_default":
signature_def_utils.build_signature_def(
{
"input1": utils_impl.build_tensor_info(input1),
"input2": utils_impl.build_tensor_info(input2)
}, {"output": utils_impl.build_tensor_info(output)})
})
builder.save()
return path
def test_v1_input_ordered(self):
path = self._v1_multi_input_saved_model()
imported = load.load(path)
self.assertEqual(imported.signatures["serving_default"].inputs[0].name,
"input1:0")
self.assertEqual(imported.signatures["serving_default"].inputs[1].name,
"input2:0")
if __name__ == "__main__":
test.main()