Switch cancellation manager to wrap/unwrap instead of pimpl
This is yak shaving for another change, which would otherwise need a bunch of extra headers in its pybind rule (the whole eager runtime basically). I think it's generally a positive change, more consistency with our other C types. PiperOrigin-RevId: 355203151 Change-Id: Ic1c491c2e0bafc21fabed1bfc533cbbffafab399
This commit is contained in:
parent
937a4232b1
commit
cd8a31e3f7
@ -620,6 +620,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
@ -709,6 +710,19 @@ cc_header_only_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_header_only_library(
|
||||
name = "tfe_cancellationmanager_internal_hdrs_only",
|
||||
extra_deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tfe_cancellation_manager_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_test_util",
|
||||
testonly = 1,
|
||||
|
||||
@ -487,22 +487,22 @@ void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
|
||||
}
|
||||
|
||||
TFE_CancellationManager* TFE_NewCancellationManager() {
|
||||
return new TFE_CancellationManager;
|
||||
return tensorflow::wrap(new tensorflow::CancellationManager);
|
||||
}
|
||||
|
||||
void TFE_CancellationManagerStartCancel(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
cancellation_manager->cancellation_manager.StartCancel();
|
||||
tensorflow::unwrap(cancellation_manager)->StartCancel();
|
||||
}
|
||||
|
||||
bool TFE_CancellationManagerIsCancelled(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
return cancellation_manager->cancellation_manager.IsCancelled();
|
||||
return tensorflow::unwrap(cancellation_manager)->IsCancelled();
|
||||
}
|
||||
|
||||
void TFE_DeleteCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
delete cancellation_manager;
|
||||
delete tensorflow::unwrap(cancellation_manager);
|
||||
}
|
||||
|
||||
void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
@ -510,8 +510,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
operation->SetCancellationManager(tensorflow::unwrap(cancellation_manager));
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -15,10 +15,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
||||
struct TFE_CancellationManager {
|
||||
tensorflow::CancellationManager cancellation_manager;
|
||||
};
|
||||
struct TFE_CancellationManager;
|
||||
typedef struct TFE_CancellationManager TFE_CancellationManager;
|
||||
|
||||
namespace tensorflow {
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::CancellationManager,
|
||||
TFE_CancellationManager);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::CancellationManager*,
|
||||
TFE_CancellationManager*);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
|
||||
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -48,7 +49,7 @@ namespace py = pybind11;
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(TFE_Executor);
|
||||
PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
|
||||
PYBIND11_MAKE_OPAQUE(TFE_CancellationManager);
|
||||
PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
|
||||
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
|
||||
@ -243,7 +244,7 @@ py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
|
||||
py::object TFE_Py_ExecuteCancelable_wrapper(
|
||||
const py::handle& context, const char* device_name, const char* op_name,
|
||||
const py::handle& inputs, const py::handle& attrs,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
tensorflow::CancellationManager* cancellation_manager,
|
||||
const py::handle& num_outputs) {
|
||||
TFE_Context* ctx = tensorflow::InputTFE_Context(context);
|
||||
TFE_InputTensorHandles input_tensor_handles =
|
||||
@ -252,7 +253,7 @@ py::object TFE_Py_ExecuteCancelable_wrapper(
|
||||
InputTFE_OutputTensorHandles(num_outputs);
|
||||
tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
|
||||
attrs.ptr(), cancellation_manager,
|
||||
attrs.ptr(), tensorflow::wrap(cancellation_manager),
|
||||
&output_tensor_handles, status.get());
|
||||
|
||||
int output_len = output_tensor_handles.size();
|
||||
@ -509,7 +510,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
m, "TFE_MonitoringSampler1");
|
||||
py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
|
||||
m, "TFE_MonitoringSampler2");
|
||||
py::class_<TFE_CancellationManager> TFE_CancellationManager_class(
|
||||
py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
|
||||
m, "TFE_CancellationManager");
|
||||
|
||||
py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
|
||||
@ -855,7 +856,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
"TFE_Py_ExecuteCancelable",
|
||||
[](const py::handle& context, const char* device_name,
|
||||
const char* op_name, const py::handle& inputs, const py::handle& attrs,
|
||||
TFE_CancellationManager& cancellation_manager,
|
||||
tensorflow::CancellationManager& cancellation_manager,
|
||||
const py::handle& num_outputs) {
|
||||
return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
|
||||
context, device_name, op_name, inputs, attrs.ptr(),
|
||||
@ -1353,14 +1354,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
py::return_value_policy::reference);
|
||||
|
||||
// TFE_CancellationManager Logic
|
||||
m.def("TFE_NewCancellationManager", &TFE_NewCancellationManager,
|
||||
py::return_value_policy::reference);
|
||||
m.def(
|
||||
"TFE_NewCancellationManager",
|
||||
[]() { return new tensorflow::CancellationManager(); },
|
||||
py::return_value_policy::reference);
|
||||
m.def("TFE_CancellationManagerIsCancelled",
|
||||
&TFE_CancellationManagerIsCancelled);
|
||||
&tensorflow::CancellationManager::IsCancelled);
|
||||
m.def("TFE_CancellationManagerStartCancel",
|
||||
&TFE_CancellationManagerStartCancel);
|
||||
m.def("TFE_DeleteCancellationManager", &TFE_DeleteCancellationManager,
|
||||
py::return_value_policy::reference);
|
||||
&tensorflow::CancellationManager::StartCancel);
|
||||
m.def("TFE_DeleteCancellationManager",
|
||||
[](tensorflow::CancellationManager* cm) { delete cm; });
|
||||
|
||||
m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user