[tf.data] Use thread-local start time instead of thread-id map in model::Node
.
This change relies on the invariant that `Node::record_stop()` always being called before `Node::record_start()`. PiperOrigin-RevId: 307689582 Change-Id: I52a48b8e52646eee3611ecea1789d91a90786978
This commit is contained in:
parent
ec3f58bc3c
commit
75940041fc
@ -636,6 +636,8 @@ class Unknown : public Node {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
thread_local int64 Node::work_start_;
|
||||||
|
|
||||||
std::shared_ptr<Parameter> MakeParameter(const string& name,
|
std::shared_ptr<Parameter> MakeParameter(const string& name,
|
||||||
std::shared_ptr<SharedState> state,
|
std::shared_ptr<SharedState> state,
|
||||||
double min, double max) {
|
double min, double max) {
|
||||||
|
@ -236,18 +236,15 @@ class Node {
|
|||||||
|
|
||||||
// Records that a node thread has started executing.
|
// Records that a node thread has started executing.
|
||||||
void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
||||||
mutex_lock l(mu_);
|
DCHECK_EQ(work_start_, 0);
|
||||||
work_start_[std::this_thread::get_id()] = time_nanos;
|
work_start_ = time_nanos;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Records that a node thread has stopped executing.
|
// Records that a node thread has stopped executing.
|
||||||
void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
||||||
mutex_lock l(mu_);
|
if (work_start_ != 0) {
|
||||||
std::thread::id tid = std::this_thread::get_id();
|
processing_time_ += time_nanos - work_start_;
|
||||||
auto iter = work_start_.find(tid);
|
work_start_ = 0;
|
||||||
if (iter != work_start_.end()) {
|
|
||||||
processing_time_ += time_nanos - iter->second;
|
|
||||||
work_start_.erase(iter);
|
|
||||||
} else {
|
} else {
|
||||||
VLOG(1) << "Encountered a stop event without a matching start event.";
|
VLOG(1) << "Encountered a stop event without a matching start event.";
|
||||||
}
|
}
|
||||||
@ -436,6 +433,17 @@ class Node {
|
|||||||
void TotalMaximumBufferedBytesHelper(
|
void TotalMaximumBufferedBytesHelper(
|
||||||
absl::flat_hash_map<string, double>* total_bytes) const;
|
absl::flat_hash_map<string, double>* total_bytes) const;
|
||||||
|
|
||||||
|
// Stores the time passed to the last call to `Node::record_start()` on the
|
||||||
|
// current thread.
|
||||||
|
//
|
||||||
|
// NOTE: This thread-local variable is shared between all instances of `Node`
|
||||||
|
// on which the same thread calls `record_start()` or `record_stop()`. It
|
||||||
|
// relies on the invariant that at most one `Node` can be "active" on a
|
||||||
|
// particular thread at any time. Therefore if `n->record_start()` is called
|
||||||
|
// on thread `t`, then `n->record_stop()` must be called before another call
|
||||||
|
// to `Node::record_start()` (for any node).
|
||||||
|
static thread_local int64 work_start_; // Will be initialized to zero.
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
const int64 id_;
|
const int64 id_;
|
||||||
const string name_;
|
const string name_;
|
||||||
@ -454,7 +462,6 @@ class Node {
|
|||||||
Metrics metrics_;
|
Metrics metrics_;
|
||||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
|
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
|
||||||
TF_GUARDED_BY(mu_);
|
TF_GUARDED_BY(mu_);
|
||||||
absl::flat_hash_map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
|
|
||||||
|
|
||||||
// Statistic of inputs processing time history.
|
// Statistic of inputs processing time history.
|
||||||
double input_processing_time_sum_ = 0.0L;
|
double input_processing_time_sum_ = 0.0L;
|
||||||
|
@ -747,7 +747,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
|||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||||
|
// should not be necessary, but the `kImpl` added to the prefix passed
|
||||||
|
// to `iterator_` when it was created prevents the model from identifying
|
||||||
|
// this iterator as the output of `iterator_`.
|
||||||
|
RecordStop(ctx);
|
||||||
|
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
|
RecordStart(ctx);
|
||||||
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
Loading…
Reference in New Issue
Block a user