From 75940041fc857971b440c35da52937589e229737 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 21 Apr 2020 15:10:54 -0700 Subject: [PATCH] [tf.data] Use thread-local start time instead of thread-id map in `model::Node`. This change relies on the invariant that `Node::record_stop()` always being called before `Node::record_start()`. PiperOrigin-RevId: 307689582 Change-Id: I52a48b8e52646eee3611ecea1789d91a90786978 --- tensorflow/core/framework/model.cc | 2 ++ tensorflow/core/framework/model.h | 25 ++++++++++++------- .../core/kernels/data/cache_dataset_ops.cc | 9 ++++++- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 85eeeb60a6b..322e30179c5 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -636,6 +636,8 @@ class Unknown : public Node { } // namespace +thread_local int64 Node::work_start_; + std::shared_ptr MakeParameter(const string& name, std::shared_ptr state, double min, double max) { diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index c5d477e2136..1cea4d50d34 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -236,18 +236,15 @@ class Node { // Records that a node thread has started executing. void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - work_start_[std::this_thread::get_id()] = time_nanos; + DCHECK_EQ(work_start_, 0); + work_start_ = time_nanos; } // Records that a node thread has stopped executing. void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - std::thread::id tid = std::this_thread::get_id(); - auto iter = work_start_.find(tid); - if (iter != work_start_.end()) { - processing_time_ += time_nanos - iter->second; - work_start_.erase(iter); + if (work_start_ != 0) { + processing_time_ += time_nanos - work_start_; + work_start_ = 0; } else { VLOG(1) << "Encountered a stop event without a matching start event."; } @@ -436,6 +433,17 @@ class Node { void TotalMaximumBufferedBytesHelper( absl::flat_hash_map* total_bytes) const; + // Stores the time passed to the last call to `Node::record_start()` on the + // current thread. + // + // NOTE: This thread-local variable is shared between all instances of `Node` + // on which the same thread calls `record_start()` or `record_stop()`. It + // relies on the invariant that at most one `Node` can be "active" on a + // particular thread at any time. Therefore if `n->record_start()` is called + // on thread `t`, then `n->record_stop()` must be called before another call + // to `Node::record_start()` (for any node). + static thread_local int64 work_start_; // Will be initialized to zero. + mutable mutex mu_; const int64 id_; const string name_; @@ -454,7 +462,6 @@ class Node { Metrics metrics_; absl::flat_hash_map> parameters_ TF_GUARDED_BY(mu_); - absl::flat_hash_map work_start_ TF_GUARDED_BY(mu_); // Statistic of inputs processing time history. double input_processing_time_sum_ = 0.0L; diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 556b859c781..5f7dedc8a36 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -747,7 +747,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - return iterator_->GetNext(ctx, out_tensors, end_of_sequence); + // TODO(b/154341936): Explicitly stopping and starting this iterator + // should not be necessary, but the `kImpl` added to the prefix passed + // to `iterator_` when it was created prevents the model from identifying + // this iterator as the output of `iterator_`. + RecordStop(ctx); + Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence); + RecordStart(ctx); + return s; } protected: