[tf.data] Calculate the average input time for the root node of the data input pipeline.

PiperOrigin-RevId: 315740720
Change-Id: I55967b18f1f039847049656ccfd849714e83a4cf
This commit is contained in:
Jay Shi 2020-06-10 12:13:27 -07:00 committed by TensorFlower Gardener
parent 26ee75e596
commit 57c09eb4d8
5 changed files with 90 additions and 29 deletions

View File

@ -480,6 +480,7 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
profiler::TraceMe activity([&] { return BuildTraceMeName(); },
profiler::TraceMeLevel::kInfo);
DVLOG(3) << prefix() << " GetNext enter";
RecordInput(ctx);
RecordStart(ctx, /*stop_output=*/true);
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
if (s.ok() && !*end_of_sequence) RecordElement(ctx, out_tensors);

View File

@ -971,6 +971,15 @@ class DatasetBaseIterator : public IteratorBase {
}
}
// When modeling is enabled, this method records the fact that this iterator
// is called.
void RecordInput(IteratorContext* ctx) {
if (collect_resource_usage(ctx)) {
int64 now_nanos = EnvTime::NowNanos();
node_->record_input(now_nanos);
}
}
// When modeling is enabled, this method records the fact that a thread of
// this iterator has started work.
void RecordStart(IteratorContext* ctx, bool stop_output = false) {

View File

@ -59,12 +59,12 @@ class InterleaveMany : public Node {
(*input_times)[long_name()] = old_input_time;
return;
}
// Here `old_input_time + SelfProcessingTimeLocked()` is the average input
// Here `old_input_time + SelfProcessingTime()` is the average input
// time for the interleave node to call one of the `(num_inputs() - 1)`
// input nodes(except the first one) to return an element. Regardless of the
// `block_length` parameter of interleave node, the average input time for
// any of the `(num_inputs() - 1)` input nodes to be called is computed as:
double new_input_time = (old_input_time + SelfProcessingTimeLocked()) *
double new_input_time = (old_input_time + SelfProcessingTime()) *
static_cast<double>(num_inputs() - 1);
(*input_times)[long_name()] = new_input_time;
}
@ -77,7 +77,7 @@ class InterleaveMany : public Node {
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (num_inputs() <= 1) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
@ -123,7 +123,7 @@ class InterleaveMany : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -179,8 +179,7 @@ class AsyncInterleaveMany : public Node {
input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L);
}
} else {
input_time =
SelfProcessingTimeLocked() * static_cast<double>(num_inputs() - 1);
input_time = SelfProcessingTime() * static_cast<double>(num_inputs() - 1);
}
(*input_times)[long_name()] = input_time;
}
@ -196,7 +195,7 @@ class AsyncInterleaveMany : public Node {
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (num_inputs() <= 1) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
@ -277,7 +276,7 @@ class AsyncInterleaveMany : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -320,8 +319,7 @@ class KnownRatio : public Node {
(*input_times)[long_name()] = old_input_time;
return;
}
double new_input_time =
(old_input_time + SelfProcessingTimeLocked()) / ratio_;
double new_input_time = (old_input_time + SelfProcessingTime()) / ratio_;
(*input_times)[long_name()] = new_input_time;
}
@ -333,7 +331,7 @@ class KnownRatio : public Node {
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (ratio_ == 0) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
@ -364,7 +362,7 @@ class KnownRatio : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -420,7 +418,7 @@ class AsyncKnownRatio : public Node {
if (parallelism_parameter) {
parallelism = (*parallelism_parameter)->value;
}
input_time = SelfProcessingTimeLocked() / ratio_ / parallelism;
input_time = SelfProcessingTime() / ratio_ / parallelism;
(*input_times)[long_name()] = input_time;
}
@ -447,7 +445,7 @@ class AsyncKnownRatio : public Node {
} else if (buffer_size_parameter) {
buffer_size = (*buffer_size_parameter)->value;
}
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
double result;
double input_time;
if (output_) {
@ -535,7 +533,7 @@ class AsyncKnownRatio : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -578,8 +576,7 @@ class UnknownRatio : public Node {
std::shared_ptr<Node> input = inputs_.front();
double ratio = static_cast<double>(input->num_elements()) /
static_cast<double>(num_elements_);
double new_input_time =
(old_input_time + SelfProcessingTimeLocked()) / ratio;
double new_input_time = (old_input_time + SelfProcessingTime()) / ratio;
(*input_times)[long_name()] = new_input_time;
}
@ -591,7 +588,7 @@ class UnknownRatio : public Node {
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (num_elements_ == 0 || inputs_.empty() ||
inputs_.front()->num_elements() == 0) {
(*output_times)[long_name()] = self_processing_time;
@ -627,7 +624,7 @@ class UnknownRatio : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTimeLocked();
double self_processing_time = SelfProcessingTime();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -957,11 +954,6 @@ std::shared_ptr<Node> Node::Snapshot(std::shared_ptr<Node> output) const {
return result;
}
double Node::SelfProcessingTime() const {
tf_shared_lock l(mu_);
return SelfProcessingTimeLocked();
}
double Node::TotalBufferedBytes() const {
absl::flat_hash_map<string, double> total_bytes;
tf_shared_lock l(mu_);
@ -1080,7 +1072,15 @@ double Node::TotalProcessingTimeForInputs(
return sum;
}
double Node::SelfProcessingTimeLocked() const {
double Node::SelfInputTime() const {
if (num_elements_ <= 1) {
return 0;
}
return static_cast<double>(input_time_) /
static_cast<double>(num_elements_ - 1);
}
double Node::SelfProcessingTime() const {
if (num_elements_ == 0) {
return 0;
}
@ -1167,6 +1167,8 @@ std::shared_ptr<Node> Node::SnapshotHelper(
result_node->num_elements_.store(num_elements_);
result_node->record_metrics_.store(false);
result_node->processing_time_.store(processing_time_);
result_node->input_time_.store(input_time_);
result_node->last_input_time_.store(last_input_time_);
mutex_lock l2(result_node->mu_);
result_node->parameters_ = parameters_;
}
@ -1460,6 +1462,7 @@ double Model::OutputTime(std::shared_ptr<Node> node,
absl::flat_hash_map<string, double>* gradients) {
// To store the input time for each node.
absl::flat_hash_map<string, double> input_times;
input_times[kInputTimeKey] = node->SelfInputTime();
// TODO(jsimsa): Now that we are accounting for buffer size in wait time
// computation, assuming that the input is infinitely fast will result in

View File

@ -146,6 +146,8 @@ class Node {
bytes_produced_(0),
num_elements_(0),
processing_time_(0),
input_time_(0),
last_input_time_(0),
record_metrics_(true),
metrics_(name_),
output_(args.output.get()) {}
@ -266,6 +268,15 @@ class Node {
num_elements_++;
}
// Records that an element has been requested.
void record_input(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
if (last_input_time_ != 0) {
DCHECK_LE(last_input_time_, time_nanos);
input_time_ += time_nanos - last_input_time_;
}
last_input_time_ = time_nanos;
}
// Records that a node thread has started executing.
void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
DCHECK_EQ(work_start_, 0);
@ -340,8 +351,11 @@ class Node {
std::shared_ptr<Node> Snapshot(std::shared_ptr<Node> output) const
TF_LOCKS_EXCLUDED(mu_);
// Returns the per-element input time this node is called.
double SelfInputTime() const;
// Returns the per-element processing time spent in this node.
double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);
double SelfProcessingTime() const;
// Returns the total number of bytes buffered in all nodes in the subtree for
// which autotuning is enabled.
@ -463,9 +477,6 @@ class Node {
const absl::flat_hash_map<string, double>& total_processing_times)
TF_SHARED_LOCKS_REQUIRED(mu_);
// Returns the per-element processing time spent in this node.
double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_);
// Computes the per-element CPU time spent in the subtree rooted in this node
// and stores it in `total_processing_times`. If `processing_times` is not
// `nullptr`, collects the per-element CPU time spent in each node of the
@ -530,6 +541,9 @@ class Node {
std::atomic<int64> bytes_produced_;
std::atomic<int64> num_elements_;
std::atomic<int64> processing_time_;
std::atomic<int64> input_time_;
// Records the time current node is called for future use.
std::atomic<int64> last_input_time_;
std::atomic<bool> record_metrics_;
Metrics metrics_;
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_

View File

@ -843,6 +843,40 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(0, 20, 40, 80, 100),
::testing::Values(0, 1, 2, 4, 10, 20, 40)));
class SelfProcessingTimeTest : public ::testing::TestWithParam<int64> {};
TEST_P(SelfProcessingTimeTest, Model) {
const int64 add_times = GetParam();
std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
for (int i = 0; i < add_times; i++) {
source->add_processing_time(i);
source->record_element();
}
double self_processing_time =
(add_times == 0 ? 0.0 : (static_cast<double>(add_times) - 1.0) / 2.0);
EXPECT_EQ(source->SelfProcessingTime(), self_processing_time);
}
INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest,
::testing::Values(0, 1, 2, 5, 10, 20, 40));
class SelfInputTimeTest : public ::testing::TestWithParam<int64> {};
TEST_P(SelfInputTimeTest, Model) {
const int64 add_times = GetParam();
std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
for (int i = 0; i < add_times; i++) {
source->record_input((1 + i) * i / 2 + 1);
source->record_element();
}
double self_input_time =
(add_times <= 1 ? 0.0 : static_cast<double>(add_times) / 2.0);
EXPECT_EQ(source->SelfInputTime(), self_input_time);
}
INSTANTIATE_TEST_SUITE_P(Test, SelfInputTimeTest,
::testing::Values(0, 1, 2, 5, 10, 20, 40));
} // namespace
} // namespace model
} // namespace data