Change filename from constant to placeholder with default so that grappler can run on graph restoration.

PiperOrigin-RevId: 221205392
This commit is contained in:
A. Unique TensorFlower 2018-11-12 20:41:28 -08:00 committed by TensorFlower Gardener
parent cca5a0bfe6
commit 45622121ac
2 changed files with 13 additions and 8 deletions

View File

@ -781,8 +781,12 @@ class BaseSaverBuilder(object):
with ops.name_scope(name, "save", with ops.name_scope(name, "save",
[saveable.op for saveable in saveables]) as name: [saveable.op for saveable in saveables]) as name:
# Add the Constant string tensor for the filename. # Add a placeholder string tensor for the filename.
filename_tensor = constant_op.constant(filename or "model") filename_tensor = array_ops.placeholder_with_default(
filename or "model", shape=(), name="filename")
# Keep the name "Const" for backwards compatibility.
filename_tensor = array_ops.placeholder_with_default(
filename_tensor, shape=(), name="Const")
# Add the save ops. # Add the save ops.
if sharded: if sharded:

View File

@ -1832,8 +1832,8 @@ class MetaGraphTest(test.TestCase):
self.assertEqual(1, len(savers.value)) self.assertEqual(1, len(savers.value))
# Verifies that saver0 graph nodes are omitted from the saver1 export # Verifies that saver0 graph nodes are omitted from the saver1 export
self.assertEqual(29, len(meta_graph_def0.graph_def.node)) self.assertEqual(33, len(meta_graph_def0.graph_def.node))
self.assertEqual(19, len(meta_graph_def1.graph_def.node)) self.assertEqual(21, len(meta_graph_def1.graph_def.node))
def testBinaryAndTextFormat(self): def testBinaryAndTextFormat(self):
test_dir = self._get_test_dir("binary_and_text") test_dir = self._get_test_dir("binary_and_text")
@ -2140,13 +2140,14 @@ class MetaGraphTest(test.TestCase):
ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op] ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
if save._write_version is saver_pb2.SaverDef.V1: if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(ops, [ self.assertEqual(ops, [
"Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", "Add", "Assign", "Const", "Identity", "NoOp",
"SaveSlices", "Sub", "VariableV2" "PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub",
"VariableV2"
]) ])
else: else:
self.assertEqual(ops, [ self.assertEqual(ops, [
"Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", "SaveV2", "Add", "Assign", "Const", "Identity", "NoOp",
"Sub", "VariableV2" "PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2"
]) ])
# Test calling stripped_op_list_for_graph directly # Test calling stripped_op_list_for_graph directly