Internal change
PiperOrigin-RevId: 289196857 Change-Id: Ie94a93fd536bf3ee532b7d99b916ce4ba614e924
This commit is contained in:
parent
1db7d4b246
commit
3c3f6c03ef
@ -174,10 +174,14 @@ struct NodeItem {
|
||||
bool is_initialization_op : 1; // True iff IsInitializationOp(node)
|
||||
bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node)
|
||||
bool is_next_iteration : 1; // True iff IsNextIteration(node)
|
||||
bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp")
|
||||
|
||||
// The kernel for this node.
|
||||
OpKernel* kernel = nullptr;
|
||||
|
||||
// If the kernel is a Const op, this containts points to the constant tensor.
|
||||
const Tensor* const_tensor = nullptr;
|
||||
|
||||
// Cached values of node->num_inputs() and node->num_outputs(), to
|
||||
// avoid levels of indirection.
|
||||
int num_inputs;
|
||||
@ -659,6 +663,8 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
|
||||
CHECK(item->kernel);
|
||||
item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
|
||||
item->is_merge = IsMerge(n);
|
||||
item->const_tensor = item->kernel->const_tensor();
|
||||
item->is_noop = (item->kernel->type_string_view() == "NoOp");
|
||||
item->is_enter = IsEnter(n);
|
||||
if (item->is_enter) {
|
||||
bool is_constant_enter;
|
||||
@ -695,7 +701,7 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
|
||||
|
||||
// Initialize static information about the frames in the graph.
|
||||
frame_info->nodes->push_back(item);
|
||||
if (IsEnter(n)) {
|
||||
if (item->is_enter) {
|
||||
string enter_name;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
|
||||
EnsureFrameInfo(enter_name)->input_count++;
|
||||
@ -1878,7 +1884,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
|
||||
OpKernelContext ctx(¶ms, item.num_outputs);
|
||||
nodestats::SetOpStart(stats);
|
||||
|
||||
if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
|
||||
if (TF_PREDICT_FALSE(item.is_noop)) {
|
||||
nodestats::SetOpEnd(stats);
|
||||
} else if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
|
||||
absl::string_view op_name = op_kernel->name_view();
|
||||
const string kernel_label =
|
||||
strings::StrCat(op_name, ":", op_kernel->type_string_view());
|
||||
@ -1892,6 +1900,16 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
|
||||
// 'ScopedAnnotation' will trace the OpKernel execution time.
|
||||
profiler::ScopedAnnotation annotation(kernel_label_view);
|
||||
device->Compute(op_kernel, &ctx);
|
||||
nodestats::SetOpEnd(stats);
|
||||
s = ProcessOutputs(item, &ctx, &outputs, stats);
|
||||
} else if (item.const_tensor != nullptr && !ctx.track_allocations()) {
|
||||
// Special case for ConstantOp, which is very common.
|
||||
nodestats::SetOpEnd(stats);
|
||||
outputs.resize(1);
|
||||
outputs[0].has_value = true;
|
||||
outputs[0].val_field_is_set = true;
|
||||
outputs[0].alloc_attr = ctx.output_alloc_attr(0);
|
||||
outputs[0].val.Init(*item.const_tensor);
|
||||
} else {
|
||||
// In the common case, avoid creating any tracing objects.
|
||||
if (op_kernel->IsExpensive()) {
|
||||
@ -1901,10 +1919,9 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
|
||||
} else {
|
||||
device->Compute(op_kernel, &ctx);
|
||||
}
|
||||
nodestats::SetOpEnd(stats);
|
||||
s = ProcessOutputs(item, &ctx, &outputs, stats);
|
||||
}
|
||||
|
||||
nodestats::SetOpEnd(stats);
|
||||
s = ProcessOutputs(item, &ctx, &outputs, stats);
|
||||
if (s.ok() && impl_->device_record_tensor_accesses_) {
|
||||
// Get the list of all tensors accessed during the execution
|
||||
ctx.retrieve_accessed_tensors(&accessed_tensors);
|
||||
|
@ -152,6 +152,9 @@ class OpKernel {
|
||||
kOpIsExpensiveThresholdCycles);
|
||||
}
|
||||
|
||||
// Returns a pointer to the tensor stored inside constant ops.
|
||||
virtual const Tensor* const_tensor() const { return nullptr; }
|
||||
|
||||
// Updates the dynamic cost estimate, which is used to determine whether this
|
||||
// op is expensive. The new cost estimate is a weighted average of the old
|
||||
// cost estimate and the latest cost.
|
||||
|
@ -29,6 +29,7 @@ class ConstantOp : public OpKernel {
|
||||
explicit ConstantOp(OpKernelConstruction* ctx);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
bool IsExpensive() override { return false; }
|
||||
const Tensor* const_tensor() const override { return &tensor_; };
|
||||
~ConstantOp() override;
|
||||
|
||||
private:
|
||||
|
@ -30,6 +30,7 @@ class _HostConstantOp : public OpKernel {
|
||||
explicit _HostConstantOp(OpKernelConstruction* ctx);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
bool IsExpensive() override { return false; }
|
||||
const Tensor* const_tensor() const override { return &tensor_; };
|
||||
~_HostConstantOp() override {}
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user