Internal change

PiperOrigin-RevId: 289196857
Change-Id: Ie94a93fd536bf3ee532b7d99b916ce4ba614e924
This commit is contained in:
A. Unique TensorFlower 2020-01-10 17:41:28 -08:00 committed by TensorFlower Gardener
parent 1db7d4b246
commit 3c3f6c03ef
4 changed files with 27 additions and 5 deletions

View File

@ -174,10 +174,14 @@ struct NodeItem {
bool is_initialization_op : 1; // True iff IsInitializationOp(node)
bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node)
bool is_next_iteration : 1; // True iff IsNextIteration(node)
bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp")
// The kernel for this node.
OpKernel* kernel = nullptr;
// If the kernel is a Const op, this containts points to the constant tensor.
const Tensor* const_tensor = nullptr;
// Cached values of node->num_inputs() and node->num_outputs(), to
// avoid levels of indirection.
int num_inputs;
@ -659,6 +663,8 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
CHECK(item->kernel);
item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
item->is_merge = IsMerge(n);
item->const_tensor = item->kernel->const_tensor();
item->is_noop = (item->kernel->type_string_view() == "NoOp");
item->is_enter = IsEnter(n);
if (item->is_enter) {
bool is_constant_enter;
@ -695,7 +701,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
// Initialize static information about the frames in the graph.
frame_info->nodes->push_back(item);
if (IsEnter(n)) {
if (item->is_enter) {
string enter_name;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
EnsureFrameInfo(enter_name)->input_count++;
@ -1878,7 +1884,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
OpKernelContext ctx(&params, item.num_outputs);
nodestats::SetOpStart(stats);
if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
if (TF_PREDICT_FALSE(item.is_noop)) {
nodestats::SetOpEnd(stats);
} else if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
absl::string_view op_name = op_kernel->name_view();
const string kernel_label =
strings::StrCat(op_name, ":", op_kernel->type_string_view());
@ -1892,6 +1900,16 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
// 'ScopedAnnotation' will trace the OpKernel execution time.
profiler::ScopedAnnotation annotation(kernel_label_view);
device->Compute(op_kernel, &ctx);
nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
} else if (item.const_tensor != nullptr && !ctx.track_allocations()) {
// Special case for ConstantOp, which is very common.
nodestats::SetOpEnd(stats);
outputs.resize(1);
outputs[0].has_value = true;
outputs[0].val_field_is_set = true;
outputs[0].alloc_attr = ctx.output_alloc_attr(0);
outputs[0].val.Init(*item.const_tensor);
} else {
// In the common case, avoid creating any tracing objects.
if (op_kernel->IsExpensive()) {
@ -1901,10 +1919,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
} else {
device->Compute(op_kernel, &ctx);
}
nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
}
nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
if (s.ok() && impl_->device_record_tensor_accesses_) {
// Get the list of all tensors accessed during the execution
ctx.retrieve_accessed_tensors(&accessed_tensors);

View File

@ -152,6 +152,9 @@ class OpKernel {
kOpIsExpensiveThresholdCycles);
}
// Returns a pointer to the tensor stored inside constant ops.
virtual const Tensor* const_tensor() const { return nullptr; }
// Updates the dynamic cost estimate, which is used to determine whether this
// op is expensive. The new cost estimate is a weighted average of the old
// cost estimate and the latest cost.

View File

@ -29,6 +29,7 @@ class ConstantOp : public OpKernel {
explicit ConstantOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
bool IsExpensive() override { return false; }
const Tensor* const_tensor() const override { return &tensor_; };
~ConstantOp() override;
private:

View File

@ -30,6 +30,7 @@ class _HostConstantOp : public OpKernel {
explicit _HostConstantOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
bool IsExpensive() override { return false; }
const Tensor* const_tensor() const override { return &tensor_; };
~_HostConstantOp() override {}
private: