Use C API to implement Operation.control_inputs
Operation.control_inputs must return Python Operation instances, but C API returns TF_Operation*s. After retrieving TF_Operation*s, we get the Python Operation object from the graph using TF_Operation's name. There is currently just one test for this functionality, namely ControlDependenciesTest.testBasic. We can't enable other tests because C API does not accept unregistered operations. In following changes we will deal with this problem thereby enabling more tests. PiperOrigin-RevId: 161106091
This commit is contained in:
parent
06d25a7e62
commit
cab048ecde
@ -68,6 +68,27 @@ tensorflow::ImportNumpy();
|
||||
$result = PyUnicode_FromString($1);
|
||||
}
|
||||
|
||||
// We use TF_OperationGetControlInputs_wrapper instead of
|
||||
// TF_OperationGetControlInputs
|
||||
%ignore TF_OperationGetControlInputs;
|
||||
%unignore TF_OperationGetControlInputs_wrapper;
|
||||
// See comment for "%noexception TF_SessionRun_wrapper;"
|
||||
%noexception TF_OperationGetControlInputs_wrapper;
|
||||
|
||||
// Build a Python list of TF_Operation* and return it.
|
||||
%typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper {
|
||||
$result = PyList_New($1.size());
|
||||
if (!$result) {
|
||||
SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < $1.size(); ++i) {
|
||||
PyList_SET_ITEM($result, i, SWIG_NewPointerObj(
|
||||
$1[i], SWIGTYPE_p_TF_Operation, 0));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -766,4 +766,12 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
|
||||
ClearDecrefCache();
|
||||
}
|
||||
|
||||
std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
|
||||
TF_Operation* oper) {
|
||||
std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper));
|
||||
TF_OperationGetControlInputs(oper, control_inputs.data(),
|
||||
control_inputs.size());
|
||||
return control_inputs;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -158,6 +158,11 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
|
||||
TF_Status* out_status,
|
||||
std::vector<PyObject*>* py_outputs);
|
||||
|
||||
// Retrieves control inputs of this operation.
|
||||
// control_inputs should be empty.
|
||||
std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
|
||||
TF_Operation* oper);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
|
||||
|
@ -1583,6 +1583,13 @@ class Operation(object):
|
||||
A list of `Operation` objects.
|
||||
|
||||
"""
|
||||
if _USE_C_API:
|
||||
control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
|
||||
# pylint: disable=protected-access
|
||||
return [self.graph._get_operation_by_name_unsafe(
|
||||
c_api.TF_OperationName(c_op)) for c_op in control_c_ops]
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
return self._control_inputs
|
||||
|
||||
@property
|
||||
@ -2781,6 +2788,29 @@ class Graph(object):
|
||||
% type(name).__name__)
|
||||
return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
|
||||
|
||||
def _get_operation_by_name_unsafe(self, name):
|
||||
"""Returns the `Operation` with the given `name`.
|
||||
|
||||
This is a internal unsafe version of get_operation_by_name. It skips many
|
||||
checks and does not have user friedly error messages but runs considerably
|
||||
faster. This method may be called concurrently from multiple threads.
|
||||
|
||||
Args:
|
||||
name: The name of the `Operation` to return.
|
||||
|
||||
Returns:
|
||||
The `Operation` with the given `name`.
|
||||
|
||||
Raises:
|
||||
KeyError: If `name` does not correspond to an operation in this graph.
|
||||
"""
|
||||
|
||||
if self._finalized:
|
||||
return self._nodes_by_name[name]
|
||||
|
||||
with self._lock:
|
||||
return self._nodes_by_name[name]
|
||||
|
||||
def get_tensor_by_name(self, name):
|
||||
"""Returns the `Tensor` with the given `name`.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user