[tf.data] Improvements to the performance modeling framework.

This CL switches from using iterator prefix for identifying the parent node in the model tree when a node is constructed to directly passing a parent pointer to the constructor. In addition, this CL makes the `IteratorBase::InitializeBase` method public, which makes it possible to fix `cache` and `snapshot` implementations to reflect their use of nested iterators in the model tree.

PiperOrigin-RevId: 314570434
Change-Id: Ide0b37f404077938ad8dc4fbbd91489b7197c6e1
This commit is contained in:
Jiri Simsa 2020-06-03 11:31:26 -07:00 committed by TensorFlower Gardener
parent e2d7d94549
commit 33b3e6cd06
6 changed files with 67 additions and 94 deletions

View File

@ -336,8 +336,7 @@ bool GraphDefBuilderWrapper::HasAttr(const string& name,
}
Status IteratorBase::InitializeBase(IteratorContext* ctx,
const IteratorBase* parent,
const string& output_prefix) {
const IteratorBase* parent) {
parent_ = parent;
id_ =
Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast<uint64>(this));
@ -349,9 +348,8 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx,
auto factory = [ctx, this](model::Node::Args args) {
return CreateNode(ctx, std::move(args));
};
model->AddNode(std::move(factory), prefix(), output_prefix, &node_);
cleanup_fns_.push_back(
[model, prefix = prefix()]() { model->RemoveNode(prefix); });
model->AddNode(std::move(factory), prefix(), parent->model_node(), &node_);
cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); });
}
return Status::OK();
}
@ -418,7 +416,7 @@ Status DatasetBase::MakeIterator(
const string& output_prefix,
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(output_prefix);
Status s = (*iterator)->InitializeBase(ctx, parent, output_prefix);
Status s = (*iterator)->InitializeBase(ctx, parent);
if (s.ok()) {
s.Update((*iterator)->Initialize(ctx));
}

View File

@ -610,6 +610,9 @@ class IteratorBase {
// properly propagate errors.
virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
// Performs initialization of the base iterator.
Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
// Saves the state of this iterator.
virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
return SaveInternal(ctx, writer);
@ -673,10 +676,6 @@ class IteratorBase {
friend class DatasetBase;
friend class DatasetBaseIterator; // for access to `node_`
// Performs initialization of the base iterator.
Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent,
const string& output_prefix);
std::vector<std::function<void()>> cleanup_fns_;
std::shared_ptr<model::Node> node_ = nullptr;
const IteratorBase* parent_ = nullptr; // Not owned.

View File

@ -1089,22 +1089,18 @@ Node::NodeVector Node::CollectNodes(TraversalOrder order) const
NodeVector node_vector;
std::list<std::shared_ptr<Node>> temp_list;
{
for (auto& input : inputs_) {
node_vector.push_back(input);
temp_list.push_back(input);
}
for (auto& input : inputs_) {
node_vector.push_back(input);
temp_list.push_back(input);
}
while (!temp_list.empty()) {
auto cur_node = temp_list.front();
temp_list.pop_front();
{
tf_shared_lock l(cur_node->mu_);
for (auto& input : cur_node->inputs_) {
node_vector.push_back(input);
temp_list.push_back(input);
}
tf_shared_lock l(cur_node->mu_);
for (auto& input : cur_node->inputs_) {
node_vector.push_back(input);
temp_list.push_back(input);
}
}
@ -1222,46 +1218,41 @@ void Node::TotalMaximumBufferedBytesHelper(
}
void Model::AddNode(Node::Factory factory, const string& name,
const string& output_name,
std::shared_ptr<Node> parent,
std::shared_ptr<Node>* out_node) {
// The name captures the sequence of iterators joined by `::`. We use the full
// sequence as the key in the lookup table, but only the last element of the
// sequence as the name node.
std::vector<string> tokens =
str_util::Split(name, ':', str_util::SkipEmpty());
// The output name might contain an index. We need to strip it to make it
// possible for the model to successfully identify the output node.
string sanitized_output_name = output_name;
if (str_util::EndsWith(output_name, "]")) {
sanitized_output_name = output_name.substr(0, output_name.rfind('['));
}
std::shared_ptr<Node> output;
// The name captures the sequence of iterators joined by `::`. We only use the
// last element of the sequence as the name node.
auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
mutex_lock l(mu_);
auto it = lookup_table_.find(sanitized_output_name);
if (it != lookup_table_.end()) {
output = it->second;
}
std::shared_ptr<Node> node = factory({id_counter_++, tokens.back(), output});
std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
if (!output_) {
output_ = node;
}
if (output) {
if (parent) {
VLOG(3) << "Adding " << node->long_name() << " as input for "
<< output->long_name();
output->add_input(node);
<< parent->long_name();
parent->add_input(node);
} else {
VLOG(3) << "Adding " << node->long_name();
}
collect_resource_usage_ =
collect_resource_usage_ || node->has_tunable_parameters();
lookup_table_.insert(std::make_pair(name, node));
*out_node = node;
*out_node = std::move(node);
}
void Model::FlushMetrics() {
tf_shared_lock l(mu_);
for (const auto& pair : lookup_table_) {
pair.second->FlushMetrics();
std::deque<std::shared_ptr<Node>> queue;
{
tf_shared_lock l(mu_);
if (output_) queue.push_back(output_);
}
while (!queue.empty()) {
auto node = queue.front();
queue.pop_front();
node->FlushMetrics();
for (auto input : node->inputs()) {
queue.push_back(input);
}
}
}
@ -1277,16 +1268,14 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
}
}
void Model::RemoveNode(const string& name) {
void Model::RemoveNode(std::shared_ptr<Node> node) {
mutex_lock l(mu_);
auto node = gtl::FindOrNull(lookup_table_, name);
if (node) {
if ((*node)->output()) {
(*node)->output()->remove_input(*node);
if (node->output()) {
node->output()->remove_input(node);
}
VLOG(3) << "Removing " << (*node)->long_name();
VLOG(3) << "Removing " << node->long_name();
}
lookup_table_.erase(name);
}
absl::flat_hash_map<string, std::shared_ptr<Parameter>>

View File

@ -274,6 +274,7 @@ class Node {
// Records that a node thread has stopped executing.
void record_stop(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
// TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here.
if (work_start_ != 0) {
processing_time_ += time_nanos - work_start_;
work_start_ = 0;
@ -598,10 +599,9 @@ class Model {
// Indicates whether to collect resource usage.
bool collect_resource_usage() const { return collect_resource_usage_; }
// Adds a node with the given name and given output. The method returns
// a pointer to the node but does not transfer ownership.
// Adds a node with the given name and given parent.
void AddNode(Node::Factory factory, const string& name,
const string& output_name, std::shared_ptr<Node>* out_node)
std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
TF_LOCKS_EXCLUDED(mu_);
// Flushes metrics record by the model.
@ -612,7 +612,7 @@ class Model {
TF_LOCKS_EXCLUDED(mu_);
// Removes the given node.
void RemoveNode(const string& name) TF_LOCKS_EXCLUDED(mu_);
void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
private:
// Collects tunable parameters in the tree rooted in the given node, returning
@ -670,8 +670,6 @@ class Model {
mutex mu_;
int64 id_counter_ TF_GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<string, std::shared_ptr<Node>> lookup_table_
TF_GUARDED_BY(mu_);
// Indicates whether the modeling framework should collect resource usage
// (e.g. CPU, memory). The logic for collecting this information assumes that

View File

@ -130,12 +130,11 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
} else {
mode_ = Mode::write;
}
InitializeIterator();
}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
return iterator_->Initialize(ctx);
return InitializeIterator(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@ -180,8 +179,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
<< "mistake, please remove the above file and try running again.";
mode_ = Mode::read;
}
InitializeIterator();
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
TF_RETURN_IF_ERROR(InitializeIterator(ctx));
return RestoreInput(ctx, reader, iterator_);
}
@ -560,15 +558,16 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
bool iterator_restored_ TF_GUARDED_BY(mu_);
}; // FileReaderIterator
void InitializeIterator() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
// We intentionally use the same prefix for both `FileReaderIterator`
// and `FileWriterIterator`. Since at any time there will be at most
// one of them alive, there should be no conflicts. This allows both
// iterators to use a common key for `cur_index`. We leverage this
// in the corner case when this iterator is restored from an old
// checkpoint in `write` mode and the cache has been completely
// flushed to disk since then. In that case we simply build a
// `FileReaderIterator` and seek to the `cur_index`.
Status InitializeIterator(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
// We intentionally use the same prefix for both `FileReaderIterator` and
// `FileWriterIterator`. Since at any time there will be at most one of
// them alive, there should be no conflicts. This allows both iterators to
// use a common key for `cur_index`. We leverage this in the corner case
// when this iterator is restored from an old checkpoint in `write` mode
// and the cache has been completely flushed to disk since then. In that
// case we simply build a `FileReaderIterator` and seek to the
// `cur_index`.
switch (mode_) {
case Mode::read:
iterator_ =
@ -580,6 +579,8 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
absl::make_unique<FileWriterIterator>(FileWriterIterator::Params{
dataset(), strings::StrCat(prefix(), kImpl)});
}
TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
return iterator_->Initialize(ctx);
}
mutex mu_;
@ -741,22 +742,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
InitializeIterator();
return iterator_->Initialize(ctx);
return InitializeIterator(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
// TODO(b/154341936): Explicitly stopping and starting this iterator
// should not be necessary, but the `kImpl` added to the prefix passed
// to `iterator_` when it was created prevents the model from identifying
// this iterator as the output of `iterator_`.
RecordStop(ctx);
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
RecordStart(ctx);
return s;
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
}
protected:
@ -789,8 +782,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
[this](const string& s) { return full_name(s); }));
cache_->Complete(std::move(temp_cache));
}
InitializeIterator();
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
TF_RETURN_IF_ERROR(InitializeIterator(ctx));
return RestoreInput(ctx, reader, iterator_);
}
@ -946,7 +938,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
size_t index_ TF_GUARDED_BY(mu_);
}; // MemoryReaderIterator
void InitializeIterator() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
Status InitializeIterator(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (cache_->IsCompleted()) {
iterator_ = absl::make_unique<MemoryReaderIterator>(
MemoryReaderIterator::Params{dataset(),
@ -958,6 +951,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
strings::StrCat(prefix(), kImpl)},
cache_);
}
TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
return iterator_->Initialize(ctx);
}
mutex mu_;

View File

@ -536,16 +536,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
if (iterator_ == nullptr) {
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
}
// TODO(b/154341936): Explicitly stopping and starting this iterator
// should not be necessary, but the additional
// `{Reader,Writer,Passthrough}::kIteratorName` added to the prefix passed to
// `iterator_` when it was created prevents the model from identifying this
// iterator as the output of `iterator_`.
RecordStop(ctx);
Status s = iterator_->GetNext(ctx, out_tensors, end_of_sequence);
index_++;
RecordStart(ctx);
return s;
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
}
Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
@ -611,6 +603,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)});
break;
}
TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
return iterator_->Initialize(ctx);
}
@ -1352,6 +1345,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
break;
}
TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
return iterator_->Initialize(ctx);
}