Break out some import_graph_def logic into helper functions

PiperOrigin-RevId: 174526258
This commit is contained in:
Skye Wanderman-Milne 2017-11-03 15:56:09 -07:00 committed by TensorFlower Gardener
parent 011953754a
commit 743c12a10c

View File

@ -147,6 +147,43 @@ def _MaybeDevice(device):
yield
def _ProcessGraphDefParam(graph_def):
"""Type-checks and possibly canonicalizes `graph_def`."""
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
return graph_def
def _ProcessInputMapParam(input_map):
"""Type-checks and possibly canonicalizes `input_map`."""
if input_map is None:
input_map = {}
else:
if not (isinstance(input_map, dict)
and all(isinstance(k, compat.bytes_or_text_types)
for k in input_map.keys())):
raise TypeError('input_map must be a dictionary mapping strings to '
'Tensor objects.')
return input_map
def _ProcessReturnElementsParam(return_elements):
"""Type-checks and possibly canonicalizes `return_elements`."""
if return_elements is not None:
return_elements = tuple(return_elements)
if not all(isinstance(x, compat.bytes_or_text_types)
for x in return_elements):
raise TypeError('return_elements must be a list of strings.')
return return_elements
def _FindAttrInOpDef(attr_name, op_def):
for attr_def in op_def.attr:
if attr_name == attr_def.name:
@ -201,29 +238,9 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
"""
# Type checks for inputs.
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
if input_map is None:
input_map = {}
else:
if not (isinstance(input_map, dict)
and all(isinstance(k, compat.bytes_or_text_types)
for k in input_map.keys())):
raise TypeError('input_map must be a dictionary mapping strings to '
'Tensor objects.')
if return_elements is not None:
return_elements = tuple(return_elements)
if not all(isinstance(x, compat.bytes_or_text_types)
for x in return_elements):
raise TypeError('return_elements must be a list of strings.')
graph_def = _ProcessGraphDefParam(graph_def)
input_map = _ProcessInputMapParam(input_map)
return_elements = _ProcessReturnElementsParam(return_elements)
# Use a canonical representation for all tensor names.
input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}