[tf.data] Use thread-local start time instead of thread-id map in model::Node.

This change relies on the invariant that `Node::record_stop()` always being called before `Node::record_start()`.

PiperOrigin-RevId: 307689582
Change-Id: I52a48b8e52646eee3611ecea1789d91a90786978
This commit is contained in:
Derek Murray 2020-04-21 15:10:54 -07:00 committed by TensorFlower Gardener
parent ec3f58bc3c
commit 75940041fc
3 changed files with 26 additions and 10 deletions

View File

@ -636,6 +636,8 @@ class Unknown : public Node {
} // namespace
thread_local int64 Node::work_start_;
std::shared_ptr<Parameter> MakeParameter(const string& name,
std::shared_ptr<SharedState> state,
double min, double max) {

View File

@ -236,18 +236,15 @@ class Node {
// Records that a node thread has started executing.
void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
work_start_[std::this_thread::get_id()] = time_nanos;
DCHECK_EQ(work_start_, 0);
work_start_ = time_nanos;
}
// Records that a node thread has stopped executing.
void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
std::thread::id tid = std::this_thread::get_id();
auto iter = work_start_.find(tid);
if (iter != work_start_.end()) {
processing_time_ += time_nanos - iter->second;
work_start_.erase(iter);
if (work_start_ != 0) {
processing_time_ += time_nanos - work_start_;
work_start_ = 0;
} else {
VLOG(1) << "Encountered a stop event without a matching start event.";
}
@ -436,6 +433,17 @@ class Node {
void TotalMaximumBufferedBytesHelper(
absl::flat_hash_map<string, double>* total_bytes) const;
// Stores the time passed to the last call to `Node::record_start()` on the
// current thread.
//
// NOTE: This thread-local variable is shared between all instances of `Node`
// on which the same thread calls `record_start()` or `record_stop()`. It
// relies on the invariant that at most one `Node` can be "active" on a
// particular thread at any time. Therefore if `n->record_start()` is called
// on thread `t`, then `n->record_stop()` must be called before another call
// to `Node::record_start()` (for any node).
static thread_local int64 work_start_; // Will be initialized to zero.
mutable mutex mu_;
const int64 id_;
const string name_;
@ -454,7 +462,6 @@ class Node {
Metrics metrics_;
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
TF_GUARDED_BY(mu_);
absl::flat_hash_map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
// Statistic of inputs processing time history.
double input_processing_time_sum_ = 0.0L;

View File

@ -747,7 +747,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
// TODO(b/154341936): Explicitly stopping and starting this iterator
// should not be necessary, but the `kImpl` added to the prefix passed
// to `iterator_` when it was created prevents the model from identifying
// this iterator as the output of `iterator_`.
RecordStop(ctx);
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
RecordStart(ctx);
return s;
}
protected: