Give clear errors for bad input names.
PiperOrigin-RevId: 155857515
This commit is contained in:
parent
8a13f77eda
commit
ebb278520a
@ -41,14 +41,26 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
|
||||
a list that specifies one value per input node name.
|
||||
|
||||
Returns:
|
||||
A GraphDef with all unnecessary ops removed.
|
||||
A `GraphDef` with all unnecessary ops removed.
|
||||
|
||||
Raises:
|
||||
ValueError: If any element in `input_node_names` refers to a tensor instead
|
||||
of an operation.
|
||||
KeyError: If any element in `input_node_names` is not found in the graph.
|
||||
"""
|
||||
for name in input_node_names:
|
||||
if ":" in name:
|
||||
raise ValueError("Name '%s' appears to refer to a Tensor, "
|
||||
"not a Operation." % name)
|
||||
|
||||
# Here we replace the nodes we're going to override as inputs with
|
||||
# placeholders so that any unused nodes that are inputs to them are
|
||||
# automatically stripped out by extract_sub_graph().
|
||||
not_found = {name for name in input_node_names}
|
||||
inputs_replaced_graph_def = graph_pb2.GraphDef()
|
||||
for node in input_graph_def.node:
|
||||
if node.name in input_node_names:
|
||||
not_found.remove(node.name)
|
||||
placeholder_node = node_def_pb2.NodeDef()
|
||||
placeholder_node.op = "Placeholder"
|
||||
placeholder_node.name = node.name
|
||||
@ -67,6 +79,9 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
|
||||
else:
|
||||
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
|
||||
|
||||
if not_found:
|
||||
raise KeyError("The following input nodes were not found: %s\n" % not_found)
|
||||
|
||||
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
|
||||
output_node_names)
|
||||
return output_graph_def
|
||||
|
@ -58,16 +58,25 @@ class StripUnusedTest(test_util.TensorFlowTestCase):
|
||||
# routine.
|
||||
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
|
||||
input_binary = False
|
||||
input_node_names = "wanted_input_node"
|
||||
output_binary = True
|
||||
output_node_names = "output_node"
|
||||
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
|
||||
|
||||
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
|
||||
output_graph_path, output_binary,
|
||||
input_node_names,
|
||||
output_node_names,
|
||||
dtypes.float32.as_datatype_enum)
|
||||
def strip(input_node_names):
|
||||
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
|
||||
output_graph_path, output_binary,
|
||||
input_node_names,
|
||||
output_node_names,
|
||||
dtypes.float32.as_datatype_enum)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
strip("does_not_exist")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
strip("wanted_input_node:0")
|
||||
|
||||
input_node_names = "wanted_input_node"
|
||||
strip(input_node_names)
|
||||
|
||||
# Now we make sure the variable is now a constant, and that the graph still
|
||||
# produces the expected result.
|
||||
|
Loading…
Reference in New Issue
Block a user