[tf.data] Use all produced elements during the process to estimate the average size instead of only current elements in the buffer.
PiperOrigin-RevId: 337941261 Change-Id: I67d4ce30de2cc7dd43442b55de65ab0164aedaec
This commit is contained in:
parent
a361baa862
commit
19fda561d8
@ -311,6 +311,15 @@ class AsyncInterleaveMany : public Node {
|
||||
(*total_processing_times)[long_name()] =
|
||||
self_processing_time + inputs_processing_time;
|
||||
}
|
||||
|
||||
double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double result = 0;
|
||||
auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
|
||||
if (parameter) {
|
||||
result += (*parameter)->value * AverageBufferedElementSize();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class KnownRatio : public Node {
|
||||
@ -593,6 +602,26 @@ class AsyncKnownRatio : public Node {
|
||||
self_processing_time + inputs_processing_time;
|
||||
}
|
||||
|
||||
double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
double result = 0;
|
||||
auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
|
||||
if (!parameter) {
|
||||
parameter = gtl::FindOrNull(parameters_, kParallelism);
|
||||
}
|
||||
|
||||
if (parameter) {
|
||||
if (ratio_ == 0) {
|
||||
result += (*parameter)->value * AverageBufferedElementSize();
|
||||
} else {
|
||||
// The estimation is currently not accurate for MapAndBatchDataset for
|
||||
// the maximum buffer size does not match `num_parallel_calls`
|
||||
// parameter.
|
||||
result += (*parameter)->value * AverageBufferedElementSize() / ratio_;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
const double ratio_;
|
||||
};
|
||||
@ -1067,11 +1096,34 @@ double Node::TotalProcessingTime(
|
||||
}
|
||||
|
||||
double Node::AverageBufferedElementSize() const {
|
||||
if (buffered_elements_ == 0) {
|
||||
return 0;
|
||||
DCHECK_GE(num_elements_, 0);
|
||||
DCHECK_GE(buffered_elements_, 0);
|
||||
if (num_elements_ <= 0) {
|
||||
if (buffered_elements_ <= 0) {
|
||||
// If there are no produced elements or buffered elements recorded, return
|
||||
// 0.
|
||||
return 0;
|
||||
}
|
||||
// If there are no produced elements but some buffered elements, return the
|
||||
// average size of all buffered elements.
|
||||
return static_cast<double>(buffered_bytes_) /
|
||||
static_cast<double>(buffered_elements_);
|
||||
}
|
||||
return static_cast<double>(buffered_bytes_) /
|
||||
static_cast<double>(buffered_elements_);
|
||||
|
||||
if (buffered_elements_ <= 0) {
|
||||
// If there are no buffered elements but some produced elements, return the
|
||||
// average size of all produced elements.
|
||||
return static_cast<double>(bytes_produced_) /
|
||||
static_cast<double>(num_elements_);
|
||||
}
|
||||
|
||||
// Otherwise, return the mean value of average size of all produced elements
|
||||
// and average size of all buffered elements.
|
||||
return (static_cast<double>(bytes_produced_) /
|
||||
static_cast<double>(num_elements_) +
|
||||
static_cast<double>(buffered_bytes_) /
|
||||
static_cast<double>(buffered_elements_)) /
|
||||
2.0;
|
||||
}
|
||||
|
||||
double Node::OutputTimeForInputs(
|
||||
@ -1275,20 +1327,17 @@ void Node::TotalMaximumBufferedBytesHelper(
|
||||
return;
|
||||
}
|
||||
|
||||
double result = 0;
|
||||
auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
|
||||
if (!parameter) {
|
||||
parameter = gtl::FindOrNull(parameters_, kParallelism);
|
||||
}
|
||||
if (parameter) {
|
||||
result = (*parameter)->value * AverageBufferedElementSize();
|
||||
}
|
||||
double result = MaximumBufferedBytes();
|
||||
for (auto& input : inputs_) {
|
||||
result += total_bytes->at(input->long_name());
|
||||
}
|
||||
total_bytes->insert(std::make_pair(long_name(), result));
|
||||
}
|
||||
|
||||
double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Model::AddNode(Node::Factory factory, const string& name,
|
||||
std::shared_ptr<Node> parent,
|
||||
std::shared_ptr<Node>* out_node) {
|
||||
|
@ -517,6 +517,12 @@ class Node {
|
||||
absl::flat_hash_map<string, double>* total_bytes) const
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Compute and return the maximum buffered bytes on the node itself. By
|
||||
// default non-tunable nodes are assumed not to buffer any bytes, so the
|
||||
// tunable nodes as subclasses are expected to override this method to ensure
|
||||
// that the optimization algorithm respects the memory budget.
|
||||
virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Stores the time passed to the last call to `Node::record_start()` on the
|
||||
// current thread.
|
||||
//
|
||||
|
@ -131,7 +131,9 @@ TEST_P(AsyncKnownRatioTest, Model) {
|
||||
async_known_many->record_buffer_event(110, 10);
|
||||
EXPECT_EQ(async_known_many->TotalBufferedBytes(), 110);
|
||||
EXPECT_EQ(async_known_many->TotalMaximumBufferedBytes(),
|
||||
110 * parallelism / 10);
|
||||
num_inputs_per_output == 0
|
||||
? 110.0 * parallelism / 10
|
||||
: 110.0 * parallelism / 10 / num_inputs_per_output);
|
||||
source1->add_processing_time(100);
|
||||
EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
|
||||
0);
|
||||
@ -385,41 +387,12 @@ TEST(UnknownTest, Model) {
|
||||
EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
|
||||
}
|
||||
|
||||
class TestNode : public model::Node {
|
||||
public:
|
||||
using model::Node::Node;
|
||||
|
||||
virtual ~TestNode() {}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
|
||||
const override TF_SHARED_LOCKS_REQUIRED(mu_) {}
|
||||
|
||||
void OutputTimeLocked(
|
||||
const absl::flat_hash_map<string, double>& input_times,
|
||||
absl::flat_hash_map<string, double>* gradients,
|
||||
absl::flat_hash_map<string, double>* output_times,
|
||||
absl::flat_hash_map<string, double>* output_time_gradients) const override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
(*output_times)[long_name()] = 0;
|
||||
}
|
||||
|
||||
void TotalProcessingTimeLocked(
|
||||
absl::flat_hash_map<string, double>* processing_times,
|
||||
absl::flat_hash_map<string, double>* total_processing_times) override
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) {
|
||||
(*total_processing_times)[long_name()] = 0;
|
||||
}
|
||||
};
|
||||
|
||||
TEST(SetterGetterTest, Node) {
|
||||
std::shared_ptr<TestNode> node =
|
||||
std::make_shared<TestNode>(model::Node::Args{-1, "TestNode", nullptr});
|
||||
std::shared_ptr<Node> node = model::MakeAsyncInterleaveManyNode(
|
||||
{-1, "TestNode", nullptr},
|
||||
{model::MakeParameter("parallelism",
|
||||
std::make_shared<SharedState>(3, nullptr, nullptr),
|
||||
1, 7)});
|
||||
EXPECT_EQ(node->id(), -1);
|
||||
EXPECT_EQ(node->name(), "TestNode");
|
||||
EXPECT_EQ(node->output(), nullptr);
|
||||
@ -428,16 +401,46 @@ TEST(SetterGetterTest, Node) {
|
||||
EXPECT_EQ(node->buffered_elements(), 0);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 0);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
|
||||
node->record_buffer_event(42, 0);
|
||||
EXPECT_EQ(node->buffered_bytes(), 42);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 0);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
|
||||
EXPECT_EQ(node->buffered_elements(), 0);
|
||||
node->record_buffer_event(0, 11);
|
||||
EXPECT_EQ(node->buffered_bytes(), 42);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 0);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
|
||||
EXPECT_EQ(node->buffered_elements(), 11);
|
||||
|
||||
node->record_buffer_event(20, 1);
|
||||
EXPECT_EQ(node->buffered_bytes(), 20);
|
||||
EXPECT_EQ(node->buffered_elements(), 1);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 20);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 60);
|
||||
|
||||
node->record_buffer_event(10, 1);
|
||||
EXPECT_EQ(node->buffered_bytes(), 30);
|
||||
EXPECT_EQ(node->buffered_elements(), 2);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 30);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 45);
|
||||
|
||||
node->record_buffer_event(18, 1);
|
||||
EXPECT_EQ(node->buffered_bytes(), 48);
|
||||
EXPECT_EQ(node->buffered_elements(), 3);
|
||||
EXPECT_EQ(node->bytes_produced(), 0);
|
||||
EXPECT_EQ(node->num_elements(), 0);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 48);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 48);
|
||||
|
||||
node->record_buffer_event(-20, -1);
|
||||
node->record_element();
|
||||
node->record_bytes_produced(20);
|
||||
EXPECT_EQ(node->buffered_bytes(), 28);
|
||||
EXPECT_EQ(node->buffered_elements(), 2);
|
||||
EXPECT_EQ(node->bytes_produced(), 20);
|
||||
EXPECT_EQ(node->num_elements(), 1);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 28);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 51);
|
||||
|
||||
node->record_buffer_event(-10, -1);
|
||||
node->record_element();
|
||||
node->record_bytes_produced(10);
|
||||
EXPECT_EQ(node->buffered_bytes(), 18);
|
||||
EXPECT_EQ(node->buffered_elements(), 1);
|
||||
EXPECT_EQ(node->bytes_produced(), 30);
|
||||
EXPECT_EQ(node->num_elements(), 2);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 18);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 49.5);
|
||||
|
||||
EXPECT_EQ(node->processing_time(), 0);
|
||||
node->record_start(1);
|
||||
@ -447,22 +450,32 @@ TEST(SetterGetterTest, Node) {
|
||||
node->add_processing_time(2);
|
||||
EXPECT_EQ(node->processing_time(), 42);
|
||||
|
||||
std::shared_ptr<TestNode> input =
|
||||
std::make_shared<TestNode>(model::Node::Args{-1, "TestInput", node});
|
||||
std::shared_ptr<Node> input = model::MakeAsyncKnownRatioNode(
|
||||
{0, "TestInput", node}, 2,
|
||||
{model::MakeParameter("parallelism",
|
||||
std::make_shared<SharedState>(5, nullptr, nullptr),
|
||||
0, 6)});
|
||||
EXPECT_EQ(input->output(), node.get());
|
||||
EXPECT_EQ(node->inputs().size(), 0);
|
||||
node->add_input(input);
|
||||
EXPECT_EQ(node->inputs().size(), 1);
|
||||
EXPECT_EQ(node->inputs().front(), input);
|
||||
input->record_buffer_event(13, 0);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 0);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
|
||||
|
||||
input->record_buffer_event(28, 1);
|
||||
EXPECT_EQ(node->bytes_consumed(), 0);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 46);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5);
|
||||
|
||||
input->record_buffer_event(-28, -1);
|
||||
input->record_element();
|
||||
input->record_bytes_produced(28);
|
||||
node->record_bytes_consumed(28);
|
||||
EXPECT_EQ(node->bytes_consumed(), 28);
|
||||
EXPECT_EQ(node->TotalBufferedBytes(), 18);
|
||||
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5);
|
||||
|
||||
node->remove_input(input);
|
||||
EXPECT_EQ(node->inputs().size(), 0);
|
||||
|
||||
EXPECT_EQ(node->num_elements(), 0);
|
||||
node->record_element();
|
||||
EXPECT_EQ(node->num_elements(), 1);
|
||||
}
|
||||
|
||||
// Returns a weighted sum of a prior and the actual processing time.
|
||||
|
@ -527,6 +527,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
"Failed to allocate memory for the batch of component ", i);
|
||||
}
|
||||
}
|
||||
RecordBufferEnqueue(ctx.get(), result->output);
|
||||
result->output_allocated = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -536,6 +537,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) {
|
||||
mutex_lock l(result->mu);
|
||||
if (result->output_allocated) {
|
||||
RecordBufferDequeue(ctx, result->output);
|
||||
}
|
||||
if (result->num_elements == 0) {
|
||||
if (result->status.ok() || errors::IsOutOfRange(result->status)) {
|
||||
*end_of_sequence = true;
|
||||
|
Loading…
x
Reference in New Issue
Block a user