diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc index 8c31c55da8c..e6e8fd21406 100644 --- a/tensorflow/core/profiler/utils/event_span.cc +++ b/tensorflow/core/profiler/utils/event_span.cc @@ -116,17 +116,17 @@ EventType ClassifyGpuEvent(absl::string_view event_name) { } EventType ClassifyCpuEvent(absl::string_view event_name, int64 correlation_id) { - if (absl::StartsWithIgnoreCase(event_name, "MEMCPYHtoD")) + if (absl::StartsWithIgnoreCase(event_name, "MEMCPYHtoD") || + absl::StrContains(event_name, "Infeed")) return HOST_TO_DEVICE; if (absl::StartsWithIgnoreCase(event_name, "MEMCPYHtoH")) return HOST_TO_HOST; if (correlation_id >= 0 || absl::StartsWithIgnoreCase(event_name, "ExecutorState::Process")) { return HOST_PREPARE; - } else { - if (absl::StartsWithIgnoreCase(event_name, "IteratorGetNext")) - return HOST_WAIT_INPUT; - return HOST_COMPUTE; } + if (absl::StartsWithIgnoreCase(event_name, "IteratorGetNext")) + return HOST_WAIT_INPUT; + return HOST_COMPUTE; } std::string PrintEventType(EventType event_type) {