Break out some import_graph_def logic into helper functions
PiperOrigin-RevId: 174526258
This commit is contained in:
parent
011953754a
commit
743c12a10c
@ -147,6 +147,43 @@ def _MaybeDevice(device):
|
||||
yield
|
||||
|
||||
|
||||
def _ProcessGraphDefParam(graph_def):
|
||||
"""Type-checks and possibly canonicalizes `graph_def`."""
|
||||
if not isinstance(graph_def, graph_pb2.GraphDef):
|
||||
# `graph_def` could be a dynamically-created message, so try a duck-typed
|
||||
# approach
|
||||
try:
|
||||
old_graph_def = graph_def
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
graph_def.MergeFrom(old_graph_def)
|
||||
except TypeError:
|
||||
raise TypeError('graph_def must be a GraphDef proto.')
|
||||
return graph_def
|
||||
|
||||
|
||||
def _ProcessInputMapParam(input_map):
|
||||
"""Type-checks and possibly canonicalizes `input_map`."""
|
||||
if input_map is None:
|
||||
input_map = {}
|
||||
else:
|
||||
if not (isinstance(input_map, dict)
|
||||
and all(isinstance(k, compat.bytes_or_text_types)
|
||||
for k in input_map.keys())):
|
||||
raise TypeError('input_map must be a dictionary mapping strings to '
|
||||
'Tensor objects.')
|
||||
return input_map
|
||||
|
||||
|
||||
def _ProcessReturnElementsParam(return_elements):
|
||||
"""Type-checks and possibly canonicalizes `return_elements`."""
|
||||
if return_elements is not None:
|
||||
return_elements = tuple(return_elements)
|
||||
if not all(isinstance(x, compat.bytes_or_text_types)
|
||||
for x in return_elements):
|
||||
raise TypeError('return_elements must be a list of strings.')
|
||||
return return_elements
|
||||
|
||||
|
||||
def _FindAttrInOpDef(attr_name, op_def):
|
||||
for attr_def in op_def.attr:
|
||||
if attr_name == attr_def.name:
|
||||
@ -201,29 +238,9 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
|
||||
it refers to an unknown tensor).
|
||||
"""
|
||||
# Type checks for inputs.
|
||||
if not isinstance(graph_def, graph_pb2.GraphDef):
|
||||
# `graph_def` could be a dynamically-created message, so try a duck-typed
|
||||
# approach
|
||||
try:
|
||||
old_graph_def = graph_def
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
graph_def.MergeFrom(old_graph_def)
|
||||
except TypeError:
|
||||
raise TypeError('graph_def must be a GraphDef proto.')
|
||||
if input_map is None:
|
||||
input_map = {}
|
||||
else:
|
||||
if not (isinstance(input_map, dict)
|
||||
and all(isinstance(k, compat.bytes_or_text_types)
|
||||
for k in input_map.keys())):
|
||||
raise TypeError('input_map must be a dictionary mapping strings to '
|
||||
'Tensor objects.')
|
||||
if return_elements is not None:
|
||||
return_elements = tuple(return_elements)
|
||||
if not all(isinstance(x, compat.bytes_or_text_types)
|
||||
for x in return_elements):
|
||||
raise TypeError('return_elements must be a list of strings.')
|
||||
graph_def = _ProcessGraphDefParam(graph_def)
|
||||
input_map = _ProcessInputMapParam(input_map)
|
||||
return_elements = _ProcessReturnElementsParam(return_elements)
|
||||
|
||||
# Use a canonical representation for all tensor names.
|
||||
input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
|
||||
|
Loading…
Reference in New Issue
Block a user