Fix graph partitioning algorithm to use real node indices from execution plan.

PiperOrigin-RevId: 255037006
This commit is contained in:
A. Unique TensorFlower 2019-06-25 13:36:21 -07:00 committed by TensorFlower Gardener
parent 64a9a35d4e
commit 6a88595b5d
5 changed files with 42 additions and 3 deletions

View File

@ -137,6 +137,7 @@ class TestGraphInfo : public GraphInfo {
const TfLiteNode& node(size_t index) const override {
return graph_->nodes()[index];
}
size_t node_index(size_t index) const override { return index; }
const std::vector<int>& inputs() const override { return graph_->inputs(); }
const std::vector<int>& outputs() const override { return graph_->outputs(); }
const std::vector<int>& variables() const override {

View File

@ -135,6 +135,9 @@ class InterpreterInfo : public GraphInfo {
int node_index = subgraph_->execution_plan()[index];
return subgraph_->nodes_and_registration()[node_index].first;
}
size_t node_index(size_t index) const override {
return subgraph_->execution_plan()[index];
}
const std::vector<int>& inputs() const override {
return subgraph_->inputs();
}

View File

@ -159,7 +159,7 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl {
// automatically true.
if (current_subset.type == node_type_[node_index]) {
node_epochs_[node_index] = current_epoch;
current_subset.nodes.push_back(node_index);
current_subset.nodes.push_back(info_->node_index(node_index));
// All outputs of this node now are assigned to this epoch as
// well.
for (int output_tensor_index : TfLiteIntArrayView(node.outputs)) {

View File

@ -41,6 +41,10 @@ class GraphInfo {
// num_nodes().
virtual const TfLiteNode& node(size_t index) const = 0;
// Returns an implementation-speicfic node index which may be different from
// index.
virtual size_t node_index(size_t index) const = 0;
// Returns the indices of the input tensors.
virtual const std::vector<int>& inputs() const = 0;

View File

@ -32,6 +32,11 @@ TfLiteIntArray* ConvertVector(const std::vector<int>& x) {
// A very simple test graph that supports setting in/out tensors on nodes.
class SimpleTestGraph : public GraphInfo {
public:
explicit SimpleTestGraph(int node_index_offset = 0)
: node_index_offset_(node_index_offset) {
for (int i = 0; i < node_index_offset; ++i) AddNode({}, {});
}
~SimpleTestGraph() override {
for (auto& node : nodes_) {
TfLiteIntArrayFree(node.inputs);
@ -39,9 +44,16 @@ class SimpleTestGraph : public GraphInfo {
}
}
size_t num_nodes() const override {
return nodes_.size() - node_index_offset_;
}
const TfLiteNode& node(size_t index) const override {
return nodes_[index + node_index_offset_];
}
size_t node_index(size_t index) const override {
return index + node_index_offset_;
}
size_t num_tensors() const override { return tensors_.size(); }
size_t num_nodes() const override { return nodes_.size(); }
const TfLiteNode& node(size_t index) const override { return nodes_[index]; }
TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; }
const std::vector<int>& inputs() const override { return inputs_; }
const std::vector<int>& outputs() const override { return outputs_; }
@ -64,6 +76,7 @@ class SimpleTestGraph : public GraphInfo {
}
private:
size_t node_index_offset_;
std::vector<TfLiteNode> nodes_;
std::vector<TfLiteTensor> tensors_;
std::vector<int> inputs_;
@ -143,6 +156,24 @@ TEST(PartitionTest, Nodes1PartitionNodes0) {
CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph});
}
TEST(PartitionTest, Nodes1PartitionNodes0WithOffset) {
constexpr int node_index_offset = 17;
SimpleTestGraph graph(node_index_offset);
graph.AddTensors(2);
graph.AddNode({0}, {1});
graph.SetInputsAndOutputs({0}, {1});
std::vector<int> nodes_to_partition = {};
std::vector<NodeSubset> generated_subgraphs;
PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
NodeSubset expected_subgraph;
expected_subgraph.type = NodeSubset::kTfNonPartition;
expected_subgraph.nodes = {node_index_offset};
expected_subgraph.input_tensors = {0};
expected_subgraph.output_tensors = {1};
CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph});
}
// Test a 1 node graph with no inputs that is fully partitioned.
// Input: node(0) -> tensor(1), nodes_to_partition=[node0]
// Output: [kTfPartition, node(0) -> tensor(1)]