diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 8793e308466..169d54a2dc8 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -1009,7 +1009,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - return NewOrResetOp(ctx, op_or_function_name, status, + return NewOrResetOp(ctx, op_or_function_name, nullptr, status, /* op_to_reset= */ nullptr); } diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index b513fcedc59..aa6bbb2b8e5 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -29,9 +29,11 @@ limitations under the License. using tensorflow::string; void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name, - TF_Status* status, TFE_Op* op_to_reset) { + const char* raw_device_name, TF_Status* status, + TFE_Op* op_to_reset) { if (op_to_reset) { - NewOrResetOp(ctx, op_or_function_name, status, op_to_reset); + NewOrResetOp(ctx, op_or_function_name, raw_device_name, status, + op_to_reset); } else { TF_SetStatus(status, TF_INVALID_ARGUMENT, "op_to_reset should not be nullptr"); diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 82cebed76fb..5b7dddb0699 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -22,8 +22,16 @@ limitations under the License. extern "C" { #endif +// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This +// is for performance optimization by reusing an exiting unused op rather than +// creating a new op every time. If `raw_device_name` is `NULL` or empty, it +// does not set the device name. If it's not `NULL`, then it attempts to parse +// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster +// than seperately calling it because if the existing op has the same +// `raw_device_name`, it skips parsing and just leave as it is. TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name, + const char* raw_device_name, TF_Status* status, TFE_Op* op_to_reset); TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, diff --git a/tensorflow/c/eager/c_api_internal.cc b/tensorflow/c/eager/c_api_internal.cc index 772fae13faf..f6092715e17 100644 --- a/tensorflow/c/eager/c_api_internal.cc +++ b/tensorflow/c/eager/c_api_internal.cc @@ -17,7 +17,8 @@ limitations under the License. #include "tensorflow/core/platform/host_info.h" TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, - TF_Status* status, TFE_Op* op_to_reset) { + const char* raw_device_name, TF_Status* status, + TFE_Op* op_to_reset) { const char* name = op_or_function_name; // Shorthand const tensorflow::AttrTypeMap* types; bool is_function = false; @@ -25,14 +26,17 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, if (!status->status.ok()) { return nullptr; } - auto create_or_reset = [&op_to_reset, &ctx, &name, &types]( - bool is_function, - TFE_OpInferenceContext* inference_ctx) -> TFE_Op* { + auto create_or_reset = + [&op_to_reset, &ctx, &name, &types, &raw_device_name, &status]( + bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* { if (op_to_reset) { - op_to_reset->Reset(ctx, name, is_function, types, inference_ctx); + status->status = op_to_reset->Reset(ctx, name, is_function, types, + raw_device_name, inference_ctx); return op_to_reset; } else { - return new TFE_Op(ctx, name, is_function, types, inference_ctx); + TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx); + status->status = new_op->operation.SetDeviceName(raw_device_name); + return new_op; } }; diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 024ed66e560..e2a4cb97ac0 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -134,11 +134,13 @@ struct TFE_Op { inference_ctx.reset(); } - void Reset(TFE_Context* ctx, const char* op, bool is_function, - const tensorflow::AttrTypeMap* t, - TFE_OpInferenceContext* infer_ctx) { - operation.Reset(ctx->context, op, is_function, t, nullptr); + tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function, + const tensorflow::AttrTypeMap* t, + const char* raw_device_name, + TFE_OpInferenceContext* infer_ctx) { inference_ctx.reset(infer_ctx); + return operation.Reset(ctx->context, op, is_function, t, raw_device_name, + nullptr); } tensorflow::EagerOperation operation; @@ -146,7 +148,8 @@ struct TFE_Op { }; TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name, - TF_Status* status, TFE_Op* op_to_reset = nullptr); + const char* raw_device_name, TF_Status* status, + TFE_Op* op_to_reset = nullptr); struct TFE_Profiler { explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); } diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 26515f7ec37..975be6efde0 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -16,16 +16,25 @@ limitations under the License. namespace tensorflow { -tensorflow::Status EagerOperation::SetDeviceName(const char* device) { +tensorflow::Status EagerOperation::SetDeviceName(const char* device, + const bool reset) { if (device != nullptr && strlen(device) > 0) { - if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) { - return errors::InvalidArgument("Malformed device specification '", device, - "' in eager op: ", DebugString()); + if (device != raw_device_name_) { + if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) { + return errors::InvalidArgument("Malformed device specification '", + device, + "' in eager op: ", DebugString()); + } + raw_device_name_ = device; + device_name_ = + DeviceNameUtils::HasSomeDetails(device_parsed_name_) + ? DeviceNameUtils::ParsedNameToString(device_parsed_name_) + : ""; } - device_name_ = - DeviceNameUtils::HasSomeDetails(device_parsed_name_) - ? DeviceNameUtils::ParsedNameToString(device_parsed_name_) - : ""; + } else if (reset) { + raw_device_name_.clear(); + device_name_.clear(); + device_parsed_name_.Clear(); } return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 14f71c27529..87da5bf8245 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -33,7 +33,9 @@ class EagerOperation { const absl::optional remote_func_params = absl::nullopt) : ctx_(nullptr) { - Reset(ctx, op, is_function, t, executor, remote_func_params); + tensorflow::Status status = + Reset(ctx, op, is_function, t, nullptr, executor, remote_func_params); + DCHECK(status.ok()); } ~EagerOperation() { @@ -53,10 +55,11 @@ class EagerOperation { inputs_.clear(); } - void Reset(tensorflow::EagerContext* ctx, const char* op, bool is_function, - const tensorflow::AttrTypeMap* t, EagerExecutor* executor, - const absl::optional - remote_func_params = absl::nullopt) { + tensorflow::Status Reset(tensorflow::EagerContext* ctx, const char* op, + bool is_function, const tensorflow::AttrTypeMap* t, + const char* raw_device_name, EagerExecutor* executor, + const absl::optional + remote_func_params = absl::nullopt) { DCHECK(ctx_ == nullptr) << "Calling Reset without first calling Release"; DCHECK(inputs_.empty()); ctx_ = ctx; @@ -67,8 +70,6 @@ class EagerOperation { } attr_types_ = t; device_ = nullptr; - device_name_ = ""; - device_parsed_name_.Clear(); use_xla_ = false; is_function_ = is_function; cancellation_manager_ = nullptr; @@ -77,6 +78,7 @@ class EagerOperation { #ifdef TENSORFLOW_MEM_DEBUG op_name_ = op; #endif + return SetDeviceName(raw_device_name, true); } bool is_function() const { return is_function_; } @@ -105,6 +107,7 @@ class EagerOperation { tensorflow::Device* Device() const { return device_; } void SetDevice(tensorflow::Device* device) { device_ = device; + raw_device_name_.clear(); device_name_ = device->name(); device_parsed_name_ = device->parsed_name(); } @@ -113,7 +116,8 @@ class EagerOperation { const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { return device_parsed_name_; } - tensorflow::Status SetDeviceName(const char* device); + tensorflow::Status SetDeviceName(const char* device, + const bool reset = false); // Indicates whether the op is assigned to a device that is local to the // current host. @@ -147,6 +151,7 @@ class EagerOperation { const tensorflow::AttrTypeMap* attr_types_; tensorflow::gtl::InlinedVector inputs_; tensorflow::Device* device_; + string raw_device_name_; string device_name_; DeviceNameUtils::ParsedName device_parsed_name_; bool use_xla_ = false; diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index f5f336c2323..f5508d76583 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -68,13 +68,14 @@ TFE_Op* ReleaseThreadLocalOp() { } TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, - TF_Status* status) { + const char* raw_device_name, TF_Status* status) { TFE_Op* maybe_op = ReleaseThreadLocalOp(); if (maybe_op) { - TFE_OpReset(ctx, op_or_function_name, status, maybe_op); + TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op); return maybe_op; } else { - return TFE_NewOp(ctx, op_or_function_name, status); + return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status, + nullptr); } } @@ -834,14 +835,12 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, TFE_CancellationManager* cancellation_manager, TFE_OutputTensorHandles* outputs, TF_Status* out_status) { - TFE_Op* op = GetOp(ctx, op_name, out_status); + TFE_Op* op = GetOp(ctx, op_name, device_name, out_status); auto cleaner = tensorflow::gtl::MakeCleanup([op] { ReturnOp(op); }); if (!out_status->status.ok()) return; - TFE_OpSetDevice(op, device_name, out_status); - if (out_status->status.ok()) { - for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) { - TFE_OpAddInput(op, inputs->at(i), out_status); - } + + for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) { + TFE_OpAddInput(op, inputs->at(i), out_status); } if (cancellation_manager && out_status->status.ok()) { TFE_OpSetCancellationManager(op, cancellation_manager, out_status); @@ -3506,7 +3505,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { return nullptr; } - TFE_Op* op = GetOp(op_exec_info.ctx, op_name, status); + TFE_Op* op = + GetOp(op_exec_info.ctx, op_name, op_exec_info.device_name, status); auto cleaner = tensorflow::gtl::MakeCleanup([status, op] { ReturnStatus(status); ReturnOp(op); @@ -3573,11 +3573,6 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { } } - TFE_OpSetDevice(op, op_exec_info.device_name, status); - if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { - return nullptr; - } - // Flat attrs and inputs as required by the record_gradient call. The attrs // here only contain inferred attrs (non-inferred attrs are added directly // from the input args).