[tf.data] Add the optimization test when RAM budget is 0.

PiperOrigin-RevId: 338359725
Change-Id: I5c89aab1f65cf170ed4b64078834c883ad230c96
This commit is contained in:
Jay Shi 2020-10-21 16:07:04 -07:00 committed by TensorFlower Gardener
parent 8366f2ecea
commit c2239c1a2a
2 changed files with 64 additions and 1 deletions

View File

@ -251,6 +251,12 @@ class Node {
// Returns the node output.
Node* output() const { return output_; }
// Returns the parameter value.
double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) {
tf_shared_lock l(mu_);
return parameters_.at(name)->state->value;
}
// Returns the aggregate processing time.
int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) {
return processing_time_;

View File

@ -387,7 +387,7 @@ TEST(UnknownTest, Model) {
EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
}
TEST(SetterGetterTest, Node) {
TEST(BufferedBytesTest, Node) {
std::shared_ptr<Node> node = model::MakeAsyncInterleaveManyNode(
{-1, "TestNode", nullptr},
{model::MakeParameter("parallelism",
@ -892,6 +892,63 @@ TEST_P(SelfProcessingTimeTest, Model) {
INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest,
::testing::Values(0, 1, 2, 5, 10, 20, 40));
class OptimizeZeroRamBudgetTest
: public ::testing::TestWithParam<model::AutotuneAlgorithm> {};
TEST_P(OptimizeZeroRamBudgetTest, Model) {
const model::AutotuneAlgorithm algorithm = GetParam();
std::shared_ptr<mutex> mutex1 = std::make_shared<mutex>();
std::shared_ptr<condition_variable> cv1 =
std::make_shared<condition_variable>();
std::shared_ptr<Node> node1 = model::MakeAsyncKnownRatioNode(
{1, "1", nullptr}, 2,
{model::MakeParameter("parallelism",
std::make_shared<SharedState>(-1, mutex1, cv1), 1,
5)});
node1->record_buffer_event(1, 1);
std::shared_ptr<mutex> mutex2 = std::make_shared<mutex>();
std::shared_ptr<condition_variable> cv2 =
std::make_shared<condition_variable>();
std::shared_ptr<Node> node2 = model::MakeAsyncKnownRatioNode(
{2, "2", node1}, 5,
{model::MakeParameter("buffer_size",
std::make_shared<SharedState>(-1, mutex2, cv2), 0,
6)});
node2->record_buffer_event(1, 1);
std::shared_ptr<mutex> mutex3 = std::make_shared<mutex>();
std::shared_ptr<condition_variable> cv3 =
std::make_shared<condition_variable>();
std::shared_ptr<Node> node3 = model::MakeAsyncInterleaveManyNode(
{3, "3", node2},
{model::MakeParameter("parallelism",
std::make_shared<SharedState>(-1, mutex3, cv3), 1,
7)});
node3->record_buffer_event(1, 1);
EXPECT_EQ(node1->parameter_value("parallelism"), -1);
EXPECT_EQ(node2->parameter_value("buffer_size"), -1);
EXPECT_EQ(node3->parameter_value("parallelism"), -1);
model::Model model;
model.AddNode([&node1](model::Node::Args args) { return node1; }, "1",
nullptr, &node1);
model.AddNode([&node2](model::Node::Args args) { return node2; }, "2", node1,
&node2);
model.AddNode([&node3](model::Node::Args args) { return node3; }, "3", node2,
&node3);
model.Optimize(algorithm, 40, 0, 0);
EXPECT_EQ(node1->parameter_value("parallelism"), 1);
EXPECT_EQ(node2->parameter_value("buffer_size"), 0);
EXPECT_EQ(node3->parameter_value("parallelism"), 1);
}
INSTANTIATE_TEST_SUITE_P(Test, OptimizeZeroRamBudgetTest,
::testing::Values(0, 1));
} // namespace
} // namespace model
} // namespace data