Switch cancellation manager to wrap/unwrap instead of pimpl

This is yak shaving for another change, which would otherwise need a bunch of extra headers in its pybind rule (the whole eager runtime basically). I think it's generally a positive change, more consistency with our other C types.

PiperOrigin-RevId: 355203151
Change-Id: Ic1c491c2e0bafc21fabed1bfc533cbbffafab399
This commit is contained in:
Allen Lavoie 2021-02-02 10:47:52 -08:00 committed by TensorFlower Gardener
parent 937a4232b1
commit cd8a31e3f7
4 changed files with 43 additions and 20 deletions

View File

@ -620,6 +620,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/core:framework",
],
)
@ -709,6 +710,19 @@ cc_header_only_library(
],
)
cc_header_only_library(
name = "tfe_cancellationmanager_internal_hdrs_only",
extra_deps = [
"@com_google_absl//absl/strings",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tfe_cancellation_manager_internal",
],
)
tf_cuda_library(
name = "c_api_test_util",
testonly = 1,

View File

@ -487,22 +487,22 @@ void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
}
TFE_CancellationManager* TFE_NewCancellationManager() {
return new TFE_CancellationManager;
return tensorflow::wrap(new tensorflow::CancellationManager);
}
void TFE_CancellationManagerStartCancel(
TFE_CancellationManager* cancellation_manager) {
cancellation_manager->cancellation_manager.StartCancel();
tensorflow::unwrap(cancellation_manager)->StartCancel();
}
bool TFE_CancellationManagerIsCancelled(
TFE_CancellationManager* cancellation_manager) {
return cancellation_manager->cancellation_manager.IsCancelled();
return tensorflow::unwrap(cancellation_manager)->IsCancelled();
}
void TFE_DeleteCancellationManager(
TFE_CancellationManager* cancellation_manager) {
delete cancellation_manager;
delete tensorflow::unwrap(cancellation_manager);
}
void TFE_OpSetCancellationManager(TFE_Op* op,
@ -510,8 +510,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
TF_Status* status) {
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
operation->SetCancellationManager(
&cancellation_manager->cancellation_manager);
operation->SetCancellationManager(tensorflow::unwrap(cancellation_manager));
status->status = tensorflow::Status::OK();
}

View File

@ -15,10 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/core/framework/cancellation.h"
struct TFE_CancellationManager {
tensorflow::CancellationManager cancellation_manager;
};
struct TFE_CancellationManager;
typedef struct TFE_CancellationManager TFE_CancellationManager;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::CancellationManager,
TFE_CancellationManager);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::CancellationManager*,
TFE_CancellationManager*);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/dlpack.h"
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
@ -48,7 +49,7 @@ namespace py = pybind11;
PYBIND11_MAKE_OPAQUE(TFE_Executor);
PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
PYBIND11_MAKE_OPAQUE(TFE_CancellationManager);
PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
@ -243,7 +244,7 @@ py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
py::object TFE_Py_ExecuteCancelable_wrapper(
const py::handle& context, const char* device_name, const char* op_name,
const py::handle& inputs, const py::handle& attrs,
TFE_CancellationManager* cancellation_manager,
tensorflow::CancellationManager* cancellation_manager,
const py::handle& num_outputs) {
TFE_Context* ctx = tensorflow::InputTFE_Context(context);
TFE_InputTensorHandles input_tensor_handles =
@ -252,7 +253,7 @@ py::object TFE_Py_ExecuteCancelable_wrapper(
InputTFE_OutputTensorHandles(num_outputs);
tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
attrs.ptr(), cancellation_manager,
attrs.ptr(), tensorflow::wrap(cancellation_manager),
&output_tensor_handles, status.get());
int output_len = output_tensor_handles.size();
@ -509,7 +510,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m, "TFE_MonitoringSampler1");
py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
m, "TFE_MonitoringSampler2");
py::class_<TFE_CancellationManager> TFE_CancellationManager_class(
py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
m, "TFE_CancellationManager");
py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
@ -855,7 +856,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
"TFE_Py_ExecuteCancelable",
[](const py::handle& context, const char* device_name,
const char* op_name, const py::handle& inputs, const py::handle& attrs,
TFE_CancellationManager& cancellation_manager,
tensorflow::CancellationManager& cancellation_manager,
const py::handle& num_outputs) {
return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
context, device_name, op_name, inputs, attrs.ptr(),
@ -1353,14 +1354,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
py::return_value_policy::reference);
// TFE_CancellationManager Logic
m.def("TFE_NewCancellationManager", &TFE_NewCancellationManager,
py::return_value_policy::reference);
m.def(
"TFE_NewCancellationManager",
[]() { return new tensorflow::CancellationManager(); },
py::return_value_policy::reference);
m.def("TFE_CancellationManagerIsCancelled",
&TFE_CancellationManagerIsCancelled);
&tensorflow::CancellationManager::IsCancelled);
m.def("TFE_CancellationManagerStartCancel",
&TFE_CancellationManagerStartCancel);
m.def("TFE_DeleteCancellationManager", &TFE_DeleteCancellationManager,
py::return_value_policy::reference);
&tensorflow::CancellationManager::StartCancel);
m.def("TFE_DeleteCancellationManager",
[](tensorflow::CancellationManager* cm) { delete cm; });
m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);