From 3c3f6c03efbcd8d69b02025bb6771df13c37f038 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 10 Jan 2020 17:41:28 -0800
Subject: [PATCH] Internal change

PiperOrigin-RevId: 289196857
Change-Id: Ie94a93fd536bf3ee532b7d99b916ce4ba614e924
---
 tensorflow/core/common_runtime/executor.cc | 27 ++++++++++++++++++----
 tensorflow/core/framework/op_kernel.h      |  3 +++
 tensorflow/core/kernels/constant_op.h      |  1 +
 tensorflow/core/kernels/host_constant_op.h |  1 +
 4 files changed, 27 insertions(+), 5 deletions(-)

diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 57f3321850d..30c256d9895 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -174,10 +174,14 @@ struct NodeItem {
   bool is_initialization_op : 1;  // True iff IsInitializationOp(node)
   bool is_recv_or_switch : 1;     // True iff IsRecv(node) || IsSwitch(node)
   bool is_next_iteration : 1;     // True iff IsNextIteration(node)
+  bool is_noop : 1;  // True iff item->kernel->type_string_view() == "NoOp")
 
   // The kernel for this node.
   OpKernel* kernel = nullptr;
 
+  // If the kernel is a Const op, this containts points to the constant tensor.
+  const Tensor* const_tensor = nullptr;
+
   // Cached values of node->num_inputs() and node->num_outputs(), to
   // avoid levels of indirection.
   int num_inputs;
@@ -659,6 +663,8 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
     CHECK(item->kernel);
     item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
     item->is_merge = IsMerge(n);
+    item->const_tensor = item->kernel->const_tensor();
+    item->is_noop = (item->kernel->type_string_view() == "NoOp");
     item->is_enter = IsEnter(n);
     if (item->is_enter) {
       bool is_constant_enter;
@@ -695,7 +701,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
 
     // Initialize static information about the frames in the graph.
     frame_info->nodes->push_back(item);
-    if (IsEnter(n)) {
+    if (item->is_enter) {
       string enter_name;
       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
       EnsureFrameInfo(enter_name)->input_count++;
@@ -1878,7 +1884,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
         OpKernelContext ctx(&params, item.num_outputs);
         nodestats::SetOpStart(stats);
 
-        if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
+        if (TF_PREDICT_FALSE(item.is_noop)) {
+          nodestats::SetOpEnd(stats);
+        } else if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
           absl::string_view op_name = op_kernel->name_view();
           const string kernel_label =
               strings::StrCat(op_name, ":", op_kernel->type_string_view());
@@ -1892,6 +1900,16 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
           // 'ScopedAnnotation' will trace the OpKernel execution time.
           profiler::ScopedAnnotation annotation(kernel_label_view);
           device->Compute(op_kernel, &ctx);
+          nodestats::SetOpEnd(stats);
+          s = ProcessOutputs(item, &ctx, &outputs, stats);
+        } else if (item.const_tensor != nullptr && !ctx.track_allocations()) {
+          // Special case for ConstantOp, which is very common.
+          nodestats::SetOpEnd(stats);
+          outputs.resize(1);
+          outputs[0].has_value = true;
+          outputs[0].val_field_is_set = true;
+          outputs[0].alloc_attr = ctx.output_alloc_attr(0);
+          outputs[0].val.Init(*item.const_tensor);
         } else {
           // In the common case, avoid creating any tracing objects.
           if (op_kernel->IsExpensive()) {
@@ -1901,10 +1919,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
           } else {
             device->Compute(op_kernel, &ctx);
           }
+          nodestats::SetOpEnd(stats);
+          s = ProcessOutputs(item, &ctx, &outputs, stats);
         }
-
-        nodestats::SetOpEnd(stats);
-        s = ProcessOutputs(item, &ctx, &outputs, stats);
         if (s.ok() && impl_->device_record_tensor_accesses_) {
           // Get the list of all tensors accessed during the execution
           ctx.retrieve_accessed_tensors(&accessed_tensors);
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index ea82aff6442..82a3b8ab15d 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -152,6 +152,9 @@ class OpKernel {
                           kOpIsExpensiveThresholdCycles);
   }
 
+  // Returns a pointer to the tensor stored inside constant ops.
+  virtual const Tensor* const_tensor() const { return nullptr; }
+
   // Updates the dynamic cost estimate, which is used to determine whether this
   // op is expensive. The new cost estimate is a weighted average of the old
   // cost estimate and the latest cost.
diff --git a/tensorflow/core/kernels/constant_op.h b/tensorflow/core/kernels/constant_op.h
index 77ba4418637..34f7036adf2 100644
--- a/tensorflow/core/kernels/constant_op.h
+++ b/tensorflow/core/kernels/constant_op.h
@@ -29,6 +29,7 @@ class ConstantOp : public OpKernel {
   explicit ConstantOp(OpKernelConstruction* ctx);
   void Compute(OpKernelContext* ctx) override;
   bool IsExpensive() override { return false; }
+  const Tensor* const_tensor() const override { return &tensor_; };
   ~ConstantOp() override;
 
  private:
diff --git a/tensorflow/core/kernels/host_constant_op.h b/tensorflow/core/kernels/host_constant_op.h
index 1b887ea1aab..d06c6d37fe0 100644
--- a/tensorflow/core/kernels/host_constant_op.h
+++ b/tensorflow/core/kernels/host_constant_op.h
@@ -30,6 +30,7 @@ class _HostConstantOp : public OpKernel {
   explicit _HostConstantOp(OpKernelConstruction* ctx);
   void Compute(OpKernelContext* ctx) override;
   bool IsExpensive() override { return false; }
+  const Tensor* const_tensor() const override { return &tensor_; };
   ~_HostConstantOp() override {}
 
  private: