Use C API to implement Operation.control_inputs

Operation.control_inputs must return Python Operation instances,
but C API returns TF_Operation*s. After retrieving TF_Operation*s,
we get the Python Operation object from the graph using TF_Operation's
name.

There is currently just one test for this functionality, namely
ControlDependenciesTest.testBasic. We can't enable other tests
because C API does not accept unregistered operations.
In following changes we will deal with this problem thereby enabling
more tests.

PiperOrigin-RevId: 161106091
This commit is contained in:
A. Unique TensorFlower 2017-07-06 11:26:54 -07:00 committed by TensorFlower Gardener
parent 06d25a7e62
commit cab048ecde
4 changed files with 65 additions and 1 deletions

View File

@ -68,6 +68,27 @@ tensorflow::ImportNumpy();
$result = PyUnicode_FromString($1); $result = PyUnicode_FromString($1);
} }
// We use TF_OperationGetControlInputs_wrapper instead of
// TF_OperationGetControlInputs
%ignore TF_OperationGetControlInputs;
%unignore TF_OperationGetControlInputs_wrapper;
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_OperationGetControlInputs_wrapper;
// Build a Python list of TF_Operation* and return it.
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
$result = PyList_New($1.size());
if (!$result) {
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
}
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, SWIG_NewPointerObj(
$1[i], SWIGTYPE_p_TF_Operation, 0));
}
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper() // BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -766,4 +766,12 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
ClearDecrefCache(); ClearDecrefCache();
} }
std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
TF_Operation* oper) {
std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
TF_OperationGetControlInputs(oper, control_inputs.data(),
control_inputs.size());
return control_inputs;
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -158,6 +158,11 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
TF_Status* out_status, TF_Status* out_status,
std::vector<PyObject*>* py_outputs); std::vector<PyObject*>* py_outputs);
// Retrieves control inputs of this operation.
// control_inputs should be empty.
std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
TF_Operation* oper);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_

View File

@ -1583,7 +1583,14 @@ class Operation(object):
A list of `Operation` objects. A list of `Operation` objects.
""" """
return self._control_inputs if _USE_C_API:
control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
# pylint: disable=protected-access
return [self.graph._get_operation_by_name_unsafe(
c_api.TF_OperationName(c_op)) for c_op in control_c_ops]
# pylint: enable=protected-access
else:
return self._control_inputs
@property @property
def type(self): def type(self):
@ -2781,6 +2788,29 @@ class Graph(object):
% type(name).__name__) % type(name).__name__)
return self.as_graph_element(name, allow_tensor=False, allow_operation=True) return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
def _get_operation_by_name_unsafe(self, name):
"""Returns the `Operation` with the given `name`.
This is a internal unsafe version of get_operation_by_name. It skips many
checks and does not have user friedly error messages but runs considerably
faster. This method may be called concurrently from multiple threads.
Args:
name: The name of the `Operation` to return.
Returns:
The `Operation` with the given `name`.
Raises:
KeyError: If `name` does not correspond to an operation in this graph.
"""
if self._finalized:
return self._nodes_by_name[name]
with self._lock:
return self._nodes_by_name[name]
def get_tensor_by_name(self, name): def get_tensor_by_name(self, name):
"""Returns the `Tensor` with the given `name`. """Returns the `Tensor` with the given `name`.