Give clear errors for bad input names.

PiperOrigin-RevId: 155857515
This commit is contained in:
Mark Daoust 2017-05-12 05:29:21 -07:00 committed by TensorFlower Gardener
parent 8a13f77eda
commit ebb278520a
2 changed files with 31 additions and 7 deletions

View File

@ -41,14 +41,26 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
a list that specifies one value per input node name.
Returns:
A GraphDef with all unnecessary ops removed.
A `GraphDef` with all unnecessary ops removed.
Raises:
ValueError: If any element in `input_node_names` refers to a tensor instead
of an operation.
KeyError: If any element in `input_node_names` is not found in the graph.
"""
for name in input_node_names:
if ":" in name:
raise ValueError("Name '%s' appears to refer to a Tensor, "
"not a Operation." % name)
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
not_found = {name for name in input_node_names}
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
@ -67,6 +79,9 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
if not_found:
raise KeyError("The following input nodes were not found: %s\n" % not_found)
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def

View File

@ -58,16 +58,25 @@ class StripUnusedTest(test_util.TensorFlowTestCase):
# routine.
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
input_binary = False
input_node_names = "wanted_input_node"
output_binary = True
output_node_names = "output_node"
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
output_graph_path, output_binary,
input_node_names,
output_node_names,
dtypes.float32.as_datatype_enum)
def strip(input_node_names):
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
output_graph_path, output_binary,
input_node_names,
output_node_names,
dtypes.float32.as_datatype_enum)
with self.assertRaises(KeyError):
strip("does_not_exist")
with self.assertRaises(ValueError):
strip("wanted_input_node:0")
input_node_names = "wanted_input_node"
strip(input_node_names)
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.