diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 41ca3c728d3..aed41d5abcd 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -372,4 +372,9 @@ PyObject* TFE_Py_SetEagerContext(PyObject* python_context); // some point. PyObject* GetPyEagerContext(); +// These are exposed since there is SWIG code that calls these. +// Returns a pre-allocated status if it exists. +TF_Status* GetStatus(); +// Returns the pre-allocated status to the code. +void ReturnStatus(TF_Status* status); #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 8258f0cc97c..173ee414585 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -46,6 +46,8 @@ namespace { thread_local std::unique_ptr thread_local_eager_operation = // NOLINT nullptr; +thread_local std::unique_ptr thread_local_tf_status = // NOLINT + nullptr; TFE_Op* ReleaseThreadLocalOp() { if (thread_local_eager_operation == nullptr) { @@ -54,21 +56,29 @@ TFE_Op* ReleaseThreadLocalOp() { return thread_local_eager_operation.release(); } -TFE_Op* CreateOrResetOp(TFE_Context* ctx, const char* op_or_function_name, - TF_Status* status, TFE_Op* op_to_reset) { - if (op_to_reset) { - TFE_OpReset(ctx, op_or_function_name, status, op_to_reset); - return op_to_reset; +TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, + TF_Status* status) { + TFE_Op* maybe_op = ReleaseThreadLocalOp(); + if (maybe_op) { + TFE_OpReset(ctx, op_or_function_name, status, maybe_op); + return maybe_op; } else { return TFE_NewOp(ctx, op_or_function_name, status); } } -void ClearAndReturnThreadLocalOp(TFE_Op* object) { +void ReturnOp(TFE_Op* object) { object->Clear(); thread_local_eager_operation.reset(object); } +TF_Status* ReleaseThreadLocalStatus() { + if (thread_local_tf_status == nullptr) { + return nullptr; + } + return thread_local_tf_status.release(); +} + struct InputInfo { InputInfo(int i, bool is_list) : i(i), is_list(is_list) {} @@ -746,6 +756,21 @@ tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0; } // namespace +TF_Status* GetStatus() { + TF_Status* maybe_status = ReleaseThreadLocalStatus(); + if (maybe_status) { + TF_SetStatus(maybe_status, TF_OK, ""); + return maybe_status; + } else { + return TF_NewStatus(); + } +} + +void ReturnStatus(TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + thread_local_tf_status.reset(status); +} + void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, const char* op_name, TFE_InputTensorHandles* inputs, PyObject* attrs, TFE_OutputTensorHandles* outputs, @@ -761,10 +786,8 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, TFE_CancellationManager* cancellation_manager, TFE_OutputTensorHandles* outputs, TF_Status* out_status) { - TFE_Op* op = - CreateOrResetOp(ctx, op_name, out_status, ReleaseThreadLocalOp()); - auto cleaner = - tensorflow::gtl::MakeCleanup([op] { ClearAndReturnThreadLocalOp(op); }); + TFE_Op* op = GetOp(ctx, op_name, out_status); + auto cleaner = tensorflow::gtl::MakeCleanup([op] { ReturnOp(op); }); if (!out_status->status.ok()) return; TFE_OpSetDevice(op, device_name, out_status); if (out_status->status.ok()) { @@ -3359,7 +3382,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks; - TF_Status* status = TF_NewStatus(); + TF_Status* status = GetStatus(); const char* op_name = TFE_GetPythonString(op_exec_info.op_name); if (op_name == nullptr) { PyErr_SetString(PyExc_TypeError, @@ -3369,11 +3392,10 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { return nullptr; } - TFE_Op* op = CreateOrResetOp(op_exec_info.ctx, op_name, status, - ReleaseThreadLocalOp()); + TFE_Op* op = GetOp(op_exec_info.ctx, op_name, status); auto cleaner = tensorflow::gtl::MakeCleanup([status, op] { - TF_DeleteStatus(status); - ClearAndReturnThreadLocalOp(op); + ReturnStatus(status); + ReturnOp(op); }); if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { return nullptr; diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 3fb50f7d157..7a36452396a 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -425,11 +425,11 @@ static PyObject* TFE_ClearScalarCache(); // Create new Status object. %typemap(in, numinputs=0) TF_Status *out_status { - $1 = TF_NewStatus(); + $1 = GetStatus(); } %typemap(freearg) (TF_Status* out_status) { - TF_DeleteStatus($1); + ReturnStatus($1); } %typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {