Add element tracing for tf.data.experimental.parallel_interleave.

PiperOrigin-RevId: 324696858
Change-Id: I099b9b8935a38e263bd24f008e123c0623432e40
This commit is contained in:
Jiho Choi 2020-08-03 15:31:34 -07:00 committed by TensorFlower Gardener
parent f292f31b57
commit f18d09553b
2 changed files with 21 additions and 1 deletions
tensorflow/core/kernels/data/experimental

View File

@ -394,6 +394,8 @@ tf_kernel_library(
"//tensorflow/core/kernels/data:captured_function",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/lib:traceme_encode",
],
)

View File

@ -31,6 +31,8 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/stringprintf.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
namespace tensorflow {
namespace data {
@ -323,6 +325,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
}
*end_of_sequence = false;
Status s = current_worker->outputs.front().status;
profiler::TraceMe traceme([&] {
return profiler::TraceMeEncode(
"ParallelInterleaveConsume",
{{"element_id", current_worker->outputs.front().id}});
});
current_worker->outputs.front().output.swap(*out_tensors);
current_worker->outputs.pop_front();
current_worker->cond_var.notify_one();
@ -564,8 +571,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
Status status;
// The buffered data element.
std::vector<Tensor> output;
int64 id = -1;
explicit OutputElem(const Status& s) : status(s) {}
OutputElem(const Status& s, int64 id) : status(s), id(id) {}
};
// Worker threads operate on their relevant WorkerState structs.
@ -813,6 +822,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
worker_thread_states_[thread_index]
.output_elem.output.empty() &&
!worker_thread_states_[thread_index].end_of_sequence) {
int64& id = worker_thread_states_[thread_index].output_elem.id;
profiler::TraceMe traceme(
[&] {
id = profiler::TraceMe::NewActivityId();
return profiler::TraceMeEncode(
"ParallelInterleaveProduce", {{"element_id", id}});
},
profiler::kInfo);
worker_thread_states_[thread_index].output_elem.status =
worker_thread_states_[thread_index].iterator->GetNext(
ctx.get(),
@ -856,7 +873,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
worker_thread_states_[thread_index].end_of_sequence = false;
} else {
workers_[thread_index].outputs.emplace_back(
worker_thread_states_[thread_index].output_elem.status);
worker_thread_states_[thread_index].output_elem.status,
worker_thread_states_[thread_index].output_elem.id);
workers_[thread_index].outputs.back().output.swap(
worker_thread_states_[thread_index].output_elem.output);
}