[tf.data] Add the optimization test when RAM budget is 0.
PiperOrigin-RevId: 338359725 Change-Id: I5c89aab1f65cf170ed4b64078834c883ad230c96
This commit is contained in:
parent
8366f2ecea
commit
c2239c1a2a
@ -251,6 +251,12 @@ class Node {
|
|||||||
// Returns the node output.
|
// Returns the node output.
|
||||||
Node* output() const { return output_; }
|
Node* output() const { return output_; }
|
||||||
|
|
||||||
|
// Returns the parameter value.
|
||||||
|
double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) {
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
return parameters_.at(name)->state->value;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the aggregate processing time.
|
// Returns the aggregate processing time.
|
||||||
int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) {
|
int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) {
|
||||||
return processing_time_;
|
return processing_time_;
|
||||||
|
@ -387,7 +387,7 @@ TEST(UnknownTest, Model) {
|
|||||||
EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
|
EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SetterGetterTest, Node) {
|
TEST(BufferedBytesTest, Node) {
|
||||||
std::shared_ptr<Node> node = model::MakeAsyncInterleaveManyNode(
|
std::shared_ptr<Node> node = model::MakeAsyncInterleaveManyNode(
|
||||||
{-1, "TestNode", nullptr},
|
{-1, "TestNode", nullptr},
|
||||||
{model::MakeParameter("parallelism",
|
{model::MakeParameter("parallelism",
|
||||||
@ -892,6 +892,63 @@ TEST_P(SelfProcessingTimeTest, Model) {
|
|||||||
INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest,
|
INSTANTIATE_TEST_SUITE_P(Test, SelfProcessingTimeTest,
|
||||||
::testing::Values(0, 1, 2, 5, 10, 20, 40));
|
::testing::Values(0, 1, 2, 5, 10, 20, 40));
|
||||||
|
|
||||||
|
class OptimizeZeroRamBudgetTest
|
||||||
|
: public ::testing::TestWithParam<model::AutotuneAlgorithm> {};
|
||||||
|
|
||||||
|
TEST_P(OptimizeZeroRamBudgetTest, Model) {
|
||||||
|
const model::AutotuneAlgorithm algorithm = GetParam();
|
||||||
|
|
||||||
|
std::shared_ptr<mutex> mutex1 = std::make_shared<mutex>();
|
||||||
|
std::shared_ptr<condition_variable> cv1 =
|
||||||
|
std::make_shared<condition_variable>();
|
||||||
|
std::shared_ptr<Node> node1 = model::MakeAsyncKnownRatioNode(
|
||||||
|
{1, "1", nullptr}, 2,
|
||||||
|
{model::MakeParameter("parallelism",
|
||||||
|
std::make_shared<SharedState>(-1, mutex1, cv1), 1,
|
||||||
|
5)});
|
||||||
|
node1->record_buffer_event(1, 1);
|
||||||
|
|
||||||
|
std::shared_ptr<mutex> mutex2 = std::make_shared<mutex>();
|
||||||
|
std::shared_ptr<condition_variable> cv2 =
|
||||||
|
std::make_shared<condition_variable>();
|
||||||
|
std::shared_ptr<Node> node2 = model::MakeAsyncKnownRatioNode(
|
||||||
|
{2, "2", node1}, 5,
|
||||||
|
{model::MakeParameter("buffer_size",
|
||||||
|
std::make_shared<SharedState>(-1, mutex2, cv2), 0,
|
||||||
|
6)});
|
||||||
|
node2->record_buffer_event(1, 1);
|
||||||
|
|
||||||
|
std::shared_ptr<mutex> mutex3 = std::make_shared<mutex>();
|
||||||
|
std::shared_ptr<condition_variable> cv3 =
|
||||||
|
std::make_shared<condition_variable>();
|
||||||
|
std::shared_ptr<Node> node3 = model::MakeAsyncInterleaveManyNode(
|
||||||
|
{3, "3", node2},
|
||||||
|
{model::MakeParameter("parallelism",
|
||||||
|
std::make_shared<SharedState>(-1, mutex3, cv3), 1,
|
||||||
|
7)});
|
||||||
|
node3->record_buffer_event(1, 1);
|
||||||
|
|
||||||
|
EXPECT_EQ(node1->parameter_value("parallelism"), -1);
|
||||||
|
EXPECT_EQ(node2->parameter_value("buffer_size"), -1);
|
||||||
|
EXPECT_EQ(node3->parameter_value("parallelism"), -1);
|
||||||
|
|
||||||
|
model::Model model;
|
||||||
|
model.AddNode([&node1](model::Node::Args args) { return node1; }, "1",
|
||||||
|
nullptr, &node1);
|
||||||
|
model.AddNode([&node2](model::Node::Args args) { return node2; }, "2", node1,
|
||||||
|
&node2);
|
||||||
|
model.AddNode([&node3](model::Node::Args args) { return node3; }, "3", node2,
|
||||||
|
&node3);
|
||||||
|
|
||||||
|
model.Optimize(algorithm, 40, 0, 0);
|
||||||
|
EXPECT_EQ(node1->parameter_value("parallelism"), 1);
|
||||||
|
EXPECT_EQ(node2->parameter_value("buffer_size"), 0);
|
||||||
|
EXPECT_EQ(node3->parameter_value("parallelism"), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(Test, OptimizeZeroRamBudgetTest,
|
||||||
|
::testing::Values(0, 1));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace model
|
} // namespace model
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
Loading…
Reference in New Issue
Block a user