[tf.data] Use thread-local start time instead of thread-id map in model::Node
.
This change relies on the invariant that `Node::record_stop()` always being called before `Node::record_start()`. PiperOrigin-RevId: 307689582 Change-Id: I52a48b8e52646eee3611ecea1789d91a90786978
This commit is contained in:
parent
ec3f58bc3c
commit
75940041fc
@ -636,6 +636,8 @@ class Unknown : public Node {
|
||||
|
||||
} // namespace
|
||||
|
||||
thread_local int64 Node::work_start_;
|
||||
|
||||
std::shared_ptr<Parameter> MakeParameter(const string& name,
|
||||
std::shared_ptr<SharedState> state,
|
||||
double min, double max) {
|
||||
|
@ -236,18 +236,15 @@ class Node {
|
||||
|
||||
// Records that a node thread has started executing.
|
||||
void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
work_start_[std::this_thread::get_id()] = time_nanos;
|
||||
DCHECK_EQ(work_start_, 0);
|
||||
work_start_ = time_nanos;
|
||||
}
|
||||
|
||||
// Records that a node thread has stopped executing.
|
||||
void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
std::thread::id tid = std::this_thread::get_id();
|
||||
auto iter = work_start_.find(tid);
|
||||
if (iter != work_start_.end()) {
|
||||
processing_time_ += time_nanos - iter->second;
|
||||
work_start_.erase(iter);
|
||||
if (work_start_ != 0) {
|
||||
processing_time_ += time_nanos - work_start_;
|
||||
work_start_ = 0;
|
||||
} else {
|
||||
VLOG(1) << "Encountered a stop event without a matching start event.";
|
||||
}
|
||||
@ -436,6 +433,17 @@ class Node {
|
||||
void TotalMaximumBufferedBytesHelper(
|
||||
absl::flat_hash_map<string, double>* total_bytes) const;
|
||||
|
||||
// Stores the time passed to the last call to `Node::record_start()` on the
|
||||
// current thread.
|
||||
//
|
||||
// NOTE: This thread-local variable is shared between all instances of `Node`
|
||||
// on which the same thread calls `record_start()` or `record_stop()`. It
|
||||
// relies on the invariant that at most one `Node` can be "active" on a
|
||||
// particular thread at any time. Therefore if `n->record_start()` is called
|
||||
// on thread `t`, then `n->record_stop()` must be called before another call
|
||||
// to `Node::record_start()` (for any node).
|
||||
static thread_local int64 work_start_; // Will be initialized to zero.
|
||||
|
||||
mutable mutex mu_;
|
||||
const int64 id_;
|
||||
const string name_;
|
||||
@ -454,7 +462,6 @@ class Node {
|
||||
Metrics metrics_;
|
||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
|
||||
TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
|
||||
|
||||
// Statistic of inputs processing time history.
|
||||
double input_processing_time_sum_ = 0.0L;
|
||||
|
@ -747,7 +747,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||
// should not be necessary, but the `kImpl` added to the prefix passed
|
||||
// to `iterator_` when it was created prevents the model from identifying
|
||||
// this iterator as the output of `iterator_`.
|
||||
RecordStop(ctx);
|
||||
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
RecordStart(ctx);
|
||||
return s;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
Loading…
Reference in New Issue
Block a user