Create Python Operations for the TF_Operations created by import_graph_def.

This change also introduces Python functionality for iterating through
every TF_Operation in the graph and every newly-added TF_Operation via
TF_GraphNextOperation.

PiperOrigin-RevId: 176694180
This commit is contained in:
Skye Wanderman-Milne 2017-11-22 13:26:25 -08:00 committed by TensorFlower Gardener
parent bb287e33f7
commit 8067aa0862
4 changed files with 118 additions and 13 deletions

View File

@ -315,6 +315,24 @@ tensorflow::ImportNumpy();
$2 = inputs.size();
}
// Typemaps for TF_GraphNextOperation().
%typemap(in) size_t* pos (size_t pos) {
pos = PyLong_AsUnsignedLong($input);
$1 = &pos;
}
// Returns a (TF_Operation*, int pos) tuple.
%typemap(argout) size_t* pos {
PyObject* new_result = PyTuple_New(2);
if (!new_result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create tuple");
}
// Steals $result reference
PyTuple_SET_ITEM(new_result, 0, $result);
PyTuple_SET_ITEM(new_result, 1, PyLong_FromSize_t(*$1));
$result = new_result;
}
// TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams,
// skip for now
%ignore TF_WhileParams;

View File

@ -110,3 +110,41 @@ def tf_output(c_op, index):
ret.oper = c_op
ret.index = index
return ret
def tf_operations(graph):
"""Generator that yields every TF_Operation in `graph`.
Args:
graph: Graph
Yields:
wrapped TF_Operation
"""
# pylint: disable=protected-access
pos = 0
c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
while c_op is not None:
yield c_op
c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
# pylint: enable=protected-access
def new_tf_operations(graph):
"""Generator that yields newly-added TF_Operations in `graph`.
Specifically, yields TF_Operations that don't have associated Operations in
`graph`. This is useful for processing nodes added by the C API.
Args:
graph: Graph
Yields:
wrapped TF_Operation
"""
# TODO(b/69679162): do this more efficiently
for c_op in tf_operations(graph):
try:
graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access
except KeyError:
yield c_op

View File

@ -194,6 +194,14 @@ def _FindAttrInOpDef(attr_name, op_def):
return None
def _ProcessNewOps(graph):
"""Processes the newly-added TF_Operations in `graph`."""
for c_op in c_api_util.new_tf_operations(graph):
graph._create_op_from_tf_operation(c_op) # pylint: disable=protected-access
# TODO(skyewm): colocation logic
@deprecated_args(None, 'Please file an issue at '
'https://github.com/tensorflow/tensorflow/issues if you depend'
' on this feature.',
@ -257,11 +265,13 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
if graph._c_graph: # pylint: disable=protected-access
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
with errors.raise_exception_on_not_ok_status() as status:
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, scoped_options.options, status) # pylint: disable=protected-access
_ProcessNewOps(graph)
if return_elements is not None:
raise ValueError('return_elements not yet implemented with C API')
return None

View File

@ -65,19 +65,58 @@ class ImportGraphDefTest(test.TestCase):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutputFloatOutput' }
node { name: 'B' op: 'ListOutput'
attr { key: 'T'
value { list { type: DT_INT32 type: DT_FLOAT } } } }
node { name: 'C' op: 'ListInput'
attr { key: 'N' value { i: 2 } }
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:0' input: 'B:0' }
node { name: 'D' op: 'ListInput'
attr { key: 'N' value { i: 2 } }
attr { key: 'T' value { type: DT_FLOAT } }
input: 'A:1' input: 'B:1' }
node { name: 'B' op: 'ListOutput'
attr { key: 'T'
value { list { type: DT_INT32 type: DT_FLOAT } } } }
node { name: 'C' op: 'ListInput'
attr { key: 'N' value { i: 2 } }
attr { key: 'T' value { type: DT_INT32 } }
input: 'A:0' input: 'B:0' }
node { name: 'D' op: 'ListInput'
attr { key: 'N' value { i: 2 } }
attr { key: 'T' value { type: DT_FLOAT } }
input: 'A:1' input: 'B:1' }
"""))
graph = ops.get_default_graph()
a = graph.get_operation_by_name("A")
b = graph.get_operation_by_name("B")
c = graph.get_operation_by_name("C")
d = graph.get_operation_by_name("D")
# Assert that the import process creates distinct tensors.
self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
self.assertNotEqual(b.outputs[0].name, b.outputs[1].name)
self.assertNotEqual(a.outputs[0].name, b.outputs[0].name)
self.assertNotEqual(a.outputs[0].name, b.outputs[1].name)
self.assertNotEqual(a.outputs[1].name, b.outputs[0].name)
self.assertNotEqual(a.outputs[1].name, b.outputs[1].name)
# Assert that the ops are connected according to the GraphDef topology.
self.assertEqual(c.inputs[0], a.outputs[0])
self.assertEqual(c.inputs[1], b.outputs[0])
self.assertEqual(d.inputs[0], a.outputs[1])
self.assertEqual(d.inputs[1], b.outputs[1])
# Check the types of the returned ops and tensors.
self.assertEqual(a.type, "IntOutputFloatOutput")
self.assertEqual(b.type, "ListOutput")
self.assertEqual(c.type, "ListInput")
self.assertEqual(d.type, "ListInput")
self.assertEqual(a.outputs[0].dtype, dtypes.int32)
self.assertEqual(a.outputs[1].dtype, dtypes.float32)
self.assertEqual(b.outputs[0].dtype, dtypes.int32)
self.assertEqual(b.outputs[1].dtype, dtypes.float32)
# Check the names of the returned ops.
self.assertEqual(a.name, "A")
self.assertEqual(b.name, "B")
self.assertEqual(c.name, "C")
self.assertEqual(d.name, "D")
# Check that the op_def is still available.
self.assertNotEqual(None, a.op_def)
def testBasic(self):
with ops.Graph().as_default():
a, b, c, d = importer.import_graph_def(