[tf.data] Use thread-local start time instead of thread-id map in model::Node.
				
					
				
			This change relies on the invariant that `Node::record_stop()` always being called before `Node::record_start()`. PiperOrigin-RevId: 307689582 Change-Id: I52a48b8e52646eee3611ecea1789d91a90786978
This commit is contained in:
		
							parent
							
								
									ec3f58bc3c
								
							
						
					
					
						commit
						75940041fc
					
				@ -636,6 +636,8 @@ class Unknown : public Node {
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
thread_local int64 Node::work_start_;
 | 
			
		||||
 | 
			
		||||
std::shared_ptr<Parameter> MakeParameter(const string& name,
 | 
			
		||||
                                         std::shared_ptr<SharedState> state,
 | 
			
		||||
                                         double min, double max) {
 | 
			
		||||
 | 
			
		||||
@ -236,18 +236,15 @@ class Node {
 | 
			
		||||
 | 
			
		||||
  // Records that a node thread has started executing.
 | 
			
		||||
  void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
 | 
			
		||||
    mutex_lock l(mu_);
 | 
			
		||||
    work_start_[std::this_thread::get_id()] = time_nanos;
 | 
			
		||||
    DCHECK_EQ(work_start_, 0);
 | 
			
		||||
    work_start_ = time_nanos;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Records that a node thread has stopped executing.
 | 
			
		||||
  void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
 | 
			
		||||
    mutex_lock l(mu_);
 | 
			
		||||
    std::thread::id tid = std::this_thread::get_id();
 | 
			
		||||
    auto iter = work_start_.find(tid);
 | 
			
		||||
    if (iter != work_start_.end()) {
 | 
			
		||||
      processing_time_ += time_nanos - iter->second;
 | 
			
		||||
      work_start_.erase(iter);
 | 
			
		||||
    if (work_start_ != 0) {
 | 
			
		||||
      processing_time_ += time_nanos - work_start_;
 | 
			
		||||
      work_start_ = 0;
 | 
			
		||||
    } else {
 | 
			
		||||
      VLOG(1) << "Encountered a stop event without a matching start event.";
 | 
			
		||||
    }
 | 
			
		||||
@ -436,6 +433,17 @@ class Node {
 | 
			
		||||
  void TotalMaximumBufferedBytesHelper(
 | 
			
		||||
      absl::flat_hash_map<string, double>* total_bytes) const;
 | 
			
		||||
 | 
			
		||||
  // Stores the time passed to the last call to `Node::record_start()` on the
 | 
			
		||||
  // current thread.
 | 
			
		||||
  //
 | 
			
		||||
  // NOTE: This thread-local variable is shared between all instances of `Node`
 | 
			
		||||
  // on which the same thread calls `record_start()` or `record_stop()`. It
 | 
			
		||||
  // relies on the invariant that at most one `Node` can be "active" on a
 | 
			
		||||
  // particular thread at any time. Therefore if `n->record_start()` is called
 | 
			
		||||
  // on thread `t`, then `n->record_stop()` must be called before another call
 | 
			
		||||
  // to `Node::record_start()` (for any node).
 | 
			
		||||
  static thread_local int64 work_start_;  // Will be initialized to zero.
 | 
			
		||||
 | 
			
		||||
  mutable mutex mu_;
 | 
			
		||||
  const int64 id_;
 | 
			
		||||
  const string name_;
 | 
			
		||||
@ -454,7 +462,6 @@ class Node {
 | 
			
		||||
  Metrics metrics_;
 | 
			
		||||
  absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
 | 
			
		||||
      TF_GUARDED_BY(mu_);
 | 
			
		||||
  absl::flat_hash_map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
 | 
			
		||||
 | 
			
		||||
  // Statistic of inputs processing time history.
 | 
			
		||||
  double input_processing_time_sum_ = 0.0L;
 | 
			
		||||
 | 
			
		||||
@ -747,7 +747,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
 | 
			
		||||
                           std::vector<Tensor>* out_tensors,
 | 
			
		||||
                           bool* end_of_sequence) override {
 | 
			
		||||
      mutex_lock l(mu_);
 | 
			
		||||
      return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
 | 
			
		||||
      // TODO(b/154341936): Explicitly stopping and starting this iterator
 | 
			
		||||
      // should not be necessary, but the `kImpl` added to the prefix passed
 | 
			
		||||
      // to `iterator_` when it was created prevents the model from identifying
 | 
			
		||||
      // this iterator as the output of `iterator_`.
 | 
			
		||||
      RecordStop(ctx);
 | 
			
		||||
      Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
 | 
			
		||||
      RecordStart(ctx);
 | 
			
		||||
      return s;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
   protected:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user