Create Python Operations for the TF_Operations created by import_graph_def.
This change also introduces Python functionality for iterating through every TF_Operation in the graph and every newly-added TF_Operation via TF_GraphNextOperation. PiperOrigin-RevId: 176694180
This commit is contained in:
parent
bb287e33f7
commit
8067aa0862
@ -315,6 +315,24 @@ tensorflow::ImportNumpy();
|
||||
$2 = inputs.size();
|
||||
}
|
||||
|
||||
// Typemaps for TF_GraphNextOperation().
|
||||
%typemap(in) size_t* pos (size_t pos) {
|
||||
pos = PyLong_AsUnsignedLong($input);
|
||||
$1 = &pos;
|
||||
}
|
||||
|
||||
// Returns a (TF_Operation*, int pos) tuple.
|
||||
%typemap(argout) size_t* pos {
|
||||
PyObject* new_result = PyTuple_New(2);
|
||||
if (!new_result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create tuple");
|
||||
}
|
||||
// Steals $result reference
|
||||
PyTuple_SET_ITEM(new_result, 0, $result);
|
||||
PyTuple_SET_ITEM(new_result, 1, PyLong_FromSize_t(*$1));
|
||||
$result = new_result;
|
||||
}
|
||||
|
||||
// TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams,
|
||||
// skip for now
|
||||
%ignore TF_WhileParams;
|
||||
|
@ -110,3 +110,41 @@ def tf_output(c_op, index):
|
||||
ret.oper = c_op
|
||||
ret.index = index
|
||||
return ret
|
||||
|
||||
|
||||
def tf_operations(graph):
|
||||
"""Generator that yields every TF_Operation in `graph`.
|
||||
|
||||
Args:
|
||||
graph: Graph
|
||||
|
||||
Yields:
|
||||
wrapped TF_Operation
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
pos = 0
|
||||
c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
|
||||
while c_op is not None:
|
||||
yield c_op
|
||||
c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def new_tf_operations(graph):
|
||||
"""Generator that yields newly-added TF_Operations in `graph`.
|
||||
|
||||
Specifically, yields TF_Operations that don't have associated Operations in
|
||||
`graph`. This is useful for processing nodes added by the C API.
|
||||
|
||||
Args:
|
||||
graph: Graph
|
||||
|
||||
Yields:
|
||||
wrapped TF_Operation
|
||||
"""
|
||||
# TODO(b/69679162): do this more efficiently
|
||||
for c_op in tf_operations(graph):
|
||||
try:
|
||||
graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access
|
||||
except KeyError:
|
||||
yield c_op
|
||||
|
@ -194,6 +194,14 @@ def _FindAttrInOpDef(attr_name, op_def):
|
||||
return None
|
||||
|
||||
|
||||
def _ProcessNewOps(graph):
|
||||
"""Processes the newly-added TF_Operations in `graph`."""
|
||||
for c_op in c_api_util.new_tf_operations(graph):
|
||||
graph._create_op_from_tf_operation(c_op) # pylint: disable=protected-access
|
||||
|
||||
# TODO(skyewm): colocation logic
|
||||
|
||||
|
||||
@deprecated_args(None, 'Please file an issue at '
|
||||
'https://github.com/tensorflow/tensorflow/issues if you depend'
|
||||
' on this feature.',
|
||||
@ -257,11 +265,13 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
if graph._c_graph: # pylint: disable=protected-access
|
||||
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
|
||||
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
|
||||
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_GraphImportGraphDefWithResults(
|
||||
graph._c_graph, serialized, scoped_options.options, status) # pylint: disable=protected-access
|
||||
|
||||
_ProcessNewOps(graph)
|
||||
|
||||
if return_elements is not None:
|
||||
raise ValueError('return_elements not yet implemented with C API')
|
||||
return None
|
||||
|
@ -65,19 +65,58 @@ class ImportGraphDefTest(test.TestCase):
|
||||
importer.import_graph_def(
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'IntOutputFloatOutput' }
|
||||
node { name: 'B' op: 'ListOutput'
|
||||
attr { key: 'T'
|
||||
value { list { type: DT_INT32 type: DT_FLOAT } } } }
|
||||
node { name: 'C' op: 'ListInput'
|
||||
attr { key: 'N' value { i: 2 } }
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:0' input: 'B:0' }
|
||||
node { name: 'D' op: 'ListInput'
|
||||
attr { key: 'N' value { i: 2 } }
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
input: 'A:1' input: 'B:1' }
|
||||
node { name: 'B' op: 'ListOutput'
|
||||
attr { key: 'T'
|
||||
value { list { type: DT_INT32 type: DT_FLOAT } } } }
|
||||
node { name: 'C' op: 'ListInput'
|
||||
attr { key: 'N' value { i: 2 } }
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:0' input: 'B:0' }
|
||||
node { name: 'D' op: 'ListInput'
|
||||
attr { key: 'N' value { i: 2 } }
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
input: 'A:1' input: 'B:1' }
|
||||
"""))
|
||||
|
||||
graph = ops.get_default_graph()
|
||||
a = graph.get_operation_by_name("A")
|
||||
b = graph.get_operation_by_name("B")
|
||||
c = graph.get_operation_by_name("C")
|
||||
d = graph.get_operation_by_name("D")
|
||||
|
||||
# Assert that the import process creates distinct tensors.
|
||||
self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
|
||||
self.assertNotEqual(b.outputs[0].name, b.outputs[1].name)
|
||||
self.assertNotEqual(a.outputs[0].name, b.outputs[0].name)
|
||||
self.assertNotEqual(a.outputs[0].name, b.outputs[1].name)
|
||||
self.assertNotEqual(a.outputs[1].name, b.outputs[0].name)
|
||||
self.assertNotEqual(a.outputs[1].name, b.outputs[1].name)
|
||||
|
||||
# Assert that the ops are connected according to the GraphDef topology.
|
||||
self.assertEqual(c.inputs[0], a.outputs[0])
|
||||
self.assertEqual(c.inputs[1], b.outputs[0])
|
||||
self.assertEqual(d.inputs[0], a.outputs[1])
|
||||
self.assertEqual(d.inputs[1], b.outputs[1])
|
||||
|
||||
# Check the types of the returned ops and tensors.
|
||||
self.assertEqual(a.type, "IntOutputFloatOutput")
|
||||
self.assertEqual(b.type, "ListOutput")
|
||||
self.assertEqual(c.type, "ListInput")
|
||||
self.assertEqual(d.type, "ListInput")
|
||||
self.assertEqual(a.outputs[0].dtype, dtypes.int32)
|
||||
self.assertEqual(a.outputs[1].dtype, dtypes.float32)
|
||||
self.assertEqual(b.outputs[0].dtype, dtypes.int32)
|
||||
self.assertEqual(b.outputs[1].dtype, dtypes.float32)
|
||||
|
||||
# Check the names of the returned ops.
|
||||
self.assertEqual(a.name, "A")
|
||||
self.assertEqual(b.name, "B")
|
||||
self.assertEqual(c.name, "C")
|
||||
self.assertEqual(d.name, "D")
|
||||
|
||||
# Check that the op_def is still available.
|
||||
self.assertNotEqual(None, a.op_def)
|
||||
|
||||
def testBasic(self):
|
||||
with ops.Graph().as_default():
|
||||
a, b, c, d = importer.import_graph_def(
|
||||
|
Loading…
x
Reference in New Issue
Block a user