From 3c3f6c03efbcd8d69b02025bb6771df13c37f038 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Fri, 10 Jan 2020 17:41:28 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 289196857 Change-Id: Ie94a93fd536bf3ee532b7d99b916ce4ba614e924 --- tensorflow/core/common_runtime/executor.cc | 27 ++++++++++++++++++---- tensorflow/core/framework/op_kernel.h | 3 +++ tensorflow/core/kernels/constant_op.h | 1 + tensorflow/core/kernels/host_constant_op.h | 1 + 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 57f3321850d..30c256d9895 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -174,10 +174,14 @@ struct NodeItem { bool is_initialization_op : 1; // True iff IsInitializationOp(node) bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) bool is_next_iteration : 1; // True iff IsNextIteration(node) + bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp") // The kernel for this node. OpKernel* kernel = nullptr; + // If the kernel is a Const op, this containts points to the constant tensor. + const Tensor* const_tensor = nullptr; + // Cached values of node->num_inputs() and node->num_outputs(), to // avoid levels of indirection. int num_inputs; @@ -659,6 +663,8 @@ Status ExecutorImpl::Initialize(const Graph& graph) { CHECK(item->kernel); item->kernel_is_async = (item->kernel->AsAsync() != nullptr); item->is_merge = IsMerge(n); + item->const_tensor = item->kernel->const_tensor(); + item->is_noop = (item->kernel->type_string_view() == "NoOp"); item->is_enter = IsEnter(n); if (item->is_enter) { bool is_constant_enter; @@ -695,7 +701,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) { // Initialize static information about the frames in the graph. frame_info->nodes->push_back(item); - if (IsEnter(n)) { + if (item->is_enter) { string enter_name; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); EnsureFrameInfo(enter_name)->input_count++; @@ -1878,7 +1884,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { OpKernelContext ctx(¶ms, item.num_outputs); nodestats::SetOpStart(stats); - if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) { + if (TF_PREDICT_FALSE(item.is_noop)) { + nodestats::SetOpEnd(stats); + } else if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) { absl::string_view op_name = op_kernel->name_view(); const string kernel_label = strings::StrCat(op_name, ":", op_kernel->type_string_view()); @@ -1892,6 +1900,16 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { // 'ScopedAnnotation' will trace the OpKernel execution time. profiler::ScopedAnnotation annotation(kernel_label_view); device->Compute(op_kernel, &ctx); + nodestats::SetOpEnd(stats); + s = ProcessOutputs(item, &ctx, &outputs, stats); + } else if (item.const_tensor != nullptr && !ctx.track_allocations()) { + // Special case for ConstantOp, which is very common. + nodestats::SetOpEnd(stats); + outputs.resize(1); + outputs[0].has_value = true; + outputs[0].val_field_is_set = true; + outputs[0].alloc_attr = ctx.output_alloc_attr(0); + outputs[0].val.Init(*item.const_tensor); } else { // In the common case, avoid creating any tracing objects. if (op_kernel->IsExpensive()) { @@ -1901,10 +1919,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { } else { device->Compute(op_kernel, &ctx); } + nodestats::SetOpEnd(stats); + s = ProcessOutputs(item, &ctx, &outputs, stats); } - - nodestats::SetOpEnd(stats); - s = ProcessOutputs(item, &ctx, &outputs, stats); if (s.ok() && impl_->device_record_tensor_accesses_) { // Get the list of all tensors accessed during the execution ctx.retrieve_accessed_tensors(&accessed_tensors); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index ea82aff6442..82a3b8ab15d 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -152,6 +152,9 @@ class OpKernel { kOpIsExpensiveThresholdCycles); } + // Returns a pointer to the tensor stored inside constant ops. + virtual const Tensor* const_tensor() const { return nullptr; } + // Updates the dynamic cost estimate, which is used to determine whether this // op is expensive. The new cost estimate is a weighted average of the old // cost estimate and the latest cost. diff --git a/tensorflow/core/kernels/constant_op.h b/tensorflow/core/kernels/constant_op.h index 77ba4418637..34f7036adf2 100644 --- a/tensorflow/core/kernels/constant_op.h +++ b/tensorflow/core/kernels/constant_op.h @@ -29,6 +29,7 @@ class ConstantOp : public OpKernel { explicit ConstantOp(OpKernelConstruction* ctx); void Compute(OpKernelContext* ctx) override; bool IsExpensive() override { return false; } + const Tensor* const_tensor() const override { return &tensor_; }; ~ConstantOp() override; private: diff --git a/tensorflow/core/kernels/host_constant_op.h b/tensorflow/core/kernels/host_constant_op.h index 1b887ea1aab..d06c6d37fe0 100644 --- a/tensorflow/core/kernels/host_constant_op.h +++ b/tensorflow/core/kernels/host_constant_op.h @@ -30,6 +30,7 @@ class _HostConstantOp : public OpKernel { explicit _HostConstantOp(OpKernelConstruction* ctx); void Compute(OpKernelContext* ctx) override; bool IsExpensive() override { return false; } + const Tensor* const_tensor() const override { return &tensor_; }; ~_HostConstantOp() override {} private: