Internal change on model loading.
PiperOrigin-RevId: 356429530 Change-Id: I1416414927aed1ff12ab3491e66140da1a8defde
This commit is contained in:
parent
81e97f2832
commit
f47acb1257
@ -136,7 +136,9 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
signature_functions = {}
|
||||
for signature_key, signature_def in meta_graph_def.signature_def.items():
|
||||
if signature_def.inputs:
|
||||
original_input_names, input_specs = zip(*signature_def.inputs.items())
|
||||
input_items = sorted(
|
||||
signature_def.inputs.items(), key=lambda item: item[1].name)
|
||||
original_input_names, input_specs = zip(*input_items)
|
||||
else:
|
||||
original_input_names = []
|
||||
input_specs = []
|
||||
|
@ -660,6 +660,41 @@ class LoadTest(test.TestCase):
|
||||
self.assertAllEqual(
|
||||
kwargs, {"start": tensor_spec.TensorSpec(shape=None, name="start")})
|
||||
|
||||
def _v1_multi_input_saved_model(self):
|
||||
export_graph = ops.Graph()
|
||||
with export_graph.as_default():
|
||||
input1 = array_ops.placeholder(
|
||||
shape=[None], dtype=dtypes.float32, name="input1")
|
||||
input2 = array_ops.placeholder(
|
||||
shape=[None], dtype=dtypes.float32, name="input2")
|
||||
v = resource_variable_ops.ResourceVariable(21.)
|
||||
output = array_ops.identity(input1 * v + input2, name="output")
|
||||
with session_lib.Session() as session:
|
||||
session.run(v.initializer)
|
||||
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
|
||||
builder = builder_impl.SavedModelBuilder(path)
|
||||
builder.add_meta_graph_and_variables(
|
||||
session,
|
||||
tags=[tag_constants.SERVING],
|
||||
signature_def_map={
|
||||
"serving_default":
|
||||
signature_def_utils.build_signature_def(
|
||||
{
|
||||
"input1": utils_impl.build_tensor_info(input1),
|
||||
"input2": utils_impl.build_tensor_info(input2)
|
||||
}, {"output": utils_impl.build_tensor_info(output)})
|
||||
})
|
||||
builder.save()
|
||||
return path
|
||||
|
||||
def test_v1_input_ordered(self):
|
||||
path = self._v1_multi_input_saved_model()
|
||||
imported = load.load(path)
|
||||
self.assertEqual(imported.signatures["serving_default"].inputs[0].name,
|
||||
"input1:0")
|
||||
self.assertEqual(imported.signatures["serving_default"].inputs[1].name,
|
||||
"input2:0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user