Added a test to validate that the cost graph is properly exported from direct
sessions when requested. Change: 135732642
This commit is contained in:
parent
091f625372
commit
7ba74d62d7
@ -198,5 +198,67 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelWithHardwareStats) {
|
||||
TestHWAccelerator(true);
|
||||
}
|
||||
|
||||
TEST(DirectSessionWithTrackingAllocTest, CostGraph) {
|
||||
EnableCPUAllocatorFullStats(true);
|
||||
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
|
||||
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
|
||||
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||
a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
||||
|
||||
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
|
||||
test::FillValues<float>(&x_tensor, {1, 1});
|
||||
Node* x = test::graph::Constant(&graph, x_tensor);
|
||||
x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||
|
||||
// y = A * x
|
||||
Node* y = test::graph::Matmul(&graph, a, x, false, false);
|
||||
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
|
||||
|
||||
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
|
||||
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
|
||||
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&graph, &def);
|
||||
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
options.config.mutable_graph_options()->set_build_cost_model(true);
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(OptimizerOptions::L0);
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
|
||||
// Request two targets: one fetch output and one non-fetched output.
|
||||
RunOptions run_options;
|
||||
std::vector<string> output_names = {y->name() + ":0"};
|
||||
std::vector<string> target_nodes = {y_neg->name()};
|
||||
std::vector<Tensor> outputs;
|
||||
RunMetadata run_metadata;
|
||||
const int64 start_micros = Env::Default()->NowMicros();
|
||||
Status s = session->Run(run_options, inputs, output_names, target_nodes,
|
||||
&outputs, &run_metadata);
|
||||
const int64 run_duration_micros = Env::Default()->NowMicros() - start_micros;
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
EXPECT_LE(2, run_metadata.cost_graph().node_size());
|
||||
for (const auto& node : run_metadata.cost_graph().node()) {
|
||||
if (node.name() == y->name() || node.name() == y_neg->name()) {
|
||||
EXPECT_EQ(1, node.output_info_size());
|
||||
EXPECT_LE(8, node.output_info(0).size());
|
||||
const TensorShapeProto& shape = node.output_info(0).shape();
|
||||
EXPECT_EQ(2, shape.dim_size());
|
||||
EXPECT_EQ(2, shape.dim(0).size());
|
||||
EXPECT_EQ(1, shape.dim(1).size());
|
||||
}
|
||||
EXPECT_LE(0, node.compute_cost());
|
||||
EXPECT_GE(run_duration_micros, node.compute_cost());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user