[tf.data] Use all produced elements during the process to estimate the average size instead of only current elements in the buffer.

PiperOrigin-RevId: 337941261
Change-Id: I67d4ce30de2cc7dd43442b55de65ab0164aedaec
This commit is contained in:
Jay Shi 2020-10-19 15:00:21 -07:00 committed by TensorFlower Gardener
parent a361baa862
commit 19fda561d8
4 changed files with 138 additions and 66 deletions

View File

@ -311,6 +311,15 @@ class AsyncInterleaveMany : public Node {
(*total_processing_times)[long_name()] =
self_processing_time + inputs_processing_time;
}
double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
double result = 0;
auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
if (parameter) {
result += (*parameter)->value * AverageBufferedElementSize();
}
return result;
}
};
class KnownRatio : public Node {
@ -593,6 +602,26 @@ class AsyncKnownRatio : public Node {
self_processing_time + inputs_processing_time;
}
double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
double result = 0;
auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
if (!parameter) {
parameter = gtl::FindOrNull(parameters_, kParallelism);
}
if (parameter) {
if (ratio_ == 0) {
result += (*parameter)->value * AverageBufferedElementSize();
} else {
// The estimation is currently not accurate for MapAndBatchDataset for
// the maximum buffer size does not match `num_parallel_calls`
// parameter.
result += (*parameter)->value * AverageBufferedElementSize() / ratio_;
}
}
return result;
}
private:
const double ratio_;
};
@ -1067,11 +1096,34 @@ double Node::TotalProcessingTime(
}
double Node::AverageBufferedElementSize() const {
if (buffered_elements_ == 0) {
return 0;
DCHECK_GE(num_elements_, 0);
DCHECK_GE(buffered_elements_, 0);
if (num_elements_ <= 0) {
if (buffered_elements_ <= 0) {
// If there are no produced elements or buffered elements recorded, return
// 0.
return 0;
}
// If there are no produced elements but some buffered elements, return the
// average size of all buffered elements.
return static_cast<double>(buffered_bytes_) /
static_cast<double>(buffered_elements_);
}
return static_cast<double>(buffered_bytes_) /
static_cast<double>(buffered_elements_);
if (buffered_elements_ <= 0) {
// If there are no buffered elements but some produced elements, return the
// average size of all produced elements.
return static_cast<double>(bytes_produced_) /
static_cast<double>(num_elements_);
}
// Otherwise, return the mean value of average size of all produced elements
// and average size of all buffered elements.
return (static_cast<double>(bytes_produced_) /
static_cast<double>(num_elements_) +
static_cast<double>(buffered_bytes_) /
static_cast<double>(buffered_elements_)) /
2.0;
}
double Node::OutputTimeForInputs(
@ -1275,20 +1327,17 @@ void Node::TotalMaximumBufferedBytesHelper(
return;
}
double result = 0;
auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
if (!parameter) {
parameter = gtl::FindOrNull(parameters_, kParallelism);
}
if (parameter) {
result = (*parameter)->value * AverageBufferedElementSize();
}
double result = MaximumBufferedBytes();
for (auto& input : inputs_) {
result += total_bytes->at(input->long_name());
}
total_bytes->insert(std::make_pair(long_name(), result));
}
double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
return 0;
}
void Model::AddNode(Node::Factory factory, const string& name,
std::shared_ptr<Node> parent,
std::shared_ptr<Node>* out_node) {

View File

@ -517,6 +517,12 @@ class Node {
absl::flat_hash_map<string, double>* total_bytes) const
TF_SHARED_LOCKS_REQUIRED(mu_);
// Compute and return the maximum buffered bytes on the node itself. By
// default non-tunable nodes are assumed not to buffer any bytes, so the
// tunable nodes as subclasses are expected to override this method to ensure
// that the optimization algorithm respects the memory budget.
virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_);
// Stores the time passed to the last call to `Node::record_start()` on the
// current thread.
//

View File

@ -131,7 +131,9 @@ TEST_P(AsyncKnownRatioTest, Model) {
async_known_many->record_buffer_event(110, 10);
EXPECT_EQ(async_known_many->TotalBufferedBytes(), 110);
EXPECT_EQ(async_known_many->TotalMaximumBufferedBytes(),
110 * parallelism / 10);
num_inputs_per_output == 0
? 110.0 * parallelism / 10
: 110.0 * parallelism / 10 / num_inputs_per_output);
source1->add_processing_time(100);
EXPECT_EQ(async_known_many->TotalProcessingTime(/*processing_times=*/nullptr),
0);
@ -385,41 +387,12 @@ TEST(UnknownTest, Model) {
EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
}
class TestNode : public model::Node {
public:
using model::Node::Node;
virtual ~TestNode() {}
protected:
std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
return nullptr;
}
void InputTimeLocked(absl::flat_hash_map<string, double>* input_times)
const override TF_SHARED_LOCKS_REQUIRED(mu_) {}
void OutputTimeLocked(
const absl::flat_hash_map<string, double>& input_times,
absl::flat_hash_map<string, double>* gradients,
absl::flat_hash_map<string, double>* output_times,
absl::flat_hash_map<string, double>* output_time_gradients) const override
TF_SHARED_LOCKS_REQUIRED(mu_) {
(*output_times)[long_name()] = 0;
}
void TotalProcessingTimeLocked(
absl::flat_hash_map<string, double>* processing_times,
absl::flat_hash_map<string, double>* total_processing_times) override
TF_SHARED_LOCKS_REQUIRED(mu_) {
(*total_processing_times)[long_name()] = 0;
}
};
TEST(SetterGetterTest, Node) {
std::shared_ptr<TestNode> node =
std::make_shared<TestNode>(model::Node::Args{-1, "TestNode", nullptr});
std::shared_ptr<Node> node = model::MakeAsyncInterleaveManyNode(
{-1, "TestNode", nullptr},
{model::MakeParameter("parallelism",
std::make_shared<SharedState>(3, nullptr, nullptr),
1, 7)});
EXPECT_EQ(node->id(), -1);
EXPECT_EQ(node->name(), "TestNode");
EXPECT_EQ(node->output(), nullptr);
@ -428,16 +401,46 @@ TEST(SetterGetterTest, Node) {
EXPECT_EQ(node->buffered_elements(), 0);
EXPECT_EQ(node->TotalBufferedBytes(), 0);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
node->record_buffer_event(42, 0);
EXPECT_EQ(node->buffered_bytes(), 42);
EXPECT_EQ(node->TotalBufferedBytes(), 0);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
EXPECT_EQ(node->buffered_elements(), 0);
node->record_buffer_event(0, 11);
EXPECT_EQ(node->buffered_bytes(), 42);
EXPECT_EQ(node->TotalBufferedBytes(), 0);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
EXPECT_EQ(node->buffered_elements(), 11);
node->record_buffer_event(20, 1);
EXPECT_EQ(node->buffered_bytes(), 20);
EXPECT_EQ(node->buffered_elements(), 1);
EXPECT_EQ(node->TotalBufferedBytes(), 20);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 60);
node->record_buffer_event(10, 1);
EXPECT_EQ(node->buffered_bytes(), 30);
EXPECT_EQ(node->buffered_elements(), 2);
EXPECT_EQ(node->TotalBufferedBytes(), 30);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 45);
node->record_buffer_event(18, 1);
EXPECT_EQ(node->buffered_bytes(), 48);
EXPECT_EQ(node->buffered_elements(), 3);
EXPECT_EQ(node->bytes_produced(), 0);
EXPECT_EQ(node->num_elements(), 0);
EXPECT_EQ(node->TotalBufferedBytes(), 48);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 48);
node->record_buffer_event(-20, -1);
node->record_element();
node->record_bytes_produced(20);
EXPECT_EQ(node->buffered_bytes(), 28);
EXPECT_EQ(node->buffered_elements(), 2);
EXPECT_EQ(node->bytes_produced(), 20);
EXPECT_EQ(node->num_elements(), 1);
EXPECT_EQ(node->TotalBufferedBytes(), 28);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 51);
node->record_buffer_event(-10, -1);
node->record_element();
node->record_bytes_produced(10);
EXPECT_EQ(node->buffered_bytes(), 18);
EXPECT_EQ(node->buffered_elements(), 1);
EXPECT_EQ(node->bytes_produced(), 30);
EXPECT_EQ(node->num_elements(), 2);
EXPECT_EQ(node->TotalBufferedBytes(), 18);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 49.5);
EXPECT_EQ(node->processing_time(), 0);
node->record_start(1);
@ -447,22 +450,32 @@ TEST(SetterGetterTest, Node) {
node->add_processing_time(2);
EXPECT_EQ(node->processing_time(), 42);
std::shared_ptr<TestNode> input =
std::make_shared<TestNode>(model::Node::Args{-1, "TestInput", node});
std::shared_ptr<Node> input = model::MakeAsyncKnownRatioNode(
{0, "TestInput", node}, 2,
{model::MakeParameter("parallelism",
std::make_shared<SharedState>(5, nullptr, nullptr),
0, 6)});
EXPECT_EQ(input->output(), node.get());
EXPECT_EQ(node->inputs().size(), 0);
node->add_input(input);
EXPECT_EQ(node->inputs().size(), 1);
EXPECT_EQ(node->inputs().front(), input);
input->record_buffer_event(13, 0);
EXPECT_EQ(node->TotalBufferedBytes(), 0);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 0);
input->record_buffer_event(28, 1);
EXPECT_EQ(node->bytes_consumed(), 0);
EXPECT_EQ(node->TotalBufferedBytes(), 46);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5);
input->record_buffer_event(-28, -1);
input->record_element();
input->record_bytes_produced(28);
node->record_bytes_consumed(28);
EXPECT_EQ(node->bytes_consumed(), 28);
EXPECT_EQ(node->TotalBufferedBytes(), 18);
EXPECT_EQ(node->TotalMaximumBufferedBytes(), 119.5);
node->remove_input(input);
EXPECT_EQ(node->inputs().size(), 0);
EXPECT_EQ(node->num_elements(), 0);
node->record_element();
EXPECT_EQ(node->num_elements(), 1);
}
// Returns a weighted sum of a prior and the actual processing time.

View File

@ -527,6 +527,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
"Failed to allocate memory for the batch of component ", i);
}
}
RecordBufferEnqueue(ctx.get(), result->output);
result->output_allocated = true;
return Status::OK();
}
@ -536,6 +537,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
mutex_lock l(result->mu);
if (result->output_allocated) {
RecordBufferDequeue(ctx, result->output);
}
if (result->num_elements == 0) {
if (result->status.ok() || errors::IsOutOfRange(result->status)) {
*end_of_sequence = true;