[tf.data] Improvements to the performance modeling framework.
This CL switches from using iterator prefix for identifying the parent node in the model tree when a node is constructed to directly passing a parent pointer to the constructor. In addition, this CL makes the `IteratorBase::InitializeBase` method public, which makes it possible to fix `cache` and `snapshot` implementations to reflect their use of nested iterators in the model tree. PiperOrigin-RevId: 314570434 Change-Id: Ide0b37f404077938ad8dc4fbbd91489b7197c6e1
This commit is contained in:
		
							parent
							
								
									e2d7d94549
								
							
						
					
					
						commit
						33b3e6cd06
					
				@ -336,8 +336,7 @@ bool GraphDefBuilderWrapper::HasAttr(const string& name,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status IteratorBase::InitializeBase(IteratorContext* ctx,
 | 
			
		||||
                                    const IteratorBase* parent,
 | 
			
		||||
                                    const string& output_prefix) {
 | 
			
		||||
                                    const IteratorBase* parent) {
 | 
			
		||||
  parent_ = parent;
 | 
			
		||||
  id_ =
 | 
			
		||||
      Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
 | 
			
		||||
@ -349,9 +348,8 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx,
 | 
			
		||||
    auto factory = [ctx, this](model::Node::Args args) {
 | 
			
		||||
      return CreateNode(ctx, std::move(args));
 | 
			
		||||
    };
 | 
			
		||||
    model->AddNode(std::move(factory), prefix(), output_prefix, &node_);
 | 
			
		||||
    cleanup_fns_.push_back(
 | 
			
		||||
        [model, prefix = prefix()]() { model->RemoveNode(prefix); });
 | 
			
		||||
    model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
 | 
			
		||||
    cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
 | 
			
		||||
  }
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
@ -418,7 +416,7 @@ Status DatasetBase::MakeIterator(
 | 
			
		||||
    const string& output_prefix,
 | 
			
		||||
    std::unique_ptr<IteratorBase>* iterator) const {
 | 
			
		||||
  *iterator = MakeIteratorInternal(output_prefix);
 | 
			
		||||
  Status s = (*iterator)->InitializeBase(ctx, parent, output_prefix);
 | 
			
		||||
  Status s = (*iterator)->InitializeBase(ctx, parent);
 | 
			
		||||
  if (s.ok()) {
 | 
			
		||||
    s.Update((*iterator)->Initialize(ctx));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -610,6 +610,9 @@ class IteratorBase {
 | 
			
		||||
  // properly propagate errors.
 | 
			
		||||
  virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
 | 
			
		||||
 | 
			
		||||
  // Performs initialization of the base iterator.
 | 
			
		||||
  Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
 | 
			
		||||
 | 
			
		||||
  // Saves the state of this iterator.
 | 
			
		||||
  virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
 | 
			
		||||
    return SaveInternal(ctx, writer);
 | 
			
		||||
@ -673,10 +676,6 @@ class IteratorBase {
 | 
			
		||||
  friend class DatasetBase;
 | 
			
		||||
  friend class DatasetBaseIterator;  // for access to `node_`
 | 
			
		||||
 | 
			
		||||
  // Performs initialization of the base iterator.
 | 
			
		||||
  Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent,
 | 
			
		||||
                        const string& output_prefix);
 | 
			
		||||
 | 
			
		||||
  std::vector<std::function<void()>> cleanup_fns_;
 | 
			
		||||
  std::shared_ptr<model::Node> node_ = nullptr;
 | 
			
		||||
  const IteratorBase* parent_ = nullptr;  // Not owned.
 | 
			
		||||
 | 
			
		||||
@ -1089,22 +1089,18 @@ Node::NodeVector Node::CollectNodes(TraversalOrder order) const
 | 
			
		||||
  NodeVector node_vector;
 | 
			
		||||
  std::list<std::shared_ptr<Node>> temp_list;
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    for (auto& input : inputs_) {
 | 
			
		||||
      node_vector.push_back(input);
 | 
			
		||||
      temp_list.push_back(input);
 | 
			
		||||
    }
 | 
			
		||||
  for (auto& input : inputs_) {
 | 
			
		||||
    node_vector.push_back(input);
 | 
			
		||||
    temp_list.push_back(input);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  while (!temp_list.empty()) {
 | 
			
		||||
    auto cur_node = temp_list.front();
 | 
			
		||||
    temp_list.pop_front();
 | 
			
		||||
    {
 | 
			
		||||
      tf_shared_lock l(cur_node->mu_);
 | 
			
		||||
      for (auto& input : cur_node->inputs_) {
 | 
			
		||||
        node_vector.push_back(input);
 | 
			
		||||
        temp_list.push_back(input);
 | 
			
		||||
      }
 | 
			
		||||
    tf_shared_lock l(cur_node->mu_);
 | 
			
		||||
    for (auto& input : cur_node->inputs_) {
 | 
			
		||||
      node_vector.push_back(input);
 | 
			
		||||
      temp_list.push_back(input);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -1222,46 +1218,41 @@ void Node::TotalMaximumBufferedBytesHelper(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Model::AddNode(Node::Factory factory, const string& name,
 | 
			
		||||
                    const string& output_name,
 | 
			
		||||
                    std::shared_ptr<Node> parent,
 | 
			
		||||
                    std::shared_ptr<Node>* out_node) {
 | 
			
		||||
  // The name captures the sequence of iterators joined by `::`. We use the full
 | 
			
		||||
  // sequence as the key in the lookup table, but only the last element of the
 | 
			
		||||
  // sequence as the name node.
 | 
			
		||||
  std::vector<string> tokens =
 | 
			
		||||
      str_util::Split(name, ':', str_util::SkipEmpty());
 | 
			
		||||
  // The output name might contain an index. We need to strip it to make it
 | 
			
		||||
  // possible for the model to successfully identify the output node.
 | 
			
		||||
  string sanitized_output_name = output_name;
 | 
			
		||||
  if (str_util::EndsWith(output_name, "]")) {
 | 
			
		||||
    sanitized_output_name = output_name.substr(0, output_name.rfind('['));
 | 
			
		||||
  }
 | 
			
		||||
  std::shared_ptr<Node> output;
 | 
			
		||||
  // The name captures the sequence of iterators joined by `::`. We only use the
 | 
			
		||||
  // last element of the sequence as the name node.
 | 
			
		||||
  auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
 | 
			
		||||
  mutex_lock l(mu_);
 | 
			
		||||
  auto it = lookup_table_.find(sanitized_output_name);
 | 
			
		||||
  if (it != lookup_table_.end()) {
 | 
			
		||||
    output = it->second;
 | 
			
		||||
  }
 | 
			
		||||
  std::shared_ptr<Node> node = factory({id_counter_++, tokens.back(), output});
 | 
			
		||||
  std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
 | 
			
		||||
  if (!output_) {
 | 
			
		||||
    output_ = node;
 | 
			
		||||
  }
 | 
			
		||||
  if (output) {
 | 
			
		||||
  if (parent) {
 | 
			
		||||
    VLOG(3) << "Adding " << node->long_name() << " as input for "
 | 
			
		||||
            << output->long_name();
 | 
			
		||||
    output->add_input(node);
 | 
			
		||||
            << parent->long_name();
 | 
			
		||||
    parent->add_input(node);
 | 
			
		||||
  } else {
 | 
			
		||||
    VLOG(3) << "Adding " << node->long_name();
 | 
			
		||||
  }
 | 
			
		||||
  collect_resource_usage_ =
 | 
			
		||||
      collect_resource_usage_ || node->has_tunable_parameters();
 | 
			
		||||
  lookup_table_.insert(std::make_pair(name, node));
 | 
			
		||||
  *out_node = node;
 | 
			
		||||
  *out_node = std::move(node);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Model::FlushMetrics() {
 | 
			
		||||
  tf_shared_lock l(mu_);
 | 
			
		||||
  for (const auto& pair : lookup_table_) {
 | 
			
		||||
    pair.second->FlushMetrics();
 | 
			
		||||
  std::deque<std::shared_ptr<Node>> queue;
 | 
			
		||||
  {
 | 
			
		||||
    tf_shared_lock l(mu_);
 | 
			
		||||
    if (output_) queue.push_back(output_);
 | 
			
		||||
  }
 | 
			
		||||
  while (!queue.empty()) {
 | 
			
		||||
    auto node = queue.front();
 | 
			
		||||
    queue.pop_front();
 | 
			
		||||
    node->FlushMetrics();
 | 
			
		||||
    for (auto input : node->inputs()) {
 | 
			
		||||
      queue.push_back(input);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1277,16 +1268,14 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Model::RemoveNode(const string& name) {
 | 
			
		||||
void Model::RemoveNode(std::shared_ptr<Node> node) {
 | 
			
		||||
  mutex_lock l(mu_);
 | 
			
		||||
  auto node = gtl::FindOrNull(lookup_table_, name);
 | 
			
		||||
  if (node) {
 | 
			
		||||
    if ((*node)->output()) {
 | 
			
		||||
      (*node)->output()->remove_input(*node);
 | 
			
		||||
    if (node->output()) {
 | 
			
		||||
      node->output()->remove_input(node);
 | 
			
		||||
    }
 | 
			
		||||
    VLOG(3) << "Removing " << (*node)->long_name();
 | 
			
		||||
    VLOG(3) << "Removing " << node->long_name();
 | 
			
		||||
  }
 | 
			
		||||
  lookup_table_.erase(name);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
 | 
			
		||||
 | 
			
		||||
@ -274,6 +274,7 @@ class Node {
 | 
			
		||||
 | 
			
		||||
  // Records that a node thread has stopped executing.
 | 
			
		||||
  void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
 | 
			
		||||
    // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here.
 | 
			
		||||
    if (work_start_ != 0) {
 | 
			
		||||
      processing_time_ += time_nanos - work_start_;
 | 
			
		||||
      work_start_ = 0;
 | 
			
		||||
@ -598,10 +599,9 @@ class Model {
 | 
			
		||||
  // Indicates whether to collect resource usage.
 | 
			
		||||
  bool collect_resource_usage() const { return collect_resource_usage_; }
 | 
			
		||||
 | 
			
		||||
  // Adds a node with the given name and given output. The method returns
 | 
			
		||||
  // a pointer to the node but does not transfer ownership.
 | 
			
		||||
  // Adds a node with the given name and given parent.
 | 
			
		||||
  void AddNode(Node::Factory factory, const string& name,
 | 
			
		||||
               const string& output_name, std::shared_ptr<Node>* out_node)
 | 
			
		||||
               std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
 | 
			
		||||
      TF_LOCKS_EXCLUDED(mu_);
 | 
			
		||||
 | 
			
		||||
  // Flushes metrics record by the model.
 | 
			
		||||
@ -612,7 +612,7 @@ class Model {
 | 
			
		||||
      TF_LOCKS_EXCLUDED(mu_);
 | 
			
		||||
 | 
			
		||||
  // Removes the given node.
 | 
			
		||||
  void RemoveNode(const string& name) TF_LOCKS_EXCLUDED(mu_);
 | 
			
		||||
  void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Collects tunable parameters in the tree rooted in the given node, returning
 | 
			
		||||
@ -670,8 +670,6 @@ class Model {
 | 
			
		||||
  mutex mu_;
 | 
			
		||||
  int64 id_counter_ TF_GUARDED_BY(mu_) = 1;
 | 
			
		||||
  std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_);
 | 
			
		||||
  absl::flat_hash_map<string, std::shared_ptr<Node>> lookup_table_
 | 
			
		||||
      TF_GUARDED_BY(mu_);
 | 
			
		||||
 | 
			
		||||
  // Indicates whether the modeling framework should collect resource usage
 | 
			
		||||
  // (e.g. CPU, memory). The logic for collecting this information assumes that
 | 
			
		||||
 | 
			
		||||
@ -130,12 +130,11 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
 | 
			
		||||
      } else {
 | 
			
		||||
        mode_ = Mode::write;
 | 
			
		||||
      }
 | 
			
		||||
      InitializeIterator();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Status Initialize(IteratorContext* ctx) override {
 | 
			
		||||
      mutex_lock l(mu_);
 | 
			
		||||
      return iterator_->Initialize(ctx);
 | 
			
		||||
      return InitializeIterator(ctx);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Status GetNextInternal(IteratorContext* ctx,
 | 
			
		||||
@ -180,8 +179,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
 | 
			
		||||
            << "mistake, please remove the above file and try running again.";
 | 
			
		||||
        mode_ = Mode::read;
 | 
			
		||||
      }
 | 
			
		||||
      InitializeIterator();
 | 
			
		||||
      TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
 | 
			
		||||
      TF_RETURN_IF_ERROR(InitializeIterator(ctx));
 | 
			
		||||
      return RestoreInput(ctx, reader, iterator_);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -560,15 +558,16 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
 | 
			
		||||
      bool iterator_restored_ TF_GUARDED_BY(mu_);
 | 
			
		||||
    };  // FileReaderIterator
 | 
			
		||||
 | 
			
		||||
    void InitializeIterator() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
 | 
			
		||||
      // We intentionally use the same prefix for both `FileReaderIterator`
 | 
			
		||||
      // and `FileWriterIterator`. Since at any time there will be at most
 | 
			
		||||
      // one of them alive, there should be no conflicts. This allows both
 | 
			
		||||
      // iterators to use a common key for `cur_index`. We leverage this
 | 
			
		||||
      // in the corner case when this iterator is restored from an old
 | 
			
		||||
      // checkpoint in `write` mode and the cache has been completely
 | 
			
		||||
      // flushed to disk since then. In that case we simply build a
 | 
			
		||||
      // `FileReaderIterator` and seek to the `cur_index`.
 | 
			
		||||
    Status InitializeIterator(IteratorContext* ctx)
 | 
			
		||||
        TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
 | 
			
		||||
      // We intentionally use the same prefix for both `FileReaderIterator` and
 | 
			
		||||
      // `FileWriterIterator`. Since at any time there will be at most one of
 | 
			
		||||
      // them alive, there should be no conflicts. This allows both iterators to
 | 
			
		||||
      // use a common key for `cur_index`. We leverage this in the corner case
 | 
			
		||||
      // when this iterator is restored from an old checkpoint in `write` mode
 | 
			
		||||
      // and the cache has been completely flushed to disk since then. In that
 | 
			
		||||
      // case we simply build a `FileReaderIterator` and seek to the
 | 
			
		||||
      // `cur_index`.
 | 
			
		||||
      switch (mode_) {
 | 
			
		||||
        case Mode::read:
 | 
			
		||||
          iterator_ =
 | 
			
		||||
@ -580,6 +579,8 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
 | 
			
		||||
              absl::make_unique<FileWriterIterator>(FileWriterIterator::Params{
 | 
			
		||||
                  dataset(), strings::StrCat(prefix(), kImpl)});
 | 
			
		||||
      }
 | 
			
		||||
      TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
 | 
			
		||||
      return iterator_->Initialize(ctx);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    mutex mu_;
 | 
			
		||||
@ -741,22 +742,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
 | 
			
		||||
 | 
			
		||||
    Status Initialize(IteratorContext* ctx) override {
 | 
			
		||||
      mutex_lock l(mu_);
 | 
			
		||||
      InitializeIterator();
 | 
			
		||||
      return iterator_->Initialize(ctx);
 | 
			
		||||
      return InitializeIterator(ctx);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Status GetNextInternal(IteratorContext* ctx,
 | 
			
		||||
                           std::vector<Tensor>* out_tensors,
 | 
			
		||||
                           bool* end_of_sequence) override {
 | 
			
		||||
      mutex_lock l(mu_);
 | 
			
		||||
      // TODO(b/154341936): Explicitly stopping and starting this iterator
 | 
			
		||||
      // should not be necessary, but the `kImpl` added to the prefix passed
 | 
			
		||||
      // to `iterator_` when it was created prevents the model from identifying
 | 
			
		||||
      // this iterator as the output of `iterator_`.
 | 
			
		||||
      RecordStop(ctx);
 | 
			
		||||
      Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
 | 
			
		||||
      RecordStart(ctx);
 | 
			
		||||
      return s;
 | 
			
		||||
      return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
   protected:
 | 
			
		||||
@ -789,8 +782,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
 | 
			
		||||
                         [this](const string& s) { return full_name(s); }));
 | 
			
		||||
        cache_->Complete(std::move(temp_cache));
 | 
			
		||||
      }
 | 
			
		||||
      InitializeIterator();
 | 
			
		||||
      TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
 | 
			
		||||
      TF_RETURN_IF_ERROR(InitializeIterator(ctx));
 | 
			
		||||
      return RestoreInput(ctx, reader, iterator_);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -946,7 +938,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
 | 
			
		||||
      size_t index_ TF_GUARDED_BY(mu_);
 | 
			
		||||
    };  // MemoryReaderIterator
 | 
			
		||||
 | 
			
		||||
    void InitializeIterator() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
 | 
			
		||||
    Status InitializeIterator(IteratorContext* ctx)
 | 
			
		||||
        TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
 | 
			
		||||
      if (cache_->IsCompleted()) {
 | 
			
		||||
        iterator_ = absl::make_unique<MemoryReaderIterator>(
 | 
			
		||||
            MemoryReaderIterator::Params{dataset(),
 | 
			
		||||
@ -958,6 +951,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
 | 
			
		||||
                                         strings::StrCat(prefix(), kImpl)},
 | 
			
		||||
            cache_);
 | 
			
		||||
      }
 | 
			
		||||
      TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
 | 
			
		||||
      return iterator_->Initialize(ctx);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    mutex mu_;
 | 
			
		||||
 | 
			
		||||
@ -536,16 +536,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
 | 
			
		||||
  if (iterator_ == nullptr) {
 | 
			
		||||
    TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
 | 
			
		||||
  }
 | 
			
		||||
  // TODO(b/154341936): Explicitly stopping and starting this iterator
 | 
			
		||||
  // should not be necessary, but the additional
 | 
			
		||||
  // `{Reader,Writer,Passthrough}::kIteratorName` added to the prefix passed to
 | 
			
		||||
  // `iterator_` when it was created prevents the model from identifying this
 | 
			
		||||
  // iterator as the output of `iterator_`.
 | 
			
		||||
  RecordStop(ctx);
 | 
			
		||||
  Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
 | 
			
		||||
  index_++;
 | 
			
		||||
  RecordStart(ctx);
 | 
			
		||||
  return s;
 | 
			
		||||
  return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
 | 
			
		||||
@ -611,6 +603,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
 | 
			
		||||
          dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)});
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
  TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
 | 
			
		||||
  return iterator_->Initialize(ctx);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1352,6 +1345,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
 | 
			
		||||
                    dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
 | 
			
		||||
        return iterator_->Initialize(ctx);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user