Add GraphTopologyView to efficiently traverse node-to-node connections.
+ Remove SimpleGraphView from topological sorting. PiperOrigin-RevId: 226932668
This commit is contained in:
parent
87af907e8c
commit
9585116b80
@ -62,6 +62,36 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_topology_view",
|
||||
srcs = ["graph_topology_view.cc"],
|
||||
hdrs = ["graph_topology_view.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_view",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "graph_topology_view_test",
|
||||
srcs = ["graph_topology_view_test.cc"],
|
||||
deps = [
|
||||
":graph_topology_view",
|
||||
":graph_view",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_view",
|
||||
srcs = ["graph_view.cc"],
|
||||
|
@ -425,9 +425,11 @@ NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
|
||||
// information is refined.
|
||||
class TopoQueue {
|
||||
public:
|
||||
explicit TopoQueue(const std::unordered_map<const NodeDef*, int>& topo_order)
|
||||
: topo_order_(topo_order) {}
|
||||
explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
|
||||
: topo_order_(TopoOrder(topo_order)) {}
|
||||
|
||||
void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
|
||||
|
||||
const NodeDef* pop() {
|
||||
CHECK(!empty());
|
||||
auto it = queue_.begin();
|
||||
@ -448,7 +450,18 @@ class TopoQueue {
|
||||
return lhs.second < rhs.second;
|
||||
}
|
||||
};
|
||||
const std::unordered_map<const NodeDef*, int>& topo_order_;
|
||||
|
||||
const std::unordered_map<const NodeDef*, int> TopoOrder(
|
||||
const std::vector<const NodeDef*>& topo_order) const {
|
||||
std::unordered_map<const NodeDef*, int> map;
|
||||
map.reserve(topo_order.size());
|
||||
for (int i = 0; i < topo_order.size(); ++i) {
|
||||
map.emplace(topo_order[i], i);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
const std::unordered_map<const NodeDef*, int> topo_order_;
|
||||
std::set<NodeAndId, OrderByIdAscending> queue_;
|
||||
};
|
||||
|
||||
@ -1970,7 +1983,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
}
|
||||
|
||||
std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
|
||||
std::vector<std::pair<const NodeDef*, const NodeDef*>> extra_deps;
|
||||
std::vector<TopologicalDependency> extra_deps;
|
||||
for (const auto& resource : resources) {
|
||||
for (const NodeDef* src : resource.second.first) {
|
||||
resource_handles[src] = resource.first;
|
||||
@ -1982,8 +1995,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<const NodeDef*, int> topo_order;
|
||||
Status s = ComputeTopologicalOrder(item_.graph, &topo_order, &extra_deps);
|
||||
std::vector<const NodeDef*> topo_order;
|
||||
Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
|
||||
if (!s.ok()) {
|
||||
if (extra_deps.empty()) {
|
||||
return s;
|
||||
@ -1992,8 +2005,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
// order. This will make the shape inference less precise but since this
|
||||
// isn't common it's not worth to figure out where to break the loop and
|
||||
// do a proper relaxation.
|
||||
TF_RETURN_IF_ERROR(
|
||||
ComputeTopologicalOrder(item_.graph, &topo_order, nullptr));
|
||||
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
|
||||
}
|
||||
}
|
||||
|
||||
|
163
tensorflow/core/grappler/graph_topology_view.cc
Normal file
163
tensorflow/core/grappler/graph_topology_view.cc
Normal file
@ -0,0 +1,163 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/graph_topology_view.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
inline void SortAndRemoveDuplicates(T* v) {
|
||||
std::sort(v->begin(), v->end());
|
||||
v->erase(std::unique(v->begin(), v->end()), v->end());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GraphTopologyView::InitializeFromGraph(
|
||||
const GraphDef& graph,
|
||||
const absl::Span<const GraphView::Edge> ephemeral_edges) {
|
||||
if (graph_ != nullptr) {
|
||||
return errors::InvalidArgument("GraphTopologyView is already initialized.");
|
||||
}
|
||||
|
||||
graph_ = &graph;
|
||||
num_nodes_ = graph.node_size();
|
||||
index_to_node_name_.resize(num_nodes_);
|
||||
node_name_to_index_.rehash(num_nodes_);
|
||||
fanins_.resize(num_nodes_);
|
||||
fanouts_.resize(num_nodes_);
|
||||
|
||||
// Build map from name to index and vice versa.
|
||||
for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
|
||||
const NodeDef& node = graph.node(node_idx);
|
||||
node_name_to_index_.emplace(node.name(), node_idx);
|
||||
index_to_node_name_.emplace_back(node.name());
|
||||
}
|
||||
|
||||
// 1. Add ephemeral edges to the adjacency lists.
|
||||
for (const GraphView::Edge& edge : ephemeral_edges) {
|
||||
const auto src = node_name_to_index_.find(edge.src.node->name());
|
||||
if (src == node_name_to_index_.end()) {
|
||||
return errors::InvalidArgument("Non-existent src node: ",
|
||||
edge.src.node->name());
|
||||
}
|
||||
const auto dst = node_name_to_index_.find(edge.dst.node->name());
|
||||
if (dst == node_name_to_index_.end()) {
|
||||
return errors::InvalidArgument("Non-existent dst node: ",
|
||||
edge.dst.node->name());
|
||||
}
|
||||
const int src_idx = src->second;
|
||||
const int dst_idx = dst->second;
|
||||
fanins_[dst_idx].push_back(src_idx);
|
||||
fanouts_[src_idx].push_back(dst_idx);
|
||||
}
|
||||
|
||||
// 2. Add graph edges to the adjacency lists.
|
||||
for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
|
||||
const NodeDef& node = graph.node(node_idx);
|
||||
fanins_[node_idx].reserve(node.input_size());
|
||||
|
||||
for (const string& input : node.input()) {
|
||||
TensorId tensor = ParseTensorName(input);
|
||||
const auto it = node_name_to_index_.find(tensor.node());
|
||||
if (it == node_name_to_index_.end()) {
|
||||
return errors::InvalidArgument("Non-existent input ", input,
|
||||
" for node ", node.name());
|
||||
}
|
||||
const int input_idx = it->second;
|
||||
fanins_[node_idx].push_back(input_idx);
|
||||
fanouts_[input_idx].push_back(node_idx);
|
||||
}
|
||||
|
||||
// Dedup the input list while it's still hot in cache.
|
||||
SortAndRemoveDuplicates(&fanins_[node_idx]);
|
||||
}
|
||||
|
||||
// Dedup outputs for all the graph nodes.
|
||||
for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
|
||||
SortAndRemoveDuplicates(&fanouts_[node_idx]);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) {
|
||||
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>());
|
||||
}
|
||||
|
||||
bool GraphTopologyView::HasNode(const absl::string_view node_name) const {
|
||||
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
|
||||
const auto it = node_name_to_index_.find(node_name);
|
||||
return it != node_name_to_index_.end();
|
||||
}
|
||||
|
||||
const NodeDef* GraphTopologyView::GetNode(
|
||||
const absl::string_view node_name) const {
|
||||
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
|
||||
const auto it = node_name_to_index_.find(node_name);
|
||||
return it == node_name_to_index_.end() ? nullptr : &graph_->node(it->second);
|
||||
}
|
||||
|
||||
const NodeDef* GraphTopologyView::GetNode(int node_idx) const {
|
||||
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
|
||||
DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
|
||||
return &graph_->node(node_idx);
|
||||
}
|
||||
|
||||
const absl::optional<int> GraphTopologyView::GetNodeIndex(
|
||||
const absl::string_view node_name) const {
|
||||
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
|
||||
const auto it = node_name_to_index_.find(node_name);
|
||||
DCHECK(it != node_name_to_index_.end()) << "Node doesn't exist in a graph";
|
||||
return it == node_name_to_index_.end() ? absl::nullopt
|
||||
: absl::make_optional(it->second);
|
||||
}
|
||||
|
||||
const absl::optional<int> GraphTopologyView::GetNodeIndex(
|
||||
const NodeDef& node) const {
|
||||
return GetNodeIndex(node.name());
|
||||
}
|
||||
|
||||
const absl::InlinedVector<int, 4>& GraphTopologyView::GetFanin(
|
||||
int node_idx) const {
|
||||
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
|
||||
const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
|
||||
DCHECK(is_valid_node_idx) << "node_idx is out of range";
|
||||
return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
|
||||
}
|
||||
|
||||
const absl::InlinedVector<int, 2>& GraphTopologyView::GetFanout(
|
||||
int node_idx) const {
|
||||
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
|
||||
const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
|
||||
DCHECK(is_valid_node_idx) << "node_idx is out of range";
|
||||
return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
105
tensorflow/core/grappler/graph_topology_view.h
Normal file
105
tensorflow/core/grappler/graph_topology_view.h
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// GraphTopologyView is a helper class to simplify `node-to-node` connectivity
|
||||
// traversals. Regular `GraphView` simplifies `tensor-to-tensor` traversals:
|
||||
// connections between output tensors and inputs of a consumer nodes. For the
|
||||
// topology view we are focused on nodes connected to nodes, and it's irrelevant
|
||||
// if this connection is formed by one or multiple individual tensors.
|
||||
//
|
||||
// Example:
|
||||
// a = Placeholder(..)
|
||||
// b = Placeholder(..)
|
||||
// c = AddN([a, a, b])
|
||||
//
|
||||
// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:3]
|
||||
// GraphTopologyView edges: [a -> c, b -> c]
|
||||
//
|
||||
// GraphView is used for exploring single node fanins and fanouts, and
|
||||
// GraphTopologyView is focused on efficient full graph traversals (computing
|
||||
// graph node properties from transitive fanouts, etc...).
|
||||
class GraphTopologyView {
|
||||
public:
|
||||
GraphTopologyView() = default;
|
||||
|
||||
// Initialize graph topology view from the graph. It's possible to pass
|
||||
// additional edges that do not exist in a graph, but must be respected when
|
||||
// computing graph topology. Example: Tensorflow runtime allows concurrent
|
||||
// execution of dequeue/enqueue ops from the same queue resource, but we might
|
||||
// want to enforce ordering between them for the purpose of graph analysis.
|
||||
Status InitializeFromGraph(const GraphDef& graph,
|
||||
absl::Span<const GraphView::Edge> ephemeral_edges);
|
||||
Status InitializeFromGraph(const GraphDef& graph);
|
||||
|
||||
bool is_initialized() const { return graph_ != nullptr; }
|
||||
int num_nodes() const { return num_nodes_; }
|
||||
const GraphDef* graph() const { return graph_; }
|
||||
|
||||
// Returns true iff the node exists in the underlying graph.
|
||||
bool HasNode(absl::string_view node_name) const;
|
||||
|
||||
// Finds a node by name or returns `nullptr` if it's not in the graph.
|
||||
const NodeDef* GetNode(absl::string_view node_name) const;
|
||||
// Returns a node corresponding to the given node index.
|
||||
const NodeDef* GetNode(int node_idx) const;
|
||||
|
||||
// Returns a node index for the given node name, if the name exists in the
|
||||
// underlying graph. Otherwise returns empty optional.
|
||||
const absl::optional<int> GetNodeIndex(absl::string_view node_name) const;
|
||||
// Returns a node index for the given node, if the node belongs to the
|
||||
// underlying graph. Otherwise returns empty optional.
|
||||
const absl::optional<int> GetNodeIndex(const NodeDef& node) const;
|
||||
|
||||
// Returns all the node indexes that are in the direct fanin of the given
|
||||
// node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
|
||||
const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
|
||||
// Returns all the node indexes that are in the direct fanout of the given
|
||||
// node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
|
||||
const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
|
||||
|
||||
private:
|
||||
// WARN: `graph_` must outlive this object and graph nodes must not be
|
||||
// destructed, because node names captured with absl::string_view.
|
||||
const GraphDef* graph_ = nullptr; // do not own
|
||||
int num_nodes_ = 0;
|
||||
std::vector<absl::string_view> index_to_node_name_;
|
||||
absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
|
||||
std::vector<absl::InlinedVector<int, 4>> fanins_; // node_idx->input nodes
|
||||
std::vector<absl::InlinedVector<int, 2>> fanouts_; // node_idx->output nodes
|
||||
|
||||
// We need a valid reference to return from GetFanin/GetFanout if the
|
||||
// `node_idx` argument is outside of the [0, num_nodes_) range.
|
||||
absl::InlinedVector<int, 4> empty_fanin_;
|
||||
absl::InlinedVector<int, 2> empty_fanout_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_
|
117
tensorflow/core/grappler/graph_topology_view_test.cc
Normal file
117
tensorflow/core/grappler/graph_topology_view_test.cc
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/graph_topology_view.h"
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class GraphTopologyViewTest : public ::testing::Test {
|
||||
protected:
|
||||
using NodeConfig = std::pair<string, std::vector<string>>;
|
||||
|
||||
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
|
||||
GraphDef graph;
|
||||
|
||||
for (const NodeConfig& node : nodes) {
|
||||
const auto& node_name = node.first;
|
||||
const auto& node_inputs = node.second;
|
||||
|
||||
NodeDef node_def;
|
||||
node_def.set_name(node_name);
|
||||
for (const string& input : node_inputs) {
|
||||
node_def.add_input(input);
|
||||
}
|
||||
|
||||
*graph.add_node() = std::move(node_def);
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GraphTopologyViewTest, SimpleGraph) {
|
||||
const GraphDef graph = CreateGraph({
|
||||
{"a", {}}, // idx: 0
|
||||
{"b", {}}, // idx: 1
|
||||
{"c", {"a", "b"}}, // idx: 2
|
||||
{"d", {"a", "c"}}, // idx: 3
|
||||
});
|
||||
|
||||
GraphTopologyView graph_view;
|
||||
TF_CHECK_OK(graph_view.InitializeFromGraph(graph));
|
||||
|
||||
EXPECT_TRUE(graph_view.is_initialized());
|
||||
|
||||
const NodeDef* a_by_name = graph_view.GetNode("a");
|
||||
const NodeDef* a_by_idx = graph_view.GetNode(0);
|
||||
ASSERT_TRUE(a_by_name);
|
||||
ASSERT_TRUE(a_by_idx);
|
||||
EXPECT_EQ(a_by_name, a_by_idx);
|
||||
|
||||
const NodeDef* b_by_name = graph_view.GetNode("b");
|
||||
const NodeDef* b_by_idx = graph_view.GetNode(1);
|
||||
ASSERT_TRUE(b_by_name);
|
||||
ASSERT_TRUE(b_by_idx);
|
||||
EXPECT_EQ(b_by_name, b_by_idx);
|
||||
|
||||
const absl::optional<int> b_idx = graph_view.GetNodeIndex(*b_by_name);
|
||||
ASSERT_TRUE(b_idx.has_value());
|
||||
EXPECT_EQ(b_idx.value(), 1);
|
||||
|
||||
const absl::optional<int> c_idx = graph_view.GetNodeIndex("c");
|
||||
ASSERT_TRUE(c_idx.has_value());
|
||||
EXPECT_EQ(c_idx.value(), 2);
|
||||
|
||||
using Fanin = absl::InlinedVector<int, 4>;
|
||||
EXPECT_EQ(graph_view.GetFanin(0), Fanin());
|
||||
EXPECT_EQ(graph_view.GetFanin(1), Fanin());
|
||||
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1}));
|
||||
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
|
||||
|
||||
using Fanout = absl::InlinedVector<int, 2>;
|
||||
EXPECT_EQ(graph_view.GetFanout(0), Fanout({2, 3}));
|
||||
EXPECT_EQ(graph_view.GetFanout(1), Fanout({2}));
|
||||
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
|
||||
EXPECT_EQ(graph_view.GetFanout(3), Fanout());
|
||||
}
|
||||
|
||||
TEST_F(GraphTopologyViewTest, GraphWithALoop) {
|
||||
const GraphDef graph = CreateGraph({
|
||||
{"a", {}}, // idx: 0
|
||||
{"b", {}}, // idx: 1
|
||||
{"c", {"a", "b", "d"}}, // idx: 2 <<<--- 'c' and 'd' have a loop
|
||||
{"d", {"a", "c"}}, // idx: 3
|
||||
});
|
||||
|
||||
GraphTopologyView graph_view;
|
||||
TF_CHECK_OK(graph_view.InitializeFromGraph(graph));
|
||||
EXPECT_TRUE(graph_view.is_initialized());
|
||||
|
||||
using Fanin = absl::InlinedVector<int, 4>;
|
||||
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1, 3}));
|
||||
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
|
||||
|
||||
using Fanout = absl::InlinedVector<int, 2>;
|
||||
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
|
||||
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -48,8 +48,11 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_topology_view",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -58,10 +61,11 @@ tf_cc_test(
|
||||
srcs = ["topological_sort_test.cc"],
|
||||
deps = [
|
||||
":topological_sort",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,10 +14,15 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/graph_topology_view.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -25,27 +30,46 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<GraphView::Edge> MakeEphemeralEdges(
|
||||
const absl::Span<const TopologicalDependency> extra_dependencies) {
|
||||
std::vector<GraphView::Edge> ephemeral_edges;
|
||||
ephemeral_edges.reserve(extra_dependencies.size());
|
||||
for (const auto& dep : extra_dependencies) {
|
||||
ephemeral_edges.emplace_back(
|
||||
GraphView::OutputPort(dep.from, Graph::kControlSlot),
|
||||
GraphView::InputPort(dep.to, Graph::kControlSlot));
|
||||
}
|
||||
return ephemeral_edges;
|
||||
}
|
||||
|
||||
// Kahn's algorithm is implemented.
|
||||
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
|
||||
Status ComputeTopologicalOrder(
|
||||
const GraphDef& graph, std::vector<int>* ready_nodes,
|
||||
const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
|
||||
extra_dependencies) {
|
||||
SimpleGraphView graph_view;
|
||||
TF_RETURN_IF_ERROR(graph_view.Initialize(graph, extra_dependencies));
|
||||
const GraphDef& graph,
|
||||
const absl::Span<const TopologicalDependency> extra_dependencies,
|
||||
std::vector<int>* ready_nodes) {
|
||||
GraphTopologyView graph_view;
|
||||
TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
|
||||
graph, MakeEphemeralEdges(extra_dependencies)));
|
||||
|
||||
ready_nodes->reserve(graph_view.num_nodes());
|
||||
// Keep track of how many inputs are ready for the given node.
|
||||
std::vector<int> num_ready_inputs(graph.node_size(), 0);
|
||||
|
||||
// We'll push index of ready nodes to this output vector.
|
||||
ready_nodes->reserve(graph.node_size());
|
||||
|
||||
int front = 0;
|
||||
int back = 0;
|
||||
std::vector<int> num_ready_inputs(graph_view.num_nodes(), 0);
|
||||
for (int i = 0; i < graph_view.num_nodes(); i++) {
|
||||
if (graph_view.inputs(i).empty()) {
|
||||
|
||||
for (int i = 0; i < graph.node_size(); i++) {
|
||||
if (graph_view.GetFanin(i).empty()) {
|
||||
ready_nodes->push_back(i);
|
||||
back++;
|
||||
}
|
||||
if (IsMerge(graph.node(i))) {
|
||||
for (int input : graph_view.inputs(i)) {
|
||||
for (int input : graph_view.GetFanin(i)) {
|
||||
if (IsNextIteration(graph.node(input))) {
|
||||
num_ready_inputs[i]++;
|
||||
}
|
||||
@ -55,9 +79,9 @@ Status ComputeTopologicalOrder(
|
||||
|
||||
while (front != back) {
|
||||
int ready_node = (*ready_nodes)[front];
|
||||
for (int fanout : graph_view.outputs(ready_node)) {
|
||||
for (int fanout : graph_view.GetFanout(ready_node)) {
|
||||
++num_ready_inputs[fanout];
|
||||
if (num_ready_inputs[fanout] == graph_view.inputs(fanout).size()) {
|
||||
if (num_ready_inputs[fanout] == graph_view.GetFanin(fanout).size()) {
|
||||
ready_nodes->push_back(fanout);
|
||||
++back;
|
||||
}
|
||||
@ -72,23 +96,32 @@ Status ComputeTopologicalOrder(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ComputeTopologicalOrder(
|
||||
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* topo_order,
|
||||
const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
|
||||
extra_dependencies) {
|
||||
const GraphDef& graph,
|
||||
const absl::Span<const TopologicalDependency> extra_dependencies,
|
||||
std::vector<const NodeDef*>* topo_order) {
|
||||
std::vector<int> ready_nodes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ComputeTopologicalOrder(graph, &ready_nodes, extra_dependencies));
|
||||
topo_order->reserve(graph.node_size());
|
||||
for (int i = 0; i < ready_nodes.size(); ++i) {
|
||||
(*topo_order)[&graph.node(ready_nodes[i])] = i;
|
||||
ComputeTopologicalOrder(graph, extra_dependencies, &ready_nodes));
|
||||
|
||||
topo_order->reserve(ready_nodes.size());
|
||||
for (int ready_node_idx : ready_nodes) {
|
||||
topo_order->emplace_back(&graph.node(ready_node_idx));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeTopologicalOrder(const GraphDef& graph,
|
||||
std::vector<const NodeDef*>* topo_order) {
|
||||
return ComputeTopologicalOrder(graph, {}, topo_order);
|
||||
}
|
||||
|
||||
Status ReversedTopologicalSort(GraphDef* graph) {
|
||||
std::vector<int> ready_nodes;
|
||||
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
|
||||
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
|
||||
std::reverse(ready_nodes.begin(), ready_nodes.end());
|
||||
PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
|
||||
return Status::OK();
|
||||
@ -96,7 +129,7 @@ Status ReversedTopologicalSort(GraphDef* graph) {
|
||||
|
||||
Status TopologicalSort(GraphDef* graph) {
|
||||
std::vector<int> ready_nodes;
|
||||
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
|
||||
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
|
||||
PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -16,22 +16,40 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Compute a topological ordering for the graph nodes.
|
||||
Status ComputeTopologicalOrder(
|
||||
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* topo_order,
|
||||
const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
|
||||
extra_dependencies);
|
||||
// TODO(ezhulenev, b/121379902): We should be consistent with GraphTopologyView
|
||||
// and use `GraphView::Edge` to pass extra dependencies.
|
||||
struct TopologicalDependency {
|
||||
TopologicalDependency(const NodeDef* from, const NodeDef* to)
|
||||
: from(from), to(to) {}
|
||||
const NodeDef* from;
|
||||
const NodeDef* to;
|
||||
};
|
||||
|
||||
// Sort a graph in topological order.
|
||||
// Computes a topological ordering for the graph nodes and outputs nodes in the
|
||||
// topological order to the `topo_order` output argument.
|
||||
//
|
||||
// It's possible to pass additional edges that do not exists in a graph, but
|
||||
// must be respected when computing graph topological order. Example: Tensorflow
|
||||
// runtime allows concurrent execution of dequeue/enqueue ops from the same
|
||||
// queue resource, but we might want to enforce ordering between them.
|
||||
Status ComputeTopologicalOrder(
|
||||
const GraphDef& graph,
|
||||
absl::Span<const TopologicalDependency> extra_dependencies,
|
||||
std::vector<const NodeDef*>* topo_order);
|
||||
Status ComputeTopologicalOrder(const GraphDef& graph,
|
||||
std::vector<const NodeDef*>* topo_order);
|
||||
|
||||
// Sorts a graph in topological order.
|
||||
Status TopologicalSort(GraphDef* graph);
|
||||
|
||||
// Sort a graph in topological order and reverse it.
|
||||
// Sorts a graph in topological order and reverse it.
|
||||
Status ReversedTopologicalSort(GraphDef* graph);
|
||||
|
||||
} // namespace grappler
|
||||
|
@ -14,79 +14,94 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class TopologicalSortTest : public ::testing::Test {
|
||||
protected:
|
||||
static NodeDef CreateNode(const string& name,
|
||||
const std::vector<string>& inputs) {
|
||||
return CreateNode(name, "", inputs);
|
||||
}
|
||||
static NodeDef CreateNode(const string& name, const string& op,
|
||||
const std::vector<string>& inputs) {
|
||||
NodeDef node;
|
||||
node.set_name(name);
|
||||
if (!op.empty()) {
|
||||
node.set_op(op);
|
||||
struct NodeConfig {
|
||||
NodeConfig(string name, std::vector<string> inputs)
|
||||
: name(std::move(name)), inputs(std::move(inputs)) {}
|
||||
NodeConfig(string name, string op, std::vector<string> inputs)
|
||||
: name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {}
|
||||
|
||||
string name;
|
||||
string op;
|
||||
std::vector<string> inputs;
|
||||
};
|
||||
|
||||
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
|
||||
GraphDef graph;
|
||||
|
||||
for (const NodeConfig& node : nodes) {
|
||||
NodeDef node_def;
|
||||
node_def.set_name(node.name);
|
||||
node_def.set_op(node.op);
|
||||
for (const string& input : node.inputs) {
|
||||
node_def.add_input(input);
|
||||
}
|
||||
*graph.add_node() = std::move(node_def);
|
||||
}
|
||||
for (const string& input : inputs) {
|
||||
node.add_input(input);
|
||||
}
|
||||
return node;
|
||||
|
||||
return graph;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TopologicalSortTest, NoLoop) {
|
||||
GraphDef graph;
|
||||
*graph.add_node() = CreateNode("2", {"5"});
|
||||
*graph.add_node() = CreateNode("0", {"5", "4"});
|
||||
*graph.add_node() = CreateNode("1", {"4", "3"});
|
||||
*graph.add_node() = CreateNode("3", {"2"});
|
||||
*graph.add_node() = CreateNode("5", {});
|
||||
*graph.add_node() = CreateNode("4", {});
|
||||
GraphDef graph = CreateGraph({
|
||||
{"2", {"5"}}, //
|
||||
{"0", {"5", "4"}}, //
|
||||
{"1", {"4", "3"}}, //
|
||||
{"3", {"2"}}, //
|
||||
{"5", {}}, //
|
||||
{"4", {}} //
|
||||
});
|
||||
|
||||
std::unordered_map<const NodeDef*, int> topo_order;
|
||||
TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order, nullptr));
|
||||
std::vector<const NodeDef*> topo_order;
|
||||
TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
|
||||
|
||||
const std::vector<string> order = {"5", "4", "2", "0", "3", "1"};
|
||||
for (const auto& topo : topo_order) {
|
||||
const string& node_name = topo.first->name();
|
||||
const int topo_order = topo.second;
|
||||
std::cout << "Node " << node_name << " at order " << topo_order
|
||||
<< std::endl;
|
||||
EXPECT_EQ(node_name, order[topo_order]);
|
||||
|
||||
ASSERT_EQ(topo_order.size(), order.size());
|
||||
for (int i = 0; i < topo_order.size(); ++i) {
|
||||
const NodeDef* node = topo_order[i];
|
||||
EXPECT_EQ(node->name(), order[i]);
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(TopologicalSort(&graph));
|
||||
for (int i = 0; i < order.size(); i++) {
|
||||
for (int i = 0; i < topo_order.size(); i++) {
|
||||
EXPECT_EQ(graph.node(i).name(), order[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, WithLoop) {
|
||||
GraphDef graph;
|
||||
// Create a loop
|
||||
*graph.add_node() = CreateNode("2", "Merge", {"1", "5"});
|
||||
*graph.add_node() = CreateNode("3", "Switch", {"2"});
|
||||
*graph.add_node() = CreateNode("4", "Identity", {"3"});
|
||||
*graph.add_node() = CreateNode("5", "NextIteration", {"4"});
|
||||
*graph.add_node() = CreateNode("1", {});
|
||||
GraphDef graph = CreateGraph({
|
||||
// Graph with a loop.
|
||||
{"2", "Merge", {"1", "5"}}, //
|
||||
{"3", "Switch", {"2"}}, //
|
||||
{"4", "Identity", {"3"}}, //
|
||||
{"5", "NextIteration", {"4"}}, //
|
||||
{"1", {}} //
|
||||
});
|
||||
|
||||
std::unordered_map<const NodeDef*, int> topo_order;
|
||||
TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order, nullptr));
|
||||
std::vector<const NodeDef*> topo_order;
|
||||
TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
|
||||
|
||||
const std::vector<string> order = {"1", "2", "3", "4", "5"};
|
||||
for (const auto& topo : topo_order) {
|
||||
const string& node_name = topo.first->name();
|
||||
const int topo_order = topo.second;
|
||||
EXPECT_EQ(node_name, order[topo_order]);
|
||||
|
||||
ASSERT_EQ(topo_order.size(), order.size());
|
||||
for (int i = 0; i < topo_order.size(); ++i) {
|
||||
const NodeDef* node = topo_order[i];
|
||||
EXPECT_EQ(node->name(), order[i]);
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(TopologicalSort(&graph));
|
||||
@ -96,12 +111,13 @@ TEST_F(TopologicalSortTest, WithLoop) {
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, WithIllegalLoop) {
|
||||
GraphDef graph;
|
||||
// A loop without Merge and NextIteration is illegal and the original node
|
||||
// order and graph will be preserved.
|
||||
*graph.add_node() = CreateNode("2", {"1", "3"});
|
||||
*graph.add_node() = CreateNode("3", {"2"});
|
||||
*graph.add_node() = CreateNode("1", {});
|
||||
GraphDef graph = CreateGraph({
|
||||
{"2", {"1", "3"}}, //
|
||||
{"3", {"2"}}, //
|
||||
{"1", {}} //
|
||||
});
|
||||
|
||||
EXPECT_FALSE(TopologicalSort(&graph).ok());
|
||||
std::vector<string> order = {"2", "3", "1"};
|
||||
@ -111,9 +127,10 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) {
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, DuplicatedInputs) {
|
||||
GraphDef graph;
|
||||
*graph.add_node() = CreateNode("2", {"1", "1"});
|
||||
*graph.add_node() = CreateNode("1", {});
|
||||
GraphDef graph = CreateGraph({
|
||||
{"2", {"1", "1"}}, //
|
||||
{"1", {}} //
|
||||
});
|
||||
|
||||
TF_EXPECT_OK(TopologicalSort(&graph));
|
||||
std::vector<string> order = {"1", "2"};
|
||||
@ -123,12 +140,13 @@ TEST_F(TopologicalSortTest, DuplicatedInputs) {
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, Idempotent) {
|
||||
GraphDef graph;
|
||||
*graph.add_node() = CreateNode("1", {});
|
||||
*graph.add_node() = CreateNode("2", {});
|
||||
*graph.add_node() = CreateNode("3", {"1", "2"});
|
||||
*graph.add_node() = CreateNode("4", {"1", "3"});
|
||||
*graph.add_node() = CreateNode("5", {"2", "3"});
|
||||
GraphDef graph = CreateGraph({
|
||||
{"1", {}}, //
|
||||
{"2", {}}, //
|
||||
{"3", {"1", "2"}}, //
|
||||
{"4", {"1", "3"}}, //
|
||||
{"5", {"2", "3"}} //
|
||||
});
|
||||
|
||||
TF_EXPECT_OK(TopologicalSort(&graph));
|
||||
std::vector<string> order = {"1", "2", "3", "4", "5"};
|
||||
@ -136,7 +154,7 @@ TEST_F(TopologicalSortTest, Idempotent) {
|
||||
EXPECT_EQ(graph.node(i).name(), order[i]);
|
||||
}
|
||||
|
||||
// Run topo sort again to verify that it is idenpotent.
|
||||
// Run topo sort again to verify that it is idempotent.
|
||||
TF_EXPECT_OK(TopologicalSort(&graph));
|
||||
for (int i = 0; i < order.size(); i++) {
|
||||
EXPECT_EQ(graph.node(i).name(), order[i]);
|
||||
@ -144,35 +162,81 @@ TEST_F(TopologicalSortTest, Idempotent) {
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, ExtraDependencies) {
|
||||
GraphDef graph;
|
||||
*graph.add_node() = CreateNode("2", {"5"});
|
||||
*graph.add_node() = CreateNode("0", {"5", "4"});
|
||||
*graph.add_node() = CreateNode("1", {"4", "3"});
|
||||
*graph.add_node() = CreateNode("3", {"2"});
|
||||
*graph.add_node() = CreateNode("5", {});
|
||||
*graph.add_node() = CreateNode("4", {});
|
||||
GraphDef graph = CreateGraph({
|
||||
{"2", {"5"}}, //
|
||||
{"0", {"5", "4"}}, //
|
||||
{"1", {"4", "3"}}, //
|
||||
{"3", {"2"}}, //
|
||||
{"5", {}}, //
|
||||
{"4", {}} //
|
||||
});
|
||||
|
||||
// Add an edge from 4 to 5.
|
||||
std::vector<std::pair<const NodeDef*, const NodeDef*>> extra_dependencies;
|
||||
extra_dependencies.emplace_back(&graph.node(5), &graph.node(4));
|
||||
std::vector<TopologicalDependency> extra_dependencies;
|
||||
extra_dependencies.push_back({&graph.node(5), &graph.node(4)});
|
||||
|
||||
std::unordered_map<const NodeDef*, int> topo_order;
|
||||
TF_EXPECT_OK(
|
||||
ComputeTopologicalOrder(graph, &topo_order, &extra_dependencies));
|
||||
std::vector<const NodeDef*> topo_order;
|
||||
TF_EXPECT_OK(ComputeTopologicalOrder(graph, extra_dependencies, &topo_order));
|
||||
|
||||
const std::vector<string> order = {"4", "5", "2", "0", "3", "1"};
|
||||
for (const auto& topo : topo_order) {
|
||||
const string& node_name = topo.first->name();
|
||||
const int topo_order = topo.second;
|
||||
EXPECT_EQ(node_name, order[topo_order]);
|
||||
const std::vector<string> valid_order_1 = {"4", "5", "2", "0", "3", "1"};
|
||||
const std::vector<string> valid_order_2 = {"4", "5", "0", "2", "3", "1"};
|
||||
|
||||
ASSERT_EQ(topo_order.size(), valid_order_1.size());
|
||||
|
||||
std::vector<string> computed_order(6, "");
|
||||
for (int i = 0; i < topo_order.size(); ++i) {
|
||||
const NodeDef* node = topo_order[i];
|
||||
computed_order[i] = node->name();
|
||||
}
|
||||
EXPECT_TRUE(computed_order == valid_order_1 ||
|
||||
computed_order == valid_order_2);
|
||||
|
||||
// Add an edge from 0 to 4. This will create a loop
|
||||
extra_dependencies.emplace_back(&graph.node(1), &graph.node(5));
|
||||
// Add an edge from `0` to `4`. This will create a loop.
|
||||
extra_dependencies.push_back({&graph.node(1), &graph.node(5)});
|
||||
EXPECT_FALSE(
|
||||
ComputeTopologicalOrder(graph, &topo_order, &extra_dependencies).ok());
|
||||
ComputeTopologicalOrder(graph, extra_dependencies, &topo_order).ok());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
static void BM_ComputeTopologicalOrder(int iters, int size) {
|
||||
testing::StopTiming();
|
||||
|
||||
random::PhiloxRandom philox(0x12345);
|
||||
random::SimplePhilox rnd(&philox);
|
||||
|
||||
string prefix = "long_node_name_prefix_to_measure_string_copy_overhead";
|
||||
|
||||
GraphDef graph;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const string name = absl::StrCat(prefix, i);
|
||||
const uint32 num_inputs = rnd.Uniform(std::min(i, 5));
|
||||
|
||||
NodeDef node;
|
||||
node.set_name(name);
|
||||
for (int n = 0; n < num_inputs; ++n) {
|
||||
const uint32 input_node = rnd.Uniform(i);
|
||||
node.add_input(absl::StrCat(prefix, input_node));
|
||||
}
|
||||
|
||||
*graph.add_node() = std::move(node);
|
||||
}
|
||||
|
||||
testing::StartTiming();
|
||||
std::vector<const NodeDef*> topo_order;
|
||||
for (int i = 0; i < iters; i++) {
|
||||
topo_order.clear();
|
||||
Status st = ComputeTopologicalOrder(graph, &topo_order);
|
||||
CHECK(st.ok()) << "Failed to compute topological order";
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
BENCHMARK(BM_ComputeTopologicalOrder)
|
||||
->Arg(10)
|
||||
->Arg(100)
|
||||
->Arg(1000)
|
||||
->Arg(10000)
|
||||
->Arg(25000)
|
||||
->Arg(50000)
|
||||
->Arg(100000);
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user