[tf.data] Improvements to the performance modeling framework.
This CL switches from using iterator prefix for identifying the parent node in the model tree when a node is constructed to directly passing a parent pointer to the constructor. In addition, this CL makes the `IteratorBase::InitializeBase` method public, which makes it possible to fix `cache` and `snapshot` implementations to reflect their use of nested iterators in the model tree. PiperOrigin-RevId: 314570434 Change-Id: Ide0b37f404077938ad8dc4fbbd91489b7197c6e1
This commit is contained in:
		
							parent
							
								
									e2d7d94549
								
							
						
					
					
						commit
						33b3e6cd06
					
				| @ -336,8 +336,7 @@ bool GraphDefBuilderWrapper::HasAttr(const string& name, | ||||
| } | ||||
| 
 | ||||
| Status IteratorBase::InitializeBase(IteratorContext* ctx, | ||||
|                                     const IteratorBase* parent, | ||||
|                                     const string& output_prefix) { | ||||
|                                     const IteratorBase* parent) { | ||||
|   parent_ = parent; | ||||
|   id_ = | ||||
|       Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this)); | ||||
| @ -349,9 +348,8 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx, | ||||
|     auto factory = [ctx, this](model::Node::Args args) { | ||||
|       return CreateNode(ctx, std::move(args)); | ||||
|     }; | ||||
|     model->AddNode(std::move(factory), prefix(), output_prefix, &node_); | ||||
|     cleanup_fns_.push_back( | ||||
|         [model, prefix = prefix()]() { model->RemoveNode(prefix); }); | ||||
|     model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_); | ||||
|     cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); }); | ||||
|   } | ||||
|   return Status::OK(); | ||||
| } | ||||
| @ -418,7 +416,7 @@ Status DatasetBase::MakeIterator( | ||||
|     const string& output_prefix, | ||||
|     std::unique_ptr<IteratorBase>* iterator) const { | ||||
|   *iterator = MakeIteratorInternal(output_prefix); | ||||
|   Status s = (*iterator)->InitializeBase(ctx, parent, output_prefix); | ||||
|   Status s = (*iterator)->InitializeBase(ctx, parent); | ||||
|   if (s.ok()) { | ||||
|     s.Update((*iterator)->Initialize(ctx)); | ||||
|   } | ||||
|  | ||||
| @ -610,6 +610,9 @@ class IteratorBase { | ||||
|   // properly propagate errors.
 | ||||
|   virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); } | ||||
| 
 | ||||
|   // Performs initialization of the base iterator.
 | ||||
|   Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent); | ||||
| 
 | ||||
|   // Saves the state of this iterator.
 | ||||
|   virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) { | ||||
|     return SaveInternal(ctx, writer); | ||||
| @ -673,10 +676,6 @@ class IteratorBase { | ||||
|   friend class DatasetBase; | ||||
|   friend class DatasetBaseIterator;  // for access to `node_`
 | ||||
| 
 | ||||
|   // Performs initialization of the base iterator.
 | ||||
|   Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent, | ||||
|                         const string& output_prefix); | ||||
| 
 | ||||
|   std::vector<std::function<void()>> cleanup_fns_; | ||||
|   std::shared_ptr<model::Node> node_ = nullptr; | ||||
|   const IteratorBase* parent_ = nullptr;  // Not owned.
 | ||||
|  | ||||
| @ -1089,22 +1089,18 @@ Node::NodeVector Node::CollectNodes(TraversalOrder order) const | ||||
|   NodeVector node_vector; | ||||
|   std::list<std::shared_ptr<Node>> temp_list; | ||||
| 
 | ||||
|   { | ||||
|     for (auto& input : inputs_) { | ||||
|       node_vector.push_back(input); | ||||
|       temp_list.push_back(input); | ||||
|     } | ||||
|   for (auto& input : inputs_) { | ||||
|     node_vector.push_back(input); | ||||
|     temp_list.push_back(input); | ||||
|   } | ||||
| 
 | ||||
|   while (!temp_list.empty()) { | ||||
|     auto cur_node = temp_list.front(); | ||||
|     temp_list.pop_front(); | ||||
|     { | ||||
|       tf_shared_lock l(cur_node->mu_); | ||||
|       for (auto& input : cur_node->inputs_) { | ||||
|         node_vector.push_back(input); | ||||
|         temp_list.push_back(input); | ||||
|       } | ||||
|     tf_shared_lock l(cur_node->mu_); | ||||
|     for (auto& input : cur_node->inputs_) { | ||||
|       node_vector.push_back(input); | ||||
|       temp_list.push_back(input); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
| @ -1222,46 +1218,41 @@ void Node::TotalMaximumBufferedBytesHelper( | ||||
| } | ||||
| 
 | ||||
| void Model::AddNode(Node::Factory factory, const string& name, | ||||
|                     const string& output_name, | ||||
|                     std::shared_ptr<Node> parent, | ||||
|                     std::shared_ptr<Node>* out_node) { | ||||
|   // The name captures the sequence of iterators joined by `::`. We use the full
 | ||||
|   // sequence as the key in the lookup table, but only the last element of the
 | ||||
|   // sequence as the name node.
 | ||||
|   std::vector<string> tokens = | ||||
|       str_util::Split(name, ':', str_util::SkipEmpty()); | ||||
|   // The output name might contain an index. We need to strip it to make it
 | ||||
|   // possible for the model to successfully identify the output node.
 | ||||
|   string sanitized_output_name = output_name; | ||||
|   if (str_util::EndsWith(output_name, "]")) { | ||||
|     sanitized_output_name = output_name.substr(0, output_name.rfind('[')); | ||||
|   } | ||||
|   std::shared_ptr<Node> output; | ||||
|   // The name captures the sequence of iterators joined by `::`. We only use the
 | ||||
|   // last element of the sequence as the name node.
 | ||||
|   auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back(); | ||||
|   mutex_lock l(mu_); | ||||
|   auto it = lookup_table_.find(sanitized_output_name); | ||||
|   if (it != lookup_table_.end()) { | ||||
|     output = it->second; | ||||
|   } | ||||
|   std::shared_ptr<Node> node = factory({id_counter_++, tokens.back(), output}); | ||||
|   std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent}); | ||||
|   if (!output_) { | ||||
|     output_ = node; | ||||
|   } | ||||
|   if (output) { | ||||
|   if (parent) { | ||||
|     VLOG(3) << "Adding " << node->long_name() << " as input for " | ||||
|             << output->long_name(); | ||||
|     output->add_input(node); | ||||
|             << parent->long_name(); | ||||
|     parent->add_input(node); | ||||
|   } else { | ||||
|     VLOG(3) << "Adding " << node->long_name(); | ||||
|   } | ||||
|   collect_resource_usage_ = | ||||
|       collect_resource_usage_ || node->has_tunable_parameters(); | ||||
|   lookup_table_.insert(std::make_pair(name, node)); | ||||
|   *out_node = node; | ||||
|   *out_node = std::move(node); | ||||
| } | ||||
| 
 | ||||
| void Model::FlushMetrics() { | ||||
|   tf_shared_lock l(mu_); | ||||
|   for (const auto& pair : lookup_table_) { | ||||
|     pair.second->FlushMetrics(); | ||||
|   std::deque<std::shared_ptr<Node>> queue; | ||||
|   { | ||||
|     tf_shared_lock l(mu_); | ||||
|     if (output_) queue.push_back(output_); | ||||
|   } | ||||
|   while (!queue.empty()) { | ||||
|     auto node = queue.front(); | ||||
|     queue.pop_front(); | ||||
|     node->FlushMetrics(); | ||||
|     for (auto input : node->inputs()) { | ||||
|       queue.push_back(input); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| @ -1277,16 +1268,14 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| void Model::RemoveNode(const string& name) { | ||||
| void Model::RemoveNode(std::shared_ptr<Node> node) { | ||||
|   mutex_lock l(mu_); | ||||
|   auto node = gtl::FindOrNull(lookup_table_, name); | ||||
|   if (node) { | ||||
|     if ((*node)->output()) { | ||||
|       (*node)->output()->remove_input(*node); | ||||
|     if (node->output()) { | ||||
|       node->output()->remove_input(node); | ||||
|     } | ||||
|     VLOG(3) << "Removing " << (*node)->long_name(); | ||||
|     VLOG(3) << "Removing " << node->long_name(); | ||||
|   } | ||||
|   lookup_table_.erase(name); | ||||
| } | ||||
| 
 | ||||
| absl::flat_hash_map<string, std::shared_ptr<Parameter>> | ||||
|  | ||||
| @ -274,6 +274,7 @@ class Node { | ||||
| 
 | ||||
|   // Records that a node thread has stopped executing.
 | ||||
|   void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) { | ||||
|     // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here.
 | ||||
|     if (work_start_ != 0) { | ||||
|       processing_time_ += time_nanos - work_start_; | ||||
|       work_start_ = 0; | ||||
| @ -598,10 +599,9 @@ class Model { | ||||
|   // Indicates whether to collect resource usage.
 | ||||
|   bool collect_resource_usage() const { return collect_resource_usage_; } | ||||
| 
 | ||||
|   // Adds a node with the given name and given output. The method returns
 | ||||
|   // a pointer to the node but does not transfer ownership.
 | ||||
|   // Adds a node with the given name and given parent.
 | ||||
|   void AddNode(Node::Factory factory, const string& name, | ||||
|                const string& output_name, std::shared_ptr<Node>* out_node) | ||||
|                std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node) | ||||
|       TF_LOCKS_EXCLUDED(mu_); | ||||
| 
 | ||||
|   // Flushes metrics record by the model.
 | ||||
| @ -612,7 +612,7 @@ class Model { | ||||
|       TF_LOCKS_EXCLUDED(mu_); | ||||
| 
 | ||||
|   // Removes the given node.
 | ||||
|   void RemoveNode(const string& name) TF_LOCKS_EXCLUDED(mu_); | ||||
|   void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_); | ||||
| 
 | ||||
|  private: | ||||
|   // Collects tunable parameters in the tree rooted in the given node, returning
 | ||||
| @ -670,8 +670,6 @@ class Model { | ||||
|   mutex mu_; | ||||
|   int64 id_counter_ TF_GUARDED_BY(mu_) = 1; | ||||
|   std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_); | ||||
|   absl::flat_hash_map<string, std::shared_ptr<Node>> lookup_table_ | ||||
|       TF_GUARDED_BY(mu_); | ||||
| 
 | ||||
|   // Indicates whether the modeling framework should collect resource usage
 | ||||
|   // (e.g. CPU, memory). The logic for collecting this information assumes that
 | ||||
|  | ||||
| @ -130,12 +130,11 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { | ||||
|       } else { | ||||
|         mode_ = Mode::write; | ||||
|       } | ||||
|       InitializeIterator(); | ||||
|     } | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|       mutex_lock l(mu_); | ||||
|       return iterator_->Initialize(ctx); | ||||
|       return InitializeIterator(ctx); | ||||
|     } | ||||
| 
 | ||||
|     Status GetNextInternal(IteratorContext* ctx, | ||||
| @ -180,8 +179,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { | ||||
|             << "mistake, please remove the above file and try running again."; | ||||
|         mode_ = Mode::read; | ||||
|       } | ||||
|       InitializeIterator(); | ||||
|       TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); | ||||
|       TF_RETURN_IF_ERROR(InitializeIterator(ctx)); | ||||
|       return RestoreInput(ctx, reader, iterator_); | ||||
|     } | ||||
| 
 | ||||
| @ -560,15 +558,16 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { | ||||
|       bool iterator_restored_ TF_GUARDED_BY(mu_); | ||||
|     };  // FileReaderIterator
 | ||||
| 
 | ||||
|     void InitializeIterator() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { | ||||
|       // We intentionally use the same prefix for both `FileReaderIterator`
 | ||||
|       // and `FileWriterIterator`. Since at any time there will be at most
 | ||||
|       // one of them alive, there should be no conflicts. This allows both
 | ||||
|       // iterators to use a common key for `cur_index`. We leverage this
 | ||||
|       // in the corner case when this iterator is restored from an old
 | ||||
|       // checkpoint in `write` mode and the cache has been completely
 | ||||
|       // flushed to disk since then. In that case we simply build a
 | ||||
|       // `FileReaderIterator` and seek to the `cur_index`.
 | ||||
|     Status InitializeIterator(IteratorContext* ctx) | ||||
|         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { | ||||
|       // We intentionally use the same prefix for both `FileReaderIterator` and
 | ||||
|       // `FileWriterIterator`. Since at any time there will be at most one of
 | ||||
|       // them alive, there should be no conflicts. This allows both iterators to
 | ||||
|       // use a common key for `cur_index`. We leverage this in the corner case
 | ||||
|       // when this iterator is restored from an old checkpoint in `write` mode
 | ||||
|       // and the cache has been completely flushed to disk since then. In that
 | ||||
|       // case we simply build a `FileReaderIterator` and seek to the
 | ||||
|       // `cur_index`.
 | ||||
|       switch (mode_) { | ||||
|         case Mode::read: | ||||
|           iterator_ = | ||||
| @ -580,6 +579,8 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { | ||||
|               absl::make_unique<FileWriterIterator>(FileWriterIterator::Params{ | ||||
|                   dataset(), strings::StrCat(prefix(), kImpl)}); | ||||
|       } | ||||
|       TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this)); | ||||
|       return iterator_->Initialize(ctx); | ||||
|     } | ||||
| 
 | ||||
|     mutex mu_; | ||||
| @ -741,22 +742,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { | ||||
| 
 | ||||
|     Status Initialize(IteratorContext* ctx) override { | ||||
|       mutex_lock l(mu_); | ||||
|       InitializeIterator(); | ||||
|       return iterator_->Initialize(ctx); | ||||
|       return InitializeIterator(ctx); | ||||
|     } | ||||
| 
 | ||||
|     Status GetNextInternal(IteratorContext* ctx, | ||||
|                            std::vector<Tensor>* out_tensors, | ||||
|                            bool* end_of_sequence) override { | ||||
|       mutex_lock l(mu_); | ||||
|       // TODO(b/154341936): Explicitly stopping and starting this iterator
 | ||||
|       // should not be necessary, but the `kImpl` added to the prefix passed
 | ||||
|       // to `iterator_` when it was created prevents the model from identifying
 | ||||
|       // this iterator as the output of `iterator_`.
 | ||||
|       RecordStop(ctx); | ||||
|       Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence); | ||||
|       RecordStart(ctx); | ||||
|       return s; | ||||
|       return iterator_->GetNext(ctx, out_tensors, end_of_sequence); | ||||
|     } | ||||
| 
 | ||||
|    protected: | ||||
| @ -789,8 +782,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { | ||||
|                          [this](const string& s) { return full_name(s); })); | ||||
|         cache_->Complete(std::move(temp_cache)); | ||||
|       } | ||||
|       InitializeIterator(); | ||||
|       TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); | ||||
|       TF_RETURN_IF_ERROR(InitializeIterator(ctx)); | ||||
|       return RestoreInput(ctx, reader, iterator_); | ||||
|     } | ||||
| 
 | ||||
| @ -946,7 +938,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { | ||||
|       size_t index_ TF_GUARDED_BY(mu_); | ||||
|     };  // MemoryReaderIterator
 | ||||
| 
 | ||||
|     void InitializeIterator() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { | ||||
|     Status InitializeIterator(IteratorContext* ctx) | ||||
|         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { | ||||
|       if (cache_->IsCompleted()) { | ||||
|         iterator_ = absl::make_unique<MemoryReaderIterator>( | ||||
|             MemoryReaderIterator::Params{dataset(), | ||||
| @ -958,6 +951,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { | ||||
|                                          strings::StrCat(prefix(), kImpl)}, | ||||
|             cache_); | ||||
|       } | ||||
|       TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this)); | ||||
|       return iterator_->Initialize(ctx); | ||||
|     } | ||||
| 
 | ||||
|     mutex mu_; | ||||
|  | ||||
| @ -536,16 +536,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal( | ||||
|   if (iterator_ == nullptr) { | ||||
|     TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr)); | ||||
|   } | ||||
|   // TODO(b/154341936): Explicitly stopping and starting this iterator
 | ||||
|   // should not be necessary, but the additional
 | ||||
|   // `{Reader,Writer,Passthrough}::kIteratorName` added to the prefix passed to
 | ||||
|   // `iterator_` when it was created prevents the model from identifying this
 | ||||
|   // iterator as the output of `iterator_`.
 | ||||
|   RecordStop(ctx); | ||||
|   Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence); | ||||
|   index_++; | ||||
|   RecordStart(ctx); | ||||
|   return s; | ||||
|   return iterator_->GetNext(ctx, out_tensors, end_of_sequence); | ||||
| } | ||||
| 
 | ||||
| Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator( | ||||
| @ -611,6 +603,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator( | ||||
|           dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)}); | ||||
|       break; | ||||
|   } | ||||
|   TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this)); | ||||
|   return iterator_->Initialize(ctx); | ||||
| } | ||||
| 
 | ||||
| @ -1352,6 +1345,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { | ||||
|                     dataset(), absl::StrCat(prefix(), "PassthroughImpl")}); | ||||
|             break; | ||||
|         } | ||||
|         TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this)); | ||||
|         return iterator_->Initialize(ctx); | ||||
|       } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user