diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 284c98639d7..243c870e0d7 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -68,6 +68,27 @@ tensorflow::ImportNumpy(); $result = PyUnicode_FromString($1); } +// We use TF_OperationGetControlInputs_wrapper instead of +// TF_OperationGetControlInputs +%ignore TF_OperationGetControlInputs; +%unignore TF_OperationGetControlInputs_wrapper; +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_OperationGetControlInputs_wrapper; + +// Build a Python list of TF_Operation* and return it. +%typemap(out) std::vector tensorflow::TF_OperationGetControlInputs_wrapper { + $result = PyList_New($1.size()); + if (!$result) { + SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); + } + + for (size_t i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM($result, i, SWIG_NewPointerObj( + $1[i], SWIGTYPE_p_TF_Operation, 0)); + } +} + + //////////////////////////////////////////////////////////////////////////////// // BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper() //////////////////////////////////////////////////////////////////////////////// diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 7ebb1a7fe4c..0e89ae2426d 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -766,4 +766,12 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, ClearDecrefCache(); } +std::vector TF_OperationGetControlInputs_wrapper( + TF_Operation* oper) { + std::vector control_inputs(TF_OperationNumControlInputs(oper)); + TF_OperationGetControlInputs(oper, control_inputs.data(), + control_inputs.size()); + return control_inputs; +} + } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 9937b6aeeb3..f1f70a9a1d2 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -158,6 +158,11 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, TF_Status* out_status, std::vector* py_outputs); +// Retrieves control inputs of this operation. +// control_inputs should be empty. +std::vector TF_OperationGetControlInputs_wrapper( + TF_Operation* oper); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f52127f1190..a8c2930cbe9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1583,7 +1583,14 @@ class Operation(object): A list of `Operation` objects. """ - return self._control_inputs + if _USE_C_API: + control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops] + # pylint: enable=protected-access + else: + return self._control_inputs @property def type(self): @@ -2781,6 +2788,29 @@ class Graph(object): % type(name).__name__) return self.as_graph_element(name, allow_tensor=False, allow_operation=True) + def _get_operation_by_name_unsafe(self, name): + """Returns the `Operation` with the given `name`. + + This is a internal unsafe version of get_operation_by_name. It skips many + checks and does not have user friedly error messages but runs considerably + faster. This method may be called concurrently from multiple threads. + + Args: + name: The name of the `Operation` to return. + + Returns: + The `Operation` with the given `name`. + + Raises: + KeyError: If `name` does not correspond to an operation in this graph. + """ + + if self._finalized: + return self._nodes_by_name[name] + + with self._lock: + return self._nodes_by_name[name] + def get_tensor_by_name(self, name): """Returns the `Tensor` with the given `name`.