Fix graph partitioning algorithm to use real node indices from execution plan.
PiperOrigin-RevId: 255037006
This commit is contained in:
parent
64a9a35d4e
commit
6a88595b5d
@ -137,6 +137,7 @@ class TestGraphInfo : public GraphInfo {
|
||||
const TfLiteNode& node(size_t index) const override {
|
||||
return graph_->nodes()[index];
|
||||
}
|
||||
size_t node_index(size_t index) const override { return index; }
|
||||
const std::vector<int>& inputs() const override { return graph_->inputs(); }
|
||||
const std::vector<int>& outputs() const override { return graph_->outputs(); }
|
||||
const std::vector<int>& variables() const override {
|
||||
|
@ -135,6 +135,9 @@ class InterpreterInfo : public GraphInfo {
|
||||
int node_index = subgraph_->execution_plan()[index];
|
||||
return subgraph_->nodes_and_registration()[node_index].first;
|
||||
}
|
||||
size_t node_index(size_t index) const override {
|
||||
return subgraph_->execution_plan()[index];
|
||||
}
|
||||
const std::vector<int>& inputs() const override {
|
||||
return subgraph_->inputs();
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl {
|
||||
// automatically true.
|
||||
if (current_subset.type == node_type_[node_index]) {
|
||||
node_epochs_[node_index] = current_epoch;
|
||||
current_subset.nodes.push_back(node_index);
|
||||
current_subset.nodes.push_back(info_->node_index(node_index));
|
||||
// All outputs of this node now are assigned to this epoch as
|
||||
// well.
|
||||
for (int output_tensor_index : TfLiteIntArrayView(node.outputs)) {
|
||||
|
@ -41,6 +41,10 @@ class GraphInfo {
|
||||
// num_nodes().
|
||||
virtual const TfLiteNode& node(size_t index) const = 0;
|
||||
|
||||
// Returns an implementation-speicfic node index which may be different from
|
||||
// index.
|
||||
virtual size_t node_index(size_t index) const = 0;
|
||||
|
||||
// Returns the indices of the input tensors.
|
||||
virtual const std::vector<int>& inputs() const = 0;
|
||||
|
||||
|
@ -32,6 +32,11 @@ TfLiteIntArray* ConvertVector(const std::vector<int>& x) {
|
||||
// A very simple test graph that supports setting in/out tensors on nodes.
|
||||
class SimpleTestGraph : public GraphInfo {
|
||||
public:
|
||||
explicit SimpleTestGraph(int node_index_offset = 0)
|
||||
: node_index_offset_(node_index_offset) {
|
||||
for (int i = 0; i < node_index_offset; ++i) AddNode({}, {});
|
||||
}
|
||||
|
||||
~SimpleTestGraph() override {
|
||||
for (auto& node : nodes_) {
|
||||
TfLiteIntArrayFree(node.inputs);
|
||||
@ -39,9 +44,16 @@ class SimpleTestGraph : public GraphInfo {
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_nodes() const override {
|
||||
return nodes_.size() - node_index_offset_;
|
||||
}
|
||||
const TfLiteNode& node(size_t index) const override {
|
||||
return nodes_[index + node_index_offset_];
|
||||
}
|
||||
size_t node_index(size_t index) const override {
|
||||
return index + node_index_offset_;
|
||||
}
|
||||
size_t num_tensors() const override { return tensors_.size(); }
|
||||
size_t num_nodes() const override { return nodes_.size(); }
|
||||
const TfLiteNode& node(size_t index) const override { return nodes_[index]; }
|
||||
TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; }
|
||||
const std::vector<int>& inputs() const override { return inputs_; }
|
||||
const std::vector<int>& outputs() const override { return outputs_; }
|
||||
@ -64,6 +76,7 @@ class SimpleTestGraph : public GraphInfo {
|
||||
}
|
||||
|
||||
private:
|
||||
size_t node_index_offset_;
|
||||
std::vector<TfLiteNode> nodes_;
|
||||
std::vector<TfLiteTensor> tensors_;
|
||||
std::vector<int> inputs_;
|
||||
@ -143,6 +156,24 @@ TEST(PartitionTest, Nodes1PartitionNodes0) {
|
||||
CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph});
|
||||
}
|
||||
|
||||
TEST(PartitionTest, Nodes1PartitionNodes0WithOffset) {
|
||||
constexpr int node_index_offset = 17;
|
||||
SimpleTestGraph graph(node_index_offset);
|
||||
graph.AddTensors(2);
|
||||
graph.AddNode({0}, {1});
|
||||
graph.SetInputsAndOutputs({0}, {1});
|
||||
std::vector<int> nodes_to_partition = {};
|
||||
std::vector<NodeSubset> generated_subgraphs;
|
||||
PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
|
||||
|
||||
NodeSubset expected_subgraph;
|
||||
expected_subgraph.type = NodeSubset::kTfNonPartition;
|
||||
expected_subgraph.nodes = {node_index_offset};
|
||||
expected_subgraph.input_tensors = {0};
|
||||
expected_subgraph.output_tensors = {1};
|
||||
CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph});
|
||||
}
|
||||
|
||||
// Test a 1 node graph with no inputs that is fully partitioned.
|
||||
// Input: node(0) -> tensor(1), nodes_to_partition=[node0]
|
||||
// Output: [kTfPartition, node(0) -> tensor(1)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user