[tf.data] Calculate the average input time for the root node of the data input pipeline.

PiperOrigin-RevId: 315904116
Change-Id: Ica73e017ab796a8ff97b92552c8b2aa698edebb3
This commit is contained in:
Jiri Simsa 2020-06-11 08:32:47 -07:00 committed by TensorFlower Gardener
parent c674577870
commit 40704b8933
5 changed files with 29 additions and 90 deletions

View File

@ -480,7 +480,6 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
profiler::TraceMe activity([&] { return BuildTraceMeName(); },
profiler::TraceMeLevel::kInfo);
DVLOG(3) << prefix() << " GetNext enter";
RecordInput(ctx);
RecordStart(ctx, /*stop_output=*/true);
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
if (s.ok() && !*end_of_sequence) RecordElement(ctx, out_tensors);

View File

@ -971,15 +971,6 @@ class DatasetBaseIterator : public IteratorBase {
}
}
// When modeling is enabled, this method records the fact that this iterator
// is called.
void RecordInput(IteratorContext* ctx) {
if (collect_resource_usage(ctx)) {
int64 now_nanos = EnvTime::NowNanos();
node_->record_input(now_nanos);
}
}
// When modeling is enabled, this method records the fact that a thread of
// this iterator has started work.
void RecordStart(IteratorContext* ctx, bool stop_output = false) {

View File

@ -59,12 +59,12 @@ class InterleaveMany : public Node {
(*input_times)[long_name()] = old_input_time;
return;
}
// Here `old_input_time + SelfProcessingTime()` is the average input
// Here `old_input_time + SelfProcessingTimeLocked()` is the average input
// time for the interleave node to call one of the `(num_inputs() - 1)`
// input nodes(except the first one) to return an element. Regardless of the
// `block_length` parameter of interleave node, the average input time for
// any of the `(num_inputs() - 1)` input nodes to be called is computed as:
double new_input_time = (old_input_time + SelfProcessingTime()) *
double new_input_time = (old_input_time + SelfProcessingTimeLocked()) *
static_cast<double>(num_inputs() - 1);
(*input_times)[long_name()] = new_input_time;
}
@ -77,7 +77,7 @@ class InterleaveMany : public Node {
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (num_inputs() <= 1) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
@ -123,7 +123,7 @@ class InterleaveMany : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -179,7 +179,8 @@ class AsyncInterleaveMany : public Node {
input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L);
}
} else {
input_time = SelfProcessingTime() * static_cast<double>(num_inputs() - 1);
input_time =
SelfProcessingTimeLocked() * static_cast<double>(num_inputs() - 1);
}
(*input_times)[long_name()] = input_time;
}
@ -195,7 +196,7 @@ class AsyncInterleaveMany : public Node {
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (num_inputs() <= 1) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
@ -276,7 +277,7 @@ class AsyncInterleaveMany : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -319,7 +320,8 @@ class KnownRatio : public Node {
(*input_times)[long_name()] = old_input_time;
return;
}
double new_input_time = (old_input_time + SelfProcessingTime()) / ratio_;
double new_input_time =
(old_input_time + SelfProcessingTimeLocked()) / ratio_;
(*input_times)[long_name()] = new_input_time;
}
@ -331,7 +333,7 @@ class KnownRatio : public Node {
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (ratio_ == 0) {
(*output_times)[long_name()] = self_processing_time;
if (gradients) {
@ -362,7 +364,7 @@ class KnownRatio : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -418,7 +420,7 @@ class AsyncKnownRatio : public Node {
if (parallelism_parameter) {
parallelism = (*parallelism_parameter)->value;
}
input_time = SelfProcessingTime() / ratio_ / parallelism;
input_time = SelfProcessingTimeLocked() / ratio_ / parallelism;
(*input_times)[long_name()] = input_time;
}
@ -445,7 +447,7 @@ class AsyncKnownRatio : public Node {
} else if (buffer_size_parameter) {
buffer_size = (*buffer_size_parameter)->value;
}
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
double result;
double input_time;
if (output_) {
@ -533,7 +535,7 @@ class AsyncKnownRatio : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -576,7 +578,8 @@ class UnknownRatio : public Node {
std::shared_ptr<Node> input = inputs_.front();
double ratio = static_cast<double>(input->num_elements()) /
static_cast<double>(num_elements_);
double new_input_time = (old_input_time + SelfProcessingTime()) / ratio;
double new_input_time =
(old_input_time + SelfProcessingTimeLocked()) / ratio;
(*input_times)[long_name()] = new_input_time;
}
@ -588,7 +591,7 @@ class UnknownRatio : public Node {
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (num_elements_ == 0 || inputs_.empty() ||
inputs_.front()->num_elements() == 0) {
(*output_times)[long_name()] = self_processing_time;
@ -624,7 +627,7 @@ class UnknownRatio : public Node {
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
double self_processing_time = SelfProcessingTime();
double self_processing_time = SelfProcessingTimeLocked();
if (processing_times) {
(*processing_times)[long_name()] = self_processing_time;
}
@ -954,6 +957,11 @@ std::shared_ptr<Node> Node::Snapshot(std::shared_ptr<Node> output) const {
return result;
}
double Node::SelfProcessingTime() const {
tf_shared_lock l(mu_);
return SelfProcessingTimeLocked();
}
double Node::TotalBufferedBytes() const {
absl::flat_hash_map<string, double> total_bytes;
tf_shared_lock l(mu_);
@ -1072,15 +1080,7 @@ double Node::TotalProcessingTimeForInputs(
return sum;
}
double Node::SelfInputTime() const {
if (num_elements_ <= 1) {
return 0;
}
return static_cast<double>(input_time_) /
static_cast<double>(num_elements_ - 1);
}
double Node::SelfProcessingTime() const {
double Node::SelfProcessingTimeLocked() const {
if (num_elements_ == 0) {
return 0;
}
@ -1167,8 +1167,6 @@ std::shared_ptr<Node> Node::SnapshotHelper(
result_node->num_elements_.store(num_elements_);
result_node->record_metrics_.store(false);
result_node->processing_time_.store(processing_time_);
result_node->input_time_.store(input_time_);
result_node->last_input_time_.store(last_input_time_);
mutex_lock l2(result_node->mu_);
result_node->parameters_ = parameters_;
}
@ -1462,7 +1460,6 @@ double Model::OutputTime(std::shared_ptr<Node> node,
absl::flat_hash_map<string, double>* gradients) {
// To store the input time for each node.
absl::flat_hash_map<string, double> input_times;
input_times[kInputTimeKey] = node->SelfInputTime();
// TODO(jsimsa): Now that we are accounting for buffer size in wait time
// computation, assuming that the input is infinitely fast will result in

View File

@ -146,8 +146,6 @@ class Node {
bytes_produced_(0),
num_elements_(0),
processing_time_(0),
input_time_(0),
last_input_time_(0),
record_metrics_(true),
metrics_(name_),
output_(args.output.get()) {}
@ -268,15 +266,6 @@ class Node {
num_elements_++;
}
// Records that an element has been requested.
void record_input(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
if (last_input_time_ != 0) {
DCHECK_LE(last_input_time_, time_nanos);
input_time_ += time_nanos - last_input_time_;
}
last_input_time_ = time_nanos;
}
// Records that a node thread has started executing.
void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
DCHECK_EQ(work_start_, 0);
@ -351,11 +340,8 @@ class Node {
std::shared_ptr<Node> Snapshot(std::shared_ptr<Node> output) const
TF_LOCKS_EXCLUDED(mu_);
// Returns the per-element input time this node is called.
double SelfInputTime() const;
// Returns the per-element processing time spent in this node.
double SelfProcessingTime() const;
double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);
// Returns the total number of bytes buffered in all nodes in the subtree for
// which autotuning is enabled.
@ -477,6 +463,9 @@ class Node {
const absl::flat_hash_map<string, double>& total_processing_times)
TF_SHARED_LOCKS_REQUIRED(mu_);
// Returns the per-element processing time spent in this node.
double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_);
// Computes the per-element CPU time spent in the subtree rooted in this node
// and stores it in `total_processing_times`. If `processing_times` is not
// `nullptr`, collects the per-element CPU time spent in each node of the
@ -541,9 +530,6 @@ class Node {
std::atomic<int64> bytes_produced_;
std::atomic<int64> num_elements_;
std::atomic<int64> processing_time_;
std::atomic<int64> input_time_;
// Records the time current node is called for future use.
std::atomic<int64> last_input_time_;
std::atomic<bool> record_metrics_;
Metrics metrics_;
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_

View File

@ -843,40 +843,6 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(0, 20, 40, 80, 100),
::testing::Values(0, 1, 2, 4, 10, 20, 40)));
class SelfProcessingTimeTest : public ::testing::TestWithParam<int64> {};
TEST_P(SelfProcessingTimeTest, Model) {
const int64 add_times = GetParam();
std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
for (int i = 0; i < add_times; i++) {
source->add_processing_time(i);
source->record_element();
}
double self_processing_time =
(add_times == 0 ? 0.0 : (static_cast<double>(add_times) - 1.0) / 2.0);
EXPECT_EQ(source->SelfProcessingTime(), self_processing_time);
}
INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest,
::testing::Values(0, 1, 2, 5, 10, 20, 40));
class SelfInputTimeTest : public ::testing::TestWithParam<int64> {};
TEST_P(SelfInputTimeTest, Model) {
const int64 add_times = GetParam();
std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
for (int i = 0; i < add_times; i++) {
source->record_input((1 + i) * i / 2 + 1);
source->record_element();
}
double self_input_time =
(add_times <= 1 ? 0.0 : static_cast<double>(add_times) / 2.0);
EXPECT_EQ(source->SelfInputTime(), self_input_time);
}
INSTANTIATE_TEST_SUITE_P(Test, SelfInputTimeTest,
::testing::Values(0, 1, 2, 5, 10, 20, 40));
} // namespace
} // namespace model
} // namespace data