Don't prune nodes that have reference inputs.
PiperOrigin-RevId: 163390862
This commit is contained in:
parent
2265108340
commit
e5353c941c
@ -134,6 +134,7 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
@ -26,8 +28,24 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
GraphRewriter::GraphRewriter(const GrapplerItem& item) {
|
||||
OpRegistryInterface* op_registry = OpRegistry::Global();
|
||||
for (auto& node : item.graph.node()) {
|
||||
nodes_[node.name()] = &node;
|
||||
NodeInfo* info = new NodeInfo();
|
||||
info->def = &node;
|
||||
|
||||
const OpRegistrationData* op_reg_data = nullptr;
|
||||
Status s = op_registry->LookUp(node.op(), &op_reg_data);
|
||||
// TODO(bsteiner): make this not a best-effort lookup and evaluation?
|
||||
if (s.ok()) {
|
||||
s = InOutTypesForNode(node, op_reg_data->op_def, &info->inputs,
|
||||
&info->outputs);
|
||||
if (!s.ok()) {
|
||||
info->inputs.clear();
|
||||
info->outputs.clear();
|
||||
}
|
||||
}
|
||||
|
||||
nodes_[node.name()].reset(info);
|
||||
}
|
||||
|
||||
std::unordered_set<string> function_names;
|
||||
@ -73,11 +91,16 @@ bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const {
|
||||
return cross_device_receivers_.find(&node) != cross_device_receivers_.end();
|
||||
}
|
||||
|
||||
bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const {
|
||||
return ref_receivers_.find(&node) != ref_receivers_.end();
|
||||
}
|
||||
|
||||
void GraphRewriter::RecordConnectivity(
|
||||
const NodeDef& node, const std::unordered_set<string>& function_names) {
|
||||
const bool is_function =
|
||||
function_names.find(node.op()) != function_names.end();
|
||||
|
||||
bool ref_receiver = false;
|
||||
for (const auto& input : node.input()) {
|
||||
int position = 0;
|
||||
string input_node_name = ParseNodeName(input, &position);
|
||||
@ -85,7 +108,8 @@ void GraphRewriter::RecordConnectivity(
|
||||
if (itr == nodes_.end()) {
|
||||
continue;
|
||||
}
|
||||
const NodeDef* fanin = itr->second;
|
||||
const NodeInfo* fanin_info = itr->second.get();
|
||||
const NodeDef* fanin = fanin_info->def;
|
||||
if (position < 0) {
|
||||
// This is a control edge
|
||||
control_dependency_drivers_.insert(fanin);
|
||||
@ -97,11 +121,20 @@ void GraphRewriter::RecordConnectivity(
|
||||
if (is_function) {
|
||||
function_neighbors_.insert(fanin);
|
||||
}
|
||||
|
||||
if (position < fanin_info->outputs.size() &&
|
||||
IsRefType(fanin_info->outputs[position])) {
|
||||
ref_receiver = true;
|
||||
}
|
||||
}
|
||||
if (fanin->device() != node.device()) {
|
||||
cross_device_receivers_.insert(&node);
|
||||
}
|
||||
}
|
||||
|
||||
if (ref_receiver) {
|
||||
ref_receivers_.insert(&node);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphRewriter::ForwardInputsInternal(
|
||||
@ -125,7 +158,7 @@ void GraphRewriter::ForwardInputsInternal(
|
||||
*new_node->add_input() = input;
|
||||
continue;
|
||||
}
|
||||
const NodeDef* input_node = itr->second;
|
||||
const NodeDef* input_node = itr->second->def;
|
||||
if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
|
||||
ForwardInputsInternal(*input_node, nodes_to_delete, new_node);
|
||||
} else {
|
||||
|
@ -55,6 +55,9 @@ class GraphRewriter {
|
||||
// device.
|
||||
bool IsDrivenByAnotherDevice(const NodeDef& node) const;
|
||||
|
||||
// Returns true if the node has input from a stateful op.
|
||||
bool ReceivesRefValue(const NodeDef& node) const;
|
||||
|
||||
private:
|
||||
void RecordConnectivity(const NodeDef& node,
|
||||
const std::unordered_set<string>& function_names);
|
||||
@ -63,11 +66,21 @@ class GraphRewriter {
|
||||
const std::unordered_set<const NodeDef*>& nodes_to_delete,
|
||||
NodeDef* new_node);
|
||||
|
||||
std::unordered_map<string, const NodeDef*> nodes_;
|
||||
struct NodeInfo {
|
||||
const NodeDef* def;
|
||||
|
||||
// These are filled in when the NodeInfo is built, but not that they
|
||||
// may be empty - if the op could not be loaded from the registry.
|
||||
DataTypeVector inputs;
|
||||
DataTypeVector outputs;
|
||||
};
|
||||
|
||||
std::unordered_map<string, std::unique_ptr<NodeInfo>> nodes_;
|
||||
std::unordered_map<string, const NodeDef*> optimized_nodes_;
|
||||
std::unordered_set<const NodeDef*> control_dependency_drivers_;
|
||||
std::unordered_set<const NodeDef*> function_neighbors_;
|
||||
std::unordered_set<const NodeDef*> cross_device_receivers_;
|
||||
std::unordered_set<const NodeDef*> ref_receivers_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -74,20 +74,23 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't remove nodes that drive control dependencies.
|
||||
// Don't remove nodes that are driven by control dependencies either since
|
||||
// we can't ensure (yet) that we won't increase the number of control
|
||||
// dependency edges by deleting them (for example, removing a node driven by
|
||||
// 10 control edges and driving 10 control edges would result in the
|
||||
// creation of 100 edges).
|
||||
// Don't modify nodes that are connected to functions since that can result
|
||||
// in inlining failures later on.
|
||||
// Don't prune nodes that are driven by another device since these could be
|
||||
// used to reduce cross device communication.
|
||||
// - Don't remove nodes that drive control dependencies.
|
||||
// - Don't remove nodes that are driven by control dependencies either since
|
||||
// we can't ensure (yet) that we won't increase the number of control
|
||||
// dependency edges by deleting them (for example, removing a node driven
|
||||
// by 10 control edges and driving 10 control edges would result in the
|
||||
// creation of 100 edges).
|
||||
// - Don't modify nodes that are connected to functions since that can
|
||||
// result in inlining failures later on.
|
||||
// - Don't prune nodes that are driven by another device since these could
|
||||
// be used to reduce cross device communication.
|
||||
// - Don't remove nodes that receive reference values, as those can be
|
||||
// converting references to non-references.
|
||||
if (!rewriter.DrivesControlDependency(node) &&
|
||||
!rewriter.IsDrivenByControlDependency(node) &&
|
||||
!rewriter.IsConnectedToFunction(node) &&
|
||||
!rewriter.IsDrivenByAnotherDevice(node)) {
|
||||
!rewriter.IsDrivenByAnotherDevice(node) &&
|
||||
!rewriter.ReceivesRefValue(node)) {
|
||||
nodes_to_delete.insert(&node);
|
||||
}
|
||||
}
|
||||
|
@ -199,6 +199,46 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
|
||||
EXPECT_EQ("^c", new_e.input(1));
|
||||
}
|
||||
|
||||
TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
// Make graph of Identity(Identity(Identity(Identity(Variable)))).
|
||||
Output a = ops::Variable(s.WithOpName("a"), {}, DT_INT64);
|
||||
Output b = ops::Identity(s.WithOpName("b"), a);
|
||||
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||
Output d = ops::Identity(s.WithOpName("d"), c);
|
||||
Output e = ops::Identity(s.WithOpName("e"), d);
|
||||
|
||||
// Run pruner.
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
ModelPruner pruner;
|
||||
GraphDef output;
|
||||
Status status = pruner.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
// Get the updated nodes.
|
||||
ASSERT_EQ(5, output.node_size());
|
||||
const NodeDef& new_a = output.node(0);
|
||||
const NodeDef& new_b = output.node(1);
|
||||
const NodeDef& new_c = output.node(2);
|
||||
const NodeDef& new_d = output.node(3);
|
||||
const NodeDef& new_e = output.node(4);
|
||||
EXPECT_EQ("a", new_a.name());
|
||||
EXPECT_EQ("b", new_b.name());
|
||||
EXPECT_EQ("c", new_c.name());
|
||||
EXPECT_EQ("d", new_d.name());
|
||||
EXPECT_EQ("e", new_e.name());
|
||||
|
||||
// Verify the connections. Identity "b" can't be removed from the chain
|
||||
// because it is converting a reference input to a non-reference, so c,d,e all
|
||||
// refer to it as an input.
|
||||
EXPECT_EQ("a", new_b.input(0));
|
||||
EXPECT_EQ("b", new_c.input(0));
|
||||
EXPECT_EQ("b", new_d.input(0));
|
||||
EXPECT_EQ("b", new_e.input(0));
|
||||
}
|
||||
|
||||
TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) {
|
||||
// Build a simple graph with a few trivially prunable ops.
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
Loading…
Reference in New Issue
Block a user