From 495e179730a97e8dc0be7b0145488542c9b77f03 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 9 Dec 2019 22:01:06 -0800
Subject: [PATCH] [Perf] Skip EagerOperation::SetDeviceName(...) call if input
 device name didn't change.

PiperOrigin-RevId: 284700133
Change-Id: I7716abe6968b0686df00ea15dec3d85bf16e8cf5
---
 tensorflow/c/eager/c_api.cc                   |  2 +-
 tensorflow/c/eager/c_api_experimental.cc      |  6 +++--
 tensorflow/c/eager/c_api_experimental.h       |  8 ++++++
 tensorflow/c/eager/c_api_internal.cc          | 16 +++++++-----
 tensorflow/c/eager/c_api_internal.h           | 13 ++++++----
 .../common_runtime/eager/eager_operation.cc   | 25 +++++++++++++------
 .../common_runtime/eager/eager_operation.h    | 21 ++++++++++------
 tensorflow/python/eager/pywrap_tfe_src.cc     | 25 ++++++++-----------
 8 files changed, 71 insertions(+), 45 deletions(-)

diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 8793e308466..169d54a2dc8 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -1009,7 +1009,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
 
 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
                   TF_Status* status) {
-  return NewOrResetOp(ctx, op_or_function_name, status,
+  return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
                       /* op_to_reset= */ nullptr);
 }
 
diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index b513fcedc59..aa6bbb2b8e5 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -29,9 +29,11 @@ limitations under the License.
 using tensorflow::string;
 
 void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
-                 TF_Status* status, TFE_Op* op_to_reset) {
+                 const char* raw_device_name, TF_Status* status,
+                 TFE_Op* op_to_reset) {
   if (op_to_reset) {
-    NewOrResetOp(ctx, op_or_function_name, status, op_to_reset);
+    NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
+                 op_to_reset);
   } else {
     TF_SetStatus(status, TF_INVALID_ARGUMENT,
                  "op_to_reset should not be nullptr");
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index 82cebed76fb..5b7dddb0699 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -22,8 +22,16 @@ limitations under the License.
 extern "C" {
 #endif
 
+// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
+// is for performance optimization by reusing an exiting unused op rather than
+// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
+// does not set the device name. If it's not `NULL`, then it attempts to parse
+// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
+// than seperately calling it because if the existing op has the same
+// `raw_device_name`, it skips parsing and just leave as it is.
 TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
                                        const char* op_or_function_name,
+                                       const char* raw_device_name,
                                        TF_Status* status, TFE_Op* op_to_reset);
 
 TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
diff --git a/tensorflow/c/eager/c_api_internal.cc b/tensorflow/c/eager/c_api_internal.cc
index 772fae13faf..f6092715e17 100644
--- a/tensorflow/c/eager/c_api_internal.cc
+++ b/tensorflow/c/eager/c_api_internal.cc
@@ -17,7 +17,8 @@ limitations under the License.
 #include "tensorflow/core/platform/host_info.h"
 
 TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
-                     TF_Status* status, TFE_Op* op_to_reset) {
+                     const char* raw_device_name, TF_Status* status,
+                     TFE_Op* op_to_reset) {
   const char* name = op_or_function_name;  // Shorthand
   const tensorflow::AttrTypeMap* types;
   bool is_function = false;
@@ -25,14 +26,17 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
   if (!status->status.ok()) {
     return nullptr;
   }
-  auto create_or_reset = [&op_to_reset, &ctx, &name, &types](
-                             bool is_function,
-                             TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
+  auto create_or_reset =
+      [&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
+          bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
     if (op_to_reset) {
-      op_to_reset->Reset(ctx, name, is_function, types, inference_ctx);
+      status->status = op_to_reset->Reset(ctx, name, is_function, types,
+                                          raw_device_name, inference_ctx);
       return op_to_reset;
     } else {
-      return new TFE_Op(ctx, name, is_function, types, inference_ctx);
+      TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
+      status->status = new_op->operation.SetDeviceName(raw_device_name);
+      return new_op;
     }
   };
 
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 024ed66e560..e2a4cb97ac0 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -134,11 +134,13 @@ struct TFE_Op {
     inference_ctx.reset();
   }
 
-  void Reset(TFE_Context* ctx, const char* op, bool is_function,
-             const tensorflow::AttrTypeMap* t,
-             TFE_OpInferenceContext* infer_ctx) {
-    operation.Reset(ctx->context, op, is_function, t, nullptr);
+  tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
+                           const tensorflow::AttrTypeMap* t,
+                           const char* raw_device_name,
+                           TFE_OpInferenceContext* infer_ctx) {
     inference_ctx.reset(infer_ctx);
+    return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
+                           nullptr);
   }
 
   tensorflow::EagerOperation operation;
@@ -146,7 +148,8 @@ struct TFE_Op {
 };
 
 TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
-                     TF_Status* status, TFE_Op* op_to_reset = nullptr);
+                     const char* raw_device_name, TF_Status* status,
+                     TFE_Op* op_to_reset = nullptr);
 
 struct TFE_Profiler {
   explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc
index 26515f7ec37..975be6efde0 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.cc
+++ b/tensorflow/core/common_runtime/eager/eager_operation.cc
@@ -16,16 +16,25 @@ limitations under the License.
 
 namespace tensorflow {
 
-tensorflow::Status EagerOperation::SetDeviceName(const char* device) {
+tensorflow::Status EagerOperation::SetDeviceName(const char* device,
+                                                 const bool reset) {
   if (device != nullptr && strlen(device) > 0) {
-    if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
-      return errors::InvalidArgument("Malformed device specification '", device,
-                                     "' in eager op: ", DebugString());
+    if (device != raw_device_name_) {
+      if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
+        return errors::InvalidArgument("Malformed device specification '",
+                                       device,
+                                       "' in eager op: ", DebugString());
+      }
+      raw_device_name_ = device;
+      device_name_ =
+          DeviceNameUtils::HasSomeDetails(device_parsed_name_)
+              ? DeviceNameUtils::ParsedNameToString(device_parsed_name_)
+              : "";
     }
-    device_name_ =
-        DeviceNameUtils::HasSomeDetails(device_parsed_name_)
-            ? DeviceNameUtils::ParsedNameToString(device_parsed_name_)
-            : "";
+  } else if (reset) {
+    raw_device_name_.clear();
+    device_name_.clear();
+    device_parsed_name_.Clear();
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h
index 14f71c27529..87da5bf8245 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.h
+++ b/tensorflow/core/common_runtime/eager/eager_operation.h
@@ -33,7 +33,9 @@ class EagerOperation {
                  const absl::optional<EagerRemoteFunctionParams>
                      remote_func_params = absl::nullopt)
       : ctx_(nullptr) {
-    Reset(ctx, op, is_function, t, executor, remote_func_params);
+    tensorflow::Status status =
+        Reset(ctx, op, is_function, t, nullptr, executor, remote_func_params);
+    DCHECK(status.ok());
   }
 
   ~EagerOperation() {
@@ -53,10 +55,11 @@ class EagerOperation {
     inputs_.clear();
   }
 
-  void Reset(tensorflow::EagerContext* ctx, const char* op, bool is_function,
-             const tensorflow::AttrTypeMap* t, EagerExecutor* executor,
-             const absl::optional<EagerRemoteFunctionParams>
-                 remote_func_params = absl::nullopt) {
+  tensorflow::Status Reset(tensorflow::EagerContext* ctx, const char* op,
+                           bool is_function, const tensorflow::AttrTypeMap* t,
+                           const char* raw_device_name, EagerExecutor* executor,
+                           const absl::optional<EagerRemoteFunctionParams>
+                               remote_func_params = absl::nullopt) {
     DCHECK(ctx_ == nullptr) << "Calling Reset without first calling Release";
     DCHECK(inputs_.empty());
     ctx_ = ctx;
@@ -67,8 +70,6 @@ class EagerOperation {
     }
     attr_types_ = t;
     device_ = nullptr;
-    device_name_ = "";
-    device_parsed_name_.Clear();
     use_xla_ = false;
     is_function_ = is_function;
     cancellation_manager_ = nullptr;
@@ -77,6 +78,7 @@ class EagerOperation {
 #ifdef TENSORFLOW_MEM_DEBUG
     op_name_ = op;
 #endif
+    return SetDeviceName(raw_device_name, true);
   }
 
   bool is_function() const { return is_function_; }
@@ -105,6 +107,7 @@ class EagerOperation {
   tensorflow::Device* Device() const { return device_; }
   void SetDevice(tensorflow::Device* device) {
     device_ = device;
+    raw_device_name_.clear();
     device_name_ = device->name();
     device_parsed_name_ = device->parsed_name();
   }
@@ -113,7 +116,8 @@ class EagerOperation {
   const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
     return device_parsed_name_;
   }
-  tensorflow::Status SetDeviceName(const char* device);
+  tensorflow::Status SetDeviceName(const char* device,
+                                   const bool reset = false);
 
   // Indicates whether the op is assigned to a device that is local to the
   // current host.
@@ -147,6 +151,7 @@ class EagerOperation {
   const tensorflow::AttrTypeMap* attr_types_;
   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
   tensorflow::Device* device_;
+  string raw_device_name_;
   string device_name_;
   DeviceNameUtils::ParsedName device_parsed_name_;
   bool use_xla_ = false;
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index f5f336c2323..f5508d76583 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -68,13 +68,14 @@ TFE_Op* ReleaseThreadLocalOp() {
 }
 
 TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
-              TF_Status* status) {
+              const char* raw_device_name, TF_Status* status) {
   TFE_Op* maybe_op = ReleaseThreadLocalOp();
   if (maybe_op) {
-    TFE_OpReset(ctx, op_or_function_name, status, maybe_op);
+    TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op);
     return maybe_op;
   } else {
-    return TFE_NewOp(ctx, op_or_function_name, status);
+    return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
+                        nullptr);
   }
 }
 
@@ -834,14 +835,12 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
                               TFE_CancellationManager* cancellation_manager,
                               TFE_OutputTensorHandles* outputs,
                               TF_Status* out_status) {
-  TFE_Op* op = GetOp(ctx, op_name, out_status);
+  TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
   auto cleaner = tensorflow::gtl::MakeCleanup([op] { ReturnOp(op); });
   if (!out_status->status.ok()) return;
-  TFE_OpSetDevice(op, device_name, out_status);
-  if (out_status->status.ok()) {
-    for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
-      TFE_OpAddInput(op, inputs->at(i), out_status);
-    }
+
+  for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
+    TFE_OpAddInput(op, inputs->at(i), out_status);
   }
   if (cancellation_manager && out_status->status.ok()) {
     TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
@@ -3506,7 +3505,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
     return nullptr;
   }
 
-  TFE_Op* op = GetOp(op_exec_info.ctx, op_name, status);
+  TFE_Op* op =
+      GetOp(op_exec_info.ctx, op_name, op_exec_info.device_name, status);
   auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
     ReturnStatus(status);
     ReturnOp(op);
@@ -3573,11 +3573,6 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
     }
   }
 
-  TFE_OpSetDevice(op, op_exec_info.device_name, status);
-  if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
-    return nullptr;
-  }
-
   // Flat attrs and inputs as required by the record_gradient call. The attrs
   // here only contain inferred attrs (non-inferred attrs are added directly
   // from the input args).