Change filename from constant to placeholder with default so that grappler can run on graph restoration.

PiperOrigin-RevId: 221205392
This commit is contained in:
A. Unique TensorFlower 2018-11-12 20:41:28 -08:00 committed by TensorFlower Gardener
parent cca5a0bfe6
commit 45622121ac
2 changed files with 13 additions and 8 deletions

View File

@ -781,8 +781,12 @@ class BaseSaverBuilder(object):
with ops.name_scope(name, "save",
[saveable.op for saveable in saveables]) as name:
# Add the Constant string tensor for the filename.
filename_tensor = constant_op.constant(filename or "model")
# Add a placeholder string tensor for the filename.
filename_tensor = array_ops.placeholder_with_default(
filename or "model", shape=(), name="filename")
# Keep the name "Const" for backwards compatibility.
filename_tensor = array_ops.placeholder_with_default(
filename_tensor, shape=(), name="Const")
# Add the save ops.
if sharded:

View File

@ -1832,8 +1832,8 @@ class MetaGraphTest(test.TestCase):
self.assertEqual(1, len(savers.value))
# Verifies that saver0 graph nodes are omitted from the saver1 export
self.assertEqual(29, len(meta_graph_def0.graph_def.node))
self.assertEqual(19, len(meta_graph_def1.graph_def.node))
self.assertEqual(33, len(meta_graph_def0.graph_def.node))
self.assertEqual(21, len(meta_graph_def1.graph_def.node))
def testBinaryAndTextFormat(self):
test_dir = self._get_test_dir("binary_and_text")
@ -2140,13 +2140,14 @@ class MetaGraphTest(test.TestCase):
ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(ops, [
"Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2",
"SaveSlices", "Sub", "VariableV2"
"Add", "Assign", "Const", "Identity", "NoOp",
"PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub",
"VariableV2"
])
else:
self.assertEqual(ops, [
"Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", "SaveV2",
"Sub", "VariableV2"
"Add", "Assign", "Const", "Identity", "NoOp",
"PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2"
])
# Test calling stripped_op_list_for_graph directly