Added a test to validate that the cost graph is properly exported from direct

sessions when requested.
Change: 135732642
This commit is contained in:
Benoit Steiner 2016-10-10 15:02:40 -08:00 committed by TensorFlower Gardener
parent 091f625372
commit 7ba74d62d7

View File

@ -198,5 +198,67 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelWithHardwareStats) {
TestHWAccelerator(true);
}
TEST(DirectSessionWithTrackingAllocTest, CostGraph) {
EnableCPUAllocatorFullStats(true);
Graph graph(OpRegistry::Global());
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
Node* a = test::graph::Constant(&graph, a_tensor);
a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&x_tensor, {1, 1});
Node* x = test::graph::Constant(&graph, x_tensor);
x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
// y = A * x
Node* y = test::graph::Matmul(&graph, a, x, false, false);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
options.config.mutable_graph_options()->set_build_cost_model(true);
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(OptimizerOptions::L0);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;
// Request two targets: one fetch output and one non-fetched output.
RunOptions run_options;
std::vector<string> output_names = {y->name() + ":0"};
std::vector<string> target_nodes = {y_neg->name()};
std::vector<Tensor> outputs;
RunMetadata run_metadata;
const int64 start_micros = Env::Default()->NowMicros();
Status s = session->Run(run_options, inputs, output_names, target_nodes,
&outputs, &run_metadata);
const int64 run_duration_micros = Env::Default()->NowMicros() - start_micros;
TF_ASSERT_OK(s);
EXPECT_LE(2, run_metadata.cost_graph().node_size());
for (const auto& node : run_metadata.cost_graph().node()) {
if (node.name() == y->name() || node.name() == y_neg->name()) {
EXPECT_EQ(1, node.output_info_size());
EXPECT_LE(8, node.output_info(0).size());
const TensorShapeProto& shape = node.output_info(0).shape();
EXPECT_EQ(2, shape.dim_size());
EXPECT_EQ(2, shape.dim(0).size());
EXPECT_EQ(1, shape.dim(1).size());
}
EXPECT_LE(0, node.compute_cost());
EXPECT_GE(run_duration_micros, node.compute_cost());
}
}
} // namespace
} // namespace tensorflow