Stop allocating a new tf_status on every op execution.
This is similar to the CL which stopped allocating a tfe_op for every new op execution. Benchmark benchmark_tf_identity goes from ~2.25 -> ~2.16 PiperOrigin-RevId: 272721880
This commit is contained in:
parent
8ce19dbfef
commit
172aa31854
tensorflow/python
@ -372,4 +372,9 @@ PyObject* TFE_Py_SetEagerContext(PyObject* python_context);
|
||||
// some point.
|
||||
PyObject* GetPyEagerContext();
|
||||
|
||||
// These are exposed since there is SWIG code that calls these.
|
||||
// Returns a pre-allocated status if it exists.
|
||||
TF_Status* GetStatus();
|
||||
// Returns the pre-allocated status to the code.
|
||||
void ReturnStatus(TF_Status* status);
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
|
||||
|
@ -46,6 +46,8 @@ namespace {
|
||||
|
||||
thread_local std::unique_ptr<TFE_Op> thread_local_eager_operation = // NOLINT
|
||||
nullptr;
|
||||
thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
|
||||
nullptr;
|
||||
|
||||
TFE_Op* ReleaseThreadLocalOp() {
|
||||
if (thread_local_eager_operation == nullptr) {
|
||||
@ -54,21 +56,29 @@ TFE_Op* ReleaseThreadLocalOp() {
|
||||
return thread_local_eager_operation.release();
|
||||
}
|
||||
|
||||
TFE_Op* CreateOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset) {
|
||||
if (op_to_reset) {
|
||||
TFE_OpReset(ctx, op_or_function_name, status, op_to_reset);
|
||||
return op_to_reset;
|
||||
TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
TFE_Op* maybe_op = ReleaseThreadLocalOp();
|
||||
if (maybe_op) {
|
||||
TFE_OpReset(ctx, op_or_function_name, status, maybe_op);
|
||||
return maybe_op;
|
||||
} else {
|
||||
return TFE_NewOp(ctx, op_or_function_name, status);
|
||||
}
|
||||
}
|
||||
|
||||
void ClearAndReturnThreadLocalOp(TFE_Op* object) {
|
||||
void ReturnOp(TFE_Op* object) {
|
||||
object->Clear();
|
||||
thread_local_eager_operation.reset(object);
|
||||
}
|
||||
|
||||
TF_Status* ReleaseThreadLocalStatus() {
|
||||
if (thread_local_tf_status == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return thread_local_tf_status.release();
|
||||
}
|
||||
|
||||
struct InputInfo {
|
||||
InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
|
||||
|
||||
@ -746,6 +756,21 @@ tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
|
||||
|
||||
} // namespace
|
||||
|
||||
TF_Status* GetStatus() {
|
||||
TF_Status* maybe_status = ReleaseThreadLocalStatus();
|
||||
if (maybe_status) {
|
||||
TF_SetStatus(maybe_status, TF_OK, "");
|
||||
return maybe_status;
|
||||
} else {
|
||||
return TF_NewStatus();
|
||||
}
|
||||
}
|
||||
|
||||
void ReturnStatus(TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
thread_local_tf_status.reset(status);
|
||||
}
|
||||
|
||||
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
|
||||
const char* op_name, TFE_InputTensorHandles* inputs,
|
||||
PyObject* attrs, TFE_OutputTensorHandles* outputs,
|
||||
@ -761,10 +786,8 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TFE_OutputTensorHandles* outputs,
|
||||
TF_Status* out_status) {
|
||||
TFE_Op* op =
|
||||
CreateOrResetOp(ctx, op_name, out_status, ReleaseThreadLocalOp());
|
||||
auto cleaner =
|
||||
tensorflow::gtl::MakeCleanup([op] { ClearAndReturnThreadLocalOp(op); });
|
||||
TFE_Op* op = GetOp(ctx, op_name, out_status);
|
||||
auto cleaner = tensorflow::gtl::MakeCleanup([op] { ReturnOp(op); });
|
||||
if (!out_status->status.ok()) return;
|
||||
TFE_OpSetDevice(op, device_name, out_status);
|
||||
if (out_status->status.ok()) {
|
||||
@ -3359,7 +3382,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
|
||||
op_exec_info.run_post_exec_callbacks;
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Status* status = GetStatus();
|
||||
const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
|
||||
if (op_name == nullptr) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
@ -3369,11 +3392,10 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_Op* op = CreateOrResetOp(op_exec_info.ctx, op_name, status,
|
||||
ReleaseThreadLocalOp());
|
||||
TFE_Op* op = GetOp(op_exec_info.ctx, op_name, status);
|
||||
auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
|
||||
TF_DeleteStatus(status);
|
||||
ClearAndReturnThreadLocalOp(op);
|
||||
ReturnStatus(status);
|
||||
ReturnOp(op);
|
||||
});
|
||||
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
|
||||
return nullptr;
|
||||
|
@ -425,11 +425,11 @@ static PyObject* TFE_ClearScalarCache();
|
||||
|
||||
// Create new Status object.
|
||||
%typemap(in, numinputs=0) TF_Status *out_status {
|
||||
$1 = TF_NewStatus();
|
||||
$1 = GetStatus();
|
||||
}
|
||||
|
||||
%typemap(freearg) (TF_Status* out_status) {
|
||||
TF_DeleteStatus($1);
|
||||
ReturnStatus($1);
|
||||
}
|
||||
|
||||
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
|
||||
|
Loading…
Reference in New Issue
Block a user