[tf.data] Improvements to the performance modeling framework.
This CL switches from using iterator prefix for identifying the parent node in the model tree when a node is constructed to directly passing a parent pointer to the constructor. In addition, this CL makes the `IteratorBase::InitializeBase` method public, which makes it possible to fix `cache` and `snapshot` implementations to reflect their use of nested iterators in the model tree. PiperOrigin-RevId: 314570434 Change-Id: Ide0b37f404077938ad8dc4fbbd91489b7197c6e1
This commit is contained in:
parent
e2d7d94549
commit
33b3e6cd06
tensorflow/core
framework
kernels/data
@ -336,8 +336,7 @@ bool GraphDefBuilderWrapper::HasAttr(const string& name,
|
||||
}
|
||||
|
||||
Status IteratorBase::InitializeBase(IteratorContext* ctx,
|
||||
const IteratorBase* parent,
|
||||
const string& output_prefix) {
|
||||
const IteratorBase* parent) {
|
||||
parent_ = parent;
|
||||
id_ =
|
||||
Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
|
||||
@ -349,9 +348,8 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx,
|
||||
auto factory = [ctx, this](model::Node::Args args) {
|
||||
return CreateNode(ctx, std::move(args));
|
||||
};
|
||||
model->AddNode(std::move(factory), prefix(), output_prefix, &node_);
|
||||
cleanup_fns_.push_back(
|
||||
[model, prefix = prefix()]() { model->RemoveNode(prefix); });
|
||||
model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
|
||||
cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -418,7 +416,7 @@ Status DatasetBase::MakeIterator(
|
||||
const string& output_prefix,
|
||||
std::unique_ptr<IteratorBase>* iterator) const {
|
||||
*iterator = MakeIteratorInternal(output_prefix);
|
||||
Status s = (*iterator)->InitializeBase(ctx, parent, output_prefix);
|
||||
Status s = (*iterator)->InitializeBase(ctx, parent);
|
||||
if (s.ok()) {
|
||||
s.Update((*iterator)->Initialize(ctx));
|
||||
}
|
||||
|
@ -610,6 +610,9 @@ class IteratorBase {
|
||||
// properly propagate errors.
|
||||
virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
|
||||
|
||||
// Performs initialization of the base iterator.
|
||||
Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
|
||||
|
||||
// Saves the state of this iterator.
|
||||
virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||
return SaveInternal(ctx, writer);
|
||||
@ -673,10 +676,6 @@ class IteratorBase {
|
||||
friend class DatasetBase;
|
||||
friend class DatasetBaseIterator; // for access to `node_`
|
||||
|
||||
// Performs initialization of the base iterator.
|
||||
Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent,
|
||||
const string& output_prefix);
|
||||
|
||||
std::vector<std::function<void()>> cleanup_fns_;
|
||||
std::shared_ptr<model::Node> node_ = nullptr;
|
||||
const IteratorBase* parent_ = nullptr; // Not owned.
|
||||
|
@ -1089,22 +1089,18 @@ Node::NodeVector Node::CollectNodes(TraversalOrder order) const
|
||||
NodeVector node_vector;
|
||||
std::list<std::shared_ptr<Node>> temp_list;
|
||||
|
||||
{
|
||||
for (auto& input : inputs_) {
|
||||
node_vector.push_back(input);
|
||||
temp_list.push_back(input);
|
||||
}
|
||||
for (auto& input : inputs_) {
|
||||
node_vector.push_back(input);
|
||||
temp_list.push_back(input);
|
||||
}
|
||||
|
||||
while (!temp_list.empty()) {
|
||||
auto cur_node = temp_list.front();
|
||||
temp_list.pop_front();
|
||||
{
|
||||
tf_shared_lock l(cur_node->mu_);
|
||||
for (auto& input : cur_node->inputs_) {
|
||||
node_vector.push_back(input);
|
||||
temp_list.push_back(input);
|
||||
}
|
||||
tf_shared_lock l(cur_node->mu_);
|
||||
for (auto& input : cur_node->inputs_) {
|
||||
node_vector.push_back(input);
|
||||
temp_list.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1222,46 +1218,41 @@ void Node::TotalMaximumBufferedBytesHelper(
|
||||
}
|
||||
|
||||
void Model::AddNode(Node::Factory factory, const string& name,
|
||||
const string& output_name,
|
||||
std::shared_ptr<Node> parent,
|
||||
std::shared_ptr<Node>* out_node) {
|
||||
// The name captures the sequence of iterators joined by `::`. We use the full
|
||||
// sequence as the key in the lookup table, but only the last element of the
|
||||
// sequence as the name node.
|
||||
std::vector<string> tokens =
|
||||
str_util::Split(name, ':', str_util::SkipEmpty());
|
||||
// The output name might contain an index. We need to strip it to make it
|
||||
// possible for the model to successfully identify the output node.
|
||||
string sanitized_output_name = output_name;
|
||||
if (str_util::EndsWith(output_name, "]")) {
|
||||
sanitized_output_name = output_name.substr(0, output_name.rfind('['));
|
||||
}
|
||||
std::shared_ptr<Node> output;
|
||||
// The name captures the sequence of iterators joined by `::`. We only use the
|
||||
// last element of the sequence as the name node.
|
||||
auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
|
||||
mutex_lock l(mu_);
|
||||
auto it = lookup_table_.find(sanitized_output_name);
|
||||
if (it != lookup_table_.end()) {
|
||||
output = it->second;
|
||||
}
|
||||
std::shared_ptr<Node> node = factory({id_counter_++, tokens.back(), output});
|
||||
std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
|
||||
if (!output_) {
|
||||
output_ = node;
|
||||
}
|
||||
if (output) {
|
||||
if (parent) {
|
||||
VLOG(3) << "Adding " << node->long_name() << " as input for "
|
||||
<< output->long_name();
|
||||
output->add_input(node);
|
||||
<< parent->long_name();
|
||||
parent->add_input(node);
|
||||
} else {
|
||||
VLOG(3) << "Adding " << node->long_name();
|
||||
}
|
||||
collect_resource_usage_ =
|
||||
collect_resource_usage_ || node->has_tunable_parameters();
|
||||
lookup_table_.insert(std::make_pair(name, node));
|
||||
*out_node = node;
|
||||
*out_node = std::move(node);
|
||||
}
|
||||
|
||||
void Model::FlushMetrics() {
|
||||
tf_shared_lock l(mu_);
|
||||
for (const auto& pair : lookup_table_) {
|
||||
pair.second->FlushMetrics();
|
||||
std::deque<std::shared_ptr<Node>> queue;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
if (output_) queue.push_back(output_);
|
||||
}
|
||||
while (!queue.empty()) {
|
||||
auto node = queue.front();
|
||||
queue.pop_front();
|
||||
node->FlushMetrics();
|
||||
for (auto input : node->inputs()) {
|
||||
queue.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1277,16 +1268,14 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
|
||||
}
|
||||
}
|
||||
|
||||
void Model::RemoveNode(const string& name) {
|
||||
void Model::RemoveNode(std::shared_ptr<Node> node) {
|
||||
mutex_lock l(mu_);
|
||||
auto node = gtl::FindOrNull(lookup_table_, name);
|
||||
if (node) {
|
||||
if ((*node)->output()) {
|
||||
(*node)->output()->remove_input(*node);
|
||||
if (node->output()) {
|
||||
node->output()->remove_input(node);
|
||||
}
|
||||
VLOG(3) << "Removing " << (*node)->long_name();
|
||||
VLOG(3) << "Removing " << node->long_name();
|
||||
}
|
||||
lookup_table_.erase(name);
|
||||
}
|
||||
|
||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>>
|
||||
|
@ -274,6 +274,7 @@ class Node {
|
||||
|
||||
// Records that a node thread has stopped executing.
|
||||
void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
||||
// TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here.
|
||||
if (work_start_ != 0) {
|
||||
processing_time_ += time_nanos - work_start_;
|
||||
work_start_ = 0;
|
||||
@ -598,10 +599,9 @@ class Model {
|
||||
// Indicates whether to collect resource usage.
|
||||
bool collect_resource_usage() const { return collect_resource_usage_; }
|
||||
|
||||
// Adds a node with the given name and given output. The method returns
|
||||
// a pointer to the node but does not transfer ownership.
|
||||
// Adds a node with the given name and given parent.
|
||||
void AddNode(Node::Factory factory, const string& name,
|
||||
const string& output_name, std::shared_ptr<Node>* out_node)
|
||||
std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Flushes metrics record by the model.
|
||||
@ -612,7 +612,7 @@ class Model {
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Removes the given node.
|
||||
void RemoveNode(const string& name) TF_LOCKS_EXCLUDED(mu_);
|
||||
void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
// Collects tunable parameters in the tree rooted in the given node, returning
|
||||
@ -670,8 +670,6 @@ class Model {
|
||||
mutex mu_;
|
||||
int64 id_counter_ TF_GUARDED_BY(mu_) = 1;
|
||||
std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<string, std::shared_ptr<Node>> lookup_table_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
||||
// Indicates whether the modeling framework should collect resource usage
|
||||
// (e.g. CPU, memory). The logic for collecting this information assumes that
|
||||
|
@ -130,12 +130,11 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||
} else {
|
||||
mode_ = Mode::write;
|
||||
}
|
||||
InitializeIterator();
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
return iterator_->Initialize(ctx);
|
||||
return InitializeIterator(ctx);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
@ -180,8 +179,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||
<< "mistake, please remove the above file and try running again.";
|
||||
mode_ = Mode::read;
|
||||
}
|
||||
InitializeIterator();
|
||||
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx));
|
||||
return RestoreInput(ctx, reader, iterator_);
|
||||
}
|
||||
|
||||
@ -560,15 +558,16 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||
bool iterator_restored_ TF_GUARDED_BY(mu_);
|
||||
}; // FileReaderIterator
|
||||
|
||||
void InitializeIterator() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// We intentionally use the same prefix for both `FileReaderIterator`
|
||||
// and `FileWriterIterator`. Since at any time there will be at most
|
||||
// one of them alive, there should be no conflicts. This allows both
|
||||
// iterators to use a common key for `cur_index`. We leverage this
|
||||
// in the corner case when this iterator is restored from an old
|
||||
// checkpoint in `write` mode and the cache has been completely
|
||||
// flushed to disk since then. In that case we simply build a
|
||||
// `FileReaderIterator` and seek to the `cur_index`.
|
||||
Status InitializeIterator(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// We intentionally use the same prefix for both `FileReaderIterator` and
|
||||
// `FileWriterIterator`. Since at any time there will be at most one of
|
||||
// them alive, there should be no conflicts. This allows both iterators to
|
||||
// use a common key for `cur_index`. We leverage this in the corner case
|
||||
// when this iterator is restored from an old checkpoint in `write` mode
|
||||
// and the cache has been completely flushed to disk since then. In that
|
||||
// case we simply build a `FileReaderIterator` and seek to the
|
||||
// `cur_index`.
|
||||
switch (mode_) {
|
||||
case Mode::read:
|
||||
iterator_ =
|
||||
@ -580,6 +579,8 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||
absl::make_unique<FileWriterIterator>(FileWriterIterator::Params{
|
||||
dataset(), strings::StrCat(prefix(), kImpl)});
|
||||
}
|
||||
TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
|
||||
return iterator_->Initialize(ctx);
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
@ -741,22 +742,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
InitializeIterator();
|
||||
return iterator_->Initialize(ctx);
|
||||
return InitializeIterator(ctx);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||
// should not be necessary, but the `kImpl` added to the prefix passed
|
||||
// to `iterator_` when it was created prevents the model from identifying
|
||||
// this iterator as the output of `iterator_`.
|
||||
RecordStop(ctx);
|
||||
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
RecordStart(ctx);
|
||||
return s;
|
||||
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -789,8 +782,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
[this](const string& s) { return full_name(s); }));
|
||||
cache_->Complete(std::move(temp_cache));
|
||||
}
|
||||
InitializeIterator();
|
||||
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx));
|
||||
return RestoreInput(ctx, reader, iterator_);
|
||||
}
|
||||
|
||||
@ -946,7 +938,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
size_t index_ TF_GUARDED_BY(mu_);
|
||||
}; // MemoryReaderIterator
|
||||
|
||||
void InitializeIterator() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
Status InitializeIterator(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (cache_->IsCompleted()) {
|
||||
iterator_ = absl::make_unique<MemoryReaderIterator>(
|
||||
MemoryReaderIterator::Params{dataset(),
|
||||
@ -958,6 +951,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
strings::StrCat(prefix(), kImpl)},
|
||||
cache_);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
|
||||
return iterator_->Initialize(ctx);
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
|
@ -536,16 +536,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
|
||||
if (iterator_ == nullptr) {
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
|
||||
}
|
||||
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||
// should not be necessary, but the additional
|
||||
// `{Reader,Writer,Passthrough}::kIteratorName` added to the prefix passed to
|
||||
// `iterator_` when it was created prevents the model from identifying this
|
||||
// iterator as the output of `iterator_`.
|
||||
RecordStop(ctx);
|
||||
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
index_++;
|
||||
RecordStart(ctx);
|
||||
return s;
|
||||
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
|
||||
@ -611,6 +603,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
|
||||
dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)});
|
||||
break;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
|
||||
return iterator_->Initialize(ctx);
|
||||
}
|
||||
|
||||
@ -1352,6 +1345,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
|
||||
break;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
|
||||
return iterator_->Initialize(ctx);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user