Stop allocating a new tf_status on every op execution.

This is similar to the CL which stopped allocating a tfe_op for every new op
execution.

Benchmark benchmark_tf_identity goes from ~2.25 -> ~2.16

PiperOrigin-RevId: 272721880
This commit is contained in:
Akshay Modi 2019-10-03 12:56:14 -07:00 committed by TensorFlower Gardener
parent 8ce19dbfef
commit 172aa31854
3 changed files with 44 additions and 17 deletions

View File

@ -372,4 +372,9 @@ PyObject* TFE_Py_SetEagerContext(PyObject* python_context);
// some point.
PyObject* GetPyEagerContext();
// These are exposed since there is SWIG code that calls these.
// Returns a pre-allocated status if it exists.
TF_Status* GetStatus();
// Returns the pre-allocated status to the code.
void ReturnStatus(TF_Status* status);
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_

View File

@ -46,6 +46,8 @@ namespace {
thread_local std::unique_ptr<TFE_Op> thread_local_eager_operation = // NOLINT
nullptr;
thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
nullptr;
TFE_Op* ReleaseThreadLocalOp() {
if (thread_local_eager_operation == nullptr) {
@ -54,21 +56,29 @@ TFE_Op* ReleaseThreadLocalOp() {
return thread_local_eager_operation.release();
}
TFE_Op* CreateOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
if (op_to_reset) {
TFE_OpReset(ctx, op_or_function_name, status, op_to_reset);
return op_to_reset;
TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
TFE_Op* maybe_op = ReleaseThreadLocalOp();
if (maybe_op) {
TFE_OpReset(ctx, op_or_function_name, status, maybe_op);
return maybe_op;
} else {
return TFE_NewOp(ctx, op_or_function_name, status);
}
}
void ClearAndReturnThreadLocalOp(TFE_Op* object) {
void ReturnOp(TFE_Op* object) {
object->Clear();
thread_local_eager_operation.reset(object);
}
TF_Status* ReleaseThreadLocalStatus() {
if (thread_local_tf_status == nullptr) {
return nullptr;
}
return thread_local_tf_status.release();
}
struct InputInfo {
InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
@ -746,6 +756,21 @@ tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
} // namespace
TF_Status* GetStatus() {
TF_Status* maybe_status = ReleaseThreadLocalStatus();
if (maybe_status) {
TF_SetStatus(maybe_status, TF_OK, "");
return maybe_status;
} else {
return TF_NewStatus();
}
}
void ReturnStatus(TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
thread_local_tf_status.reset(status);
}
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
const char* op_name, TFE_InputTensorHandles* inputs,
PyObject* attrs, TFE_OutputTensorHandles* outputs,
@ -761,10 +786,8 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
TFE_CancellationManager* cancellation_manager,
TFE_OutputTensorHandles* outputs,
TF_Status* out_status) {
TFE_Op* op =
CreateOrResetOp(ctx, op_name, out_status, ReleaseThreadLocalOp());
auto cleaner =
tensorflow::gtl::MakeCleanup([op] { ClearAndReturnThreadLocalOp(op); });
TFE_Op* op = GetOp(ctx, op_name, out_status);
auto cleaner = tensorflow::gtl::MakeCleanup([op] { ReturnOp(op); });
if (!out_status->status.ok()) return;
TFE_OpSetDevice(op, device_name, out_status);
if (out_status->status.ok()) {
@ -3359,7 +3382,7 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
op_exec_info.run_post_exec_callbacks;
TF_Status* status = TF_NewStatus();
TF_Status* status = GetStatus();
const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
if (op_name == nullptr) {
PyErr_SetString(PyExc_TypeError,
@ -3369,11 +3392,10 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
return nullptr;
}
TFE_Op* op = CreateOrResetOp(op_exec_info.ctx, op_name, status,
ReleaseThreadLocalOp());
TFE_Op* op = GetOp(op_exec_info.ctx, op_name, status);
auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
TF_DeleteStatus(status);
ClearAndReturnThreadLocalOp(op);
ReturnStatus(status);
ReturnOp(op);
});
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
return nullptr;

View File

@ -425,11 +425,11 @@ static PyObject* TFE_ClearScalarCache();
// Create new Status object.
%typemap(in, numinputs=0) TF_Status *out_status {
$1 = TF_NewStatus();
$1 = GetStatus();
}
%typemap(freearg) (TF_Status* out_status) {
TF_DeleteStatus($1);
ReturnStatus($1);
}
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {