Change filename from constant to placeholder with default so that grappler can run on graph restoration.
PiperOrigin-RevId: 221205392
This commit is contained in:
parent
cca5a0bfe6
commit
45622121ac
@ -781,8 +781,12 @@ class BaseSaverBuilder(object):
|
|||||||
|
|
||||||
with ops.name_scope(name, "save",
|
with ops.name_scope(name, "save",
|
||||||
[saveable.op for saveable in saveables]) as name:
|
[saveable.op for saveable in saveables]) as name:
|
||||||
# Add the Constant string tensor for the filename.
|
# Add a placeholder string tensor for the filename.
|
||||||
filename_tensor = constant_op.constant(filename or "model")
|
filename_tensor = array_ops.placeholder_with_default(
|
||||||
|
filename or "model", shape=(), name="filename")
|
||||||
|
# Keep the name "Const" for backwards compatibility.
|
||||||
|
filename_tensor = array_ops.placeholder_with_default(
|
||||||
|
filename_tensor, shape=(), name="Const")
|
||||||
|
|
||||||
# Add the save ops.
|
# Add the save ops.
|
||||||
if sharded:
|
if sharded:
|
||||||
|
@ -1832,8 +1832,8 @@ class MetaGraphTest(test.TestCase):
|
|||||||
self.assertEqual(1, len(savers.value))
|
self.assertEqual(1, len(savers.value))
|
||||||
|
|
||||||
# Verifies that saver0 graph nodes are omitted from the saver1 export
|
# Verifies that saver0 graph nodes are omitted from the saver1 export
|
||||||
self.assertEqual(29, len(meta_graph_def0.graph_def.node))
|
self.assertEqual(33, len(meta_graph_def0.graph_def.node))
|
||||||
self.assertEqual(19, len(meta_graph_def1.graph_def.node))
|
self.assertEqual(21, len(meta_graph_def1.graph_def.node))
|
||||||
|
|
||||||
def testBinaryAndTextFormat(self):
|
def testBinaryAndTextFormat(self):
|
||||||
test_dir = self._get_test_dir("binary_and_text")
|
test_dir = self._get_test_dir("binary_and_text")
|
||||||
@ -2140,13 +2140,14 @@ class MetaGraphTest(test.TestCase):
|
|||||||
ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
|
ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
|
||||||
if save._write_version is saver_pb2.SaverDef.V1:
|
if save._write_version is saver_pb2.SaverDef.V1:
|
||||||
self.assertEqual(ops, [
|
self.assertEqual(ops, [
|
||||||
"Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2",
|
"Add", "Assign", "Const", "Identity", "NoOp",
|
||||||
"SaveSlices", "Sub", "VariableV2"
|
"PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub",
|
||||||
|
"VariableV2"
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
self.assertEqual(ops, [
|
self.assertEqual(ops, [
|
||||||
"Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", "SaveV2",
|
"Add", "Assign", "Const", "Identity", "NoOp",
|
||||||
"Sub", "VariableV2"
|
"PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2"
|
||||||
])
|
])
|
||||||
|
|
||||||
# Test calling stripped_op_list_for_graph directly
|
# Test calling stripped_op_list_for_graph directly
|
||||||
|
Loading…
Reference in New Issue
Block a user