[tf.data] Calculate the average input time for the root node of the data input pipeline.
PiperOrigin-RevId: 315740720 Change-Id: I55967b18f1f039847049656ccfd849714e83a4cf
This commit is contained in:
parent
26ee75e596
commit
57c09eb4d8
@ -480,6 +480,7 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
|
||||
profiler::TraceMe activity([&] { return BuildTraceMeName(); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
DVLOG(3) << prefix() << " GetNext enter";
|
||||
RecordInput(ctx);
|
||||
RecordStart(ctx, /*stop_output=*/true);
|
||||
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||
if (s.ok() && !*end_of_sequence) RecordElement(ctx, out_tensors);
|
||||
|
@ -971,6 +971,15 @@ class DatasetBaseIterator : public IteratorBase {
|
||||
}
|
||||
}
|
||||
|
||||
// When modeling is enabled, this method records the fact that this iterator
|
||||
// is called.
|
||||
void RecordInput(IteratorContext* ctx) {
|
||||
if (collect_resource_usage(ctx)) {
|
||||
int64 now_nanos = EnvTime::NowNanos();
|
||||
node_->record_input(now_nanos);
|
||||
}
|
||||
}
|
||||
|
||||
// When modeling is enabled, this method records the fact that a thread of
|
||||
// this iterator has started work.
|
||||
void RecordStart(IteratorContext* ctx, bool stop_output = false) {
|
||||
|
@ -59,12 +59,12 @@ class InterleaveMany : public Node {
|
||||
(*input_times)[long_name()] = old_input_time;
|
||||
return;
|
||||
}
|
||||
// Here `old_input_time + SelfProcessingTimeLocked()` is the average input
|
||||
// Here `old_input_time + SelfProcessingTime()` is the average input
|
||||
// time for the interleave node to call one of the `(num_inputs() - 1)`
|
||||
// input nodes(except the first one) to return an element. Regardless of the
|
||||
// `block_length` parameter of interleave node, the average input time for
|
||||
// any of the `(num_inputs() - 1)` input nodes to be called is computed as:
|
||||
double new_input_time = (old_input_time + SelfProcessingTimeLocked()) *
|
||||
double new_input_time = (old_input_time + SelfProcessingTime()) *
|
||||
static_cast<double>(num_inputs() - 1);
|
||||
(*input_times)[long_name()] = new_input_time;
|
||||
}
|
||||
@ -77,7 +77,7 @@ class InterleaveMany : public Node {
|
||||
absl::flat_hash_map<string, double>* output_times,
|
||||
absl::flat_hash_map<string, double>* output_time_gradients) const override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (num_inputs() <= 1) {
|
||||
(*output_times)[long_name()] = self_processing_time;
|
||||
if (gradients) {
|
||||
@ -123,7 +123,7 @@ class InterleaveMany : public Node {
|
||||
absl::flat_hash_map<string, double>* processing_times,
|
||||
absl::flat_hash_map<string, double>* total_processing_times) override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (processing_times) {
|
||||
(*processing_times)[long_name()] = self_processing_time;
|
||||
}
|
||||
@ -179,8 +179,7 @@ class AsyncInterleaveMany : public Node {
|
||||
input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L);
|
||||
}
|
||||
} else {
|
||||
input_time =
|
||||
SelfProcessingTimeLocked() * static_cast<double>(num_inputs() - 1);
|
||||
input_time = SelfProcessingTime() * static_cast<double>(num_inputs() - 1);
|
||||
}
|
||||
(*input_times)[long_name()] = input_time;
|
||||
}
|
||||
@ -196,7 +195,7 @@ class AsyncInterleaveMany : public Node {
|
||||
absl::flat_hash_map<string, double>* output_times,
|
||||
absl::flat_hash_map<string, double>* output_time_gradients) const override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (num_inputs() <= 1) {
|
||||
(*output_times)[long_name()] = self_processing_time;
|
||||
if (gradients) {
|
||||
@ -277,7 +276,7 @@ class AsyncInterleaveMany : public Node {
|
||||
absl::flat_hash_map<string, double>* processing_times,
|
||||
absl::flat_hash_map<string, double>* total_processing_times) override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (processing_times) {
|
||||
(*processing_times)[long_name()] = self_processing_time;
|
||||
}
|
||||
@ -320,8 +319,7 @@ class KnownRatio : public Node {
|
||||
(*input_times)[long_name()] = old_input_time;
|
||||
return;
|
||||
}
|
||||
double new_input_time =
|
||||
(old_input_time + SelfProcessingTimeLocked()) / ratio_;
|
||||
double new_input_time = (old_input_time + SelfProcessingTime()) / ratio_;
|
||||
(*input_times)[long_name()] = new_input_time;
|
||||
}
|
||||
|
||||
@ -333,7 +331,7 @@ class KnownRatio : public Node {
|
||||
absl::flat_hash_map<string, double>* output_times,
|
||||
absl::flat_hash_map<string, double>* output_time_gradients) const override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (ratio_ == 0) {
|
||||
(*output_times)[long_name()] = self_processing_time;
|
||||
if (gradients) {
|
||||
@ -364,7 +362,7 @@ class KnownRatio : public Node {
|
||||
absl::flat_hash_map<string, double>* processing_times,
|
||||
absl::flat_hash_map<string, double>* total_processing_times) override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (processing_times) {
|
||||
(*processing_times)[long_name()] = self_processing_time;
|
||||
}
|
||||
@ -420,7 +418,7 @@ class AsyncKnownRatio : public Node {
|
||||
if (parallelism_parameter) {
|
||||
parallelism = (*parallelism_parameter)->value;
|
||||
}
|
||||
input_time = SelfProcessingTimeLocked() / ratio_ / parallelism;
|
||||
input_time = SelfProcessingTime() / ratio_ / parallelism;
|
||||
(*input_times)[long_name()] = input_time;
|
||||
}
|
||||
|
||||
@ -447,7 +445,7 @@ class AsyncKnownRatio : public Node {
|
||||
} else if (buffer_size_parameter) {
|
||||
buffer_size = (*buffer_size_parameter)->value;
|
||||
}
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
double result;
|
||||
double input_time;
|
||||
if (output_) {
|
||||
@ -535,7 +533,7 @@ class AsyncKnownRatio : public Node {
|
||||
absl::flat_hash_map<string, double>* processing_times,
|
||||
absl::flat_hash_map<string, double>* total_processing_times) override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (processing_times) {
|
||||
(*processing_times)[long_name()] = self_processing_time;
|
||||
}
|
||||
@ -578,8 +576,7 @@ class UnknownRatio : public Node {
|
||||
std::shared_ptr<Node> input = inputs_.front();
|
||||
double ratio = static_cast<double>(input->num_elements()) /
|
||||
static_cast<double>(num_elements_);
|
||||
double new_input_time =
|
||||
(old_input_time + SelfProcessingTimeLocked()) / ratio;
|
||||
double new_input_time = (old_input_time + SelfProcessingTime()) / ratio;
|
||||
(*input_times)[long_name()] = new_input_time;
|
||||
}
|
||||
|
||||
@ -591,7 +588,7 @@ class UnknownRatio : public Node {
|
||||
absl::flat_hash_map<string, double>* output_times,
|
||||
absl::flat_hash_map<string, double>* output_time_gradients) const override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (num_elements_ == 0 || inputs_.empty() ||
|
||||
inputs_.front()->num_elements() == 0) {
|
||||
(*output_times)[long_name()] = self_processing_time;
|
||||
@ -627,7 +624,7 @@ class UnknownRatio : public Node {
|
||||
absl::flat_hash_map<string, double>* processing_times,
|
||||
absl::flat_hash_map<string, double>* total_processing_times) override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double self_processing_time = SelfProcessingTimeLocked();
|
||||
double self_processing_time = SelfProcessingTime();
|
||||
if (processing_times) {
|
||||
(*processing_times)[long_name()] = self_processing_time;
|
||||
}
|
||||
@ -957,11 +954,6 @@ std::shared_ptr<Node> Node::Snapshot(std::shared_ptr<Node> output) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
double Node::SelfProcessingTime() const {
|
||||
tf_shared_lock l(mu_);
|
||||
return SelfProcessingTimeLocked();
|
||||
}
|
||||
|
||||
double Node::TotalBufferedBytes() const {
|
||||
absl::flat_hash_map<string, double> total_bytes;
|
||||
tf_shared_lock l(mu_);
|
||||
@ -1080,7 +1072,15 @@ double Node::TotalProcessingTimeForInputs(
|
||||
return sum;
|
||||
}
|
||||
|
||||
double Node::SelfProcessingTimeLocked() const {
|
||||
double Node::SelfInputTime() const {
|
||||
if (num_elements_ <= 1) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<double>(input_time_) /
|
||||
static_cast<double>(num_elements_ - 1);
|
||||
}
|
||||
|
||||
double Node::SelfProcessingTime() const {
|
||||
if (num_elements_ == 0) {
|
||||
return 0;
|
||||
}
|
||||
@ -1167,6 +1167,8 @@ std::shared_ptr<Node> Node::SnapshotHelper(
|
||||
result_node->num_elements_.store(num_elements_);
|
||||
result_node->record_metrics_.store(false);
|
||||
result_node->processing_time_.store(processing_time_);
|
||||
result_node->input_time_.store(input_time_);
|
||||
result_node->last_input_time_.store(last_input_time_);
|
||||
mutex_lock l2(result_node->mu_);
|
||||
result_node->parameters_ = parameters_;
|
||||
}
|
||||
@ -1460,6 +1462,7 @@ double Model::OutputTime(std::shared_ptr<Node> node,
|
||||
absl::flat_hash_map<string, double>* gradients) {
|
||||
// To store the input time for each node.
|
||||
absl::flat_hash_map<string, double> input_times;
|
||||
input_times[kInputTimeKey] = node->SelfInputTime();
|
||||
|
||||
// TODO(jsimsa): Now that we are accounting for buffer size in wait time
|
||||
// computation, assuming that the input is infinitely fast will result in
|
||||
|
@ -146,6 +146,8 @@ class Node {
|
||||
bytes_produced_(0),
|
||||
num_elements_(0),
|
||||
processing_time_(0),
|
||||
input_time_(0),
|
||||
last_input_time_(0),
|
||||
record_metrics_(true),
|
||||
metrics_(name_),
|
||||
output_(args.output.get()) {}
|
||||
@ -266,6 +268,15 @@ class Node {
|
||||
num_elements_++;
|
||||
}
|
||||
|
||||
// Records that an element has been requested.
|
||||
void record_input(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
||||
if (last_input_time_ != 0) {
|
||||
DCHECK_LE(last_input_time_, time_nanos);
|
||||
input_time_ += time_nanos - last_input_time_;
|
||||
}
|
||||
last_input_time_ = time_nanos;
|
||||
}
|
||||
|
||||
// Records that a node thread has started executing.
|
||||
void record_start(int64 time_nanos) TF_LOCKS_EXCLUDED(mu_) {
|
||||
DCHECK_EQ(work_start_, 0);
|
||||
@ -340,8 +351,11 @@ class Node {
|
||||
std::shared_ptr<Node> Snapshot(std::shared_ptr<Node> output) const
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Returns the per-element input time this node is called.
|
||||
double SelfInputTime() const;
|
||||
|
||||
// Returns the per-element processing time spent in this node.
|
||||
double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);
|
||||
double SelfProcessingTime() const;
|
||||
|
||||
// Returns the total number of bytes buffered in all nodes in the subtree for
|
||||
// which autotuning is enabled.
|
||||
@ -463,9 +477,6 @@ class Node {
|
||||
const absl::flat_hash_map<string, double>& total_processing_times)
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Returns the per-element processing time spent in this node.
|
||||
double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Computes the per-element CPU time spent in the subtree rooted in this node
|
||||
// and stores it in `total_processing_times`. If `processing_times` is not
|
||||
// `nullptr`, collects the per-element CPU time spent in each node of the
|
||||
@ -530,6 +541,9 @@ class Node {
|
||||
std::atomic<int64> bytes_produced_;
|
||||
std::atomic<int64> num_elements_;
|
||||
std::atomic<int64> processing_time_;
|
||||
std::atomic<int64> input_time_;
|
||||
// Records the time current node is called for future use.
|
||||
std::atomic<int64> last_input_time_;
|
||||
std::atomic<bool> record_metrics_;
|
||||
Metrics metrics_;
|
||||
absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
|
||||
|
@ -843,6 +843,40 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::Values(0, 20, 40, 80, 100),
|
||||
::testing::Values(0, 1, 2, 4, 10, 20, 40)));
|
||||
|
||||
class SelfProcessingTimeTest : public ::testing::TestWithParam<int64> {};
|
||||
|
||||
TEST_P(SelfProcessingTimeTest, Model) {
|
||||
const int64 add_times = GetParam();
|
||||
std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
|
||||
for (int i = 0; i < add_times; i++) {
|
||||
source->add_processing_time(i);
|
||||
source->record_element();
|
||||
}
|
||||
double self_processing_time =
|
||||
(add_times == 0 ? 0.0 : (static_cast<double>(add_times) - 1.0) / 2.0);
|
||||
EXPECT_EQ(source->SelfProcessingTime(), self_processing_time);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest,
|
||||
::testing::Values(0, 1, 2, 5, 10, 20, 40));
|
||||
|
||||
class SelfInputTimeTest : public ::testing::TestWithParam<int64> {};
|
||||
|
||||
TEST_P(SelfInputTimeTest, Model) {
|
||||
const int64 add_times = GetParam();
|
||||
std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr});
|
||||
for (int i = 0; i < add_times; i++) {
|
||||
source->record_input((1 + i) * i / 2 + 1);
|
||||
source->record_element();
|
||||
}
|
||||
double self_input_time =
|
||||
(add_times <= 1 ? 0.0 : static_cast<double>(add_times) / 2.0);
|
||||
EXPECT_EQ(source->SelfInputTime(), self_input_time);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, SelfInputTimeTest,
|
||||
::testing::Values(0, 1, 2, 5, 10, 20, 40));
|
||||
|
||||
} // namespace
|
||||
} // namespace model
|
||||
} // namespace data
|
||||
|
Loading…
Reference in New Issue
Block a user