[Grappler] Make Grappler play nice with other graph rewriting passes by preserving custom attributes.

PiperOrigin-RevId: 320069254
Change-Id: I7fcbc022a1203a2a9999fad42c41683855d38e42
This commit is contained in:
A. Unique TensorFlower 2020-07-07 15:06:57 -07:00 committed by TensorFlower Gardener
parent 6056572c1e
commit 586744d2e4
9 changed files with 117 additions and 28 deletions

View File

@ -420,7 +420,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// Device placement is preserved.
graph_modified_ = true;
node->set_op("Const");
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["dtype"].set_type(type);
constant_value.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
@ -1790,7 +1790,7 @@ void ConstantFolding::ReplaceOperationWithIdentity(
if (dtype == DT_INVALID) return;
node->set_op("Identity");
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(dtype);
// Propagate the designated input through the identity.
node->mutable_input()->SwapElements(0, input_to_forward);
@ -1821,7 +1821,7 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
if (dtype == DT_INVALID) return;
node->set_op("Snapshot");
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(dtype);
// Propagate the designated input through the Snapshot.
node->mutable_input()->SwapElements(0, input_to_forward);
@ -1840,10 +1840,15 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
// Replace a node with NoOp. Change all inputs to control dependencies.
// If the node has non-control outputs, no change will be performed.
void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph) {
void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
GraphProperties* properties,
GraphDef* graph) {
if (HasRegularOutputs(*node, *node_map_)) return;
node->set_op("NoOp");
node->clear_attr();
EraseRegularNodeAttributes(node);
EraseNodeOutputAttributes(node);
// Erase attributes that describe output properties.
properties->ClearOutputProperties(node->name());
// Change all inputs to control dependencies.
for (int i = 0; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
@ -1890,7 +1895,7 @@ void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
// Rewrite `node` in-place to BroadcastTo.
node->set_op("BroadcastTo");
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(dtype);
(*node->mutable_attr())["Tidx"].set_type(DT_INT32);
// Set the designated input to BroadcastTo.
@ -1940,7 +1945,7 @@ Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
GraphDef* graph) {
if (dtype == DT_VARIANT) return Status::OK();
node->set_op("Const");
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["dtype"].set_type(dtype);
(*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
// Convert all inputs to control dependencies.
@ -2050,7 +2055,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
SET_AND_RETURN_IF_MODIFIED(
PartialAssocOpConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
MergeConcat(use_shape_info, optimized_graph, node));
MergeConcat(use_shape_info, properties, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
PartialConcatConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
@ -2059,7 +2064,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
SET_AND_RETURN_IF_MODIFIED(
SimplifySelect(*properties, optimized_graph, node));
RETURN_IF_MODIFIED(
RemoveRedundantVariableUpdates(*properties, optimized_graph, node));
RemoveRedundantVariableUpdates(properties, optimized_graph, node));
graph_modified_ = graph_modified_cached;
return Status::OK();
@ -2485,8 +2490,7 @@ bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
}
void ConstantFolding::RemoveRedundantVariableUpdates(
const GraphProperties& properties, GraphDef* optimized_graph,
NodeDef* node) {
GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
static const absl::flat_hash_set<string>* kVariableReadOps =
new absl::flat_hash_set<string>{"AssignAddVariableOp",
"AssignSubVariableOp",
@ -2521,9 +2525,9 @@ void ConstantFolding::RemoveRedundantVariableUpdates(
VLOG(1) << "Removing redundant variable update: " << node->DebugString();
if (absl::StrContains(node->op(), "Variable") ||
absl::StrContains(node->op(), "Resource")) {
ReplaceOperationWithNoOp(node, optimized_graph);
ReplaceOperationWithNoOp(node, properties, optimized_graph);
} else {
ReplaceOperationWithIdentity(0 /* input_to_forward */, properties, node,
ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
optimized_graph);
}
}
@ -2762,7 +2766,7 @@ bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
return false;
}
node->set_op("Identity");
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(output_type);
*node->mutable_input(1) = AsControlDependency(node->input(1));
return true;
@ -2852,7 +2856,7 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
}
DataType output_type = node->attr().at("T").type();
node->set_op("Identity");
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(output_type);
*node->mutable_input(1) = AsControlDependency(node->input(1));
return true;
@ -3723,6 +3727,7 @@ bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
}
bool ConstantFolding::MergeConcat(bool use_shape_info,
GraphProperties* properties,
GraphDef* optimized_graph, NodeDef* node) {
// We only optimize for ConcatV2.
int axis;
@ -3791,16 +3796,15 @@ bool ConstantFolding::MergeConcat(bool use_shape_info,
}
}
// Forward Add control inputs
for (int i = num_regular_inputs; i < node->input_size(); ++i) {
const int num_inputs = node->input_size();
for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
parent->add_input(node->input(i));
node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
node->mutable_input()->RemoveLast();
}
node->clear_input();
node->set_op("NoOp");
node->clear_attr();
node_map_->RemoveNode(node->name());
(*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
DedupControlInputs(parent);
ReplaceOperationWithNoOp(node, properties, optimized_graph);
return true;
}

View File

@ -106,7 +106,8 @@ class ConstantFolding : public GraphOptimizer {
void ReplaceOperationWithSnapshot(int input_to_forward,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
void ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph);
void ReplaceOperationWithNoOp(NodeDef* node, GraphProperties* properties,
GraphDef* graph);
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
@ -289,7 +290,7 @@ class ConstantFolding : public GraphOptimizer {
GraphDef* optimized_graph, NodeDef* node);
// Replaces variable updates that are effectively no-ops with NoOp nodes.
void RemoveRedundantVariableUpdates(const GraphProperties& properties,
void RemoveRedundantVariableUpdates(GraphProperties* properties,
GraphDef* optimized_graph, NodeDef* node);
// Removes Reverse op over dimensions with size 1.
@ -311,8 +312,8 @@ class ConstantFolding : public GraphOptimizer {
GraphDef* optimized_graph, NodeDef* node);
bool GetConcatAxis(const NodeDef& node, int* axis);
bool MergeConcat(bool use_shape_info, GraphDef* optimized_graph,
NodeDef* node);
bool MergeConcat(bool use_shape_info, GraphProperties* properties,
GraphDef* optimized_graph, NodeDef* node);
Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
GraphDef* optimized_graph);

View File

@ -44,7 +44,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
if (IsAssert(node) || node.op() == "PrintV2") {
// Convert this node into a no-op.
node.set_op("NoOp");
node.clear_attr();
EraseRegularNodeAttributes(&node);
// Convert all its inputs into control dependency, which will then
// be optimized away by dependency optimizer.
for (string& inp : *node.mutable_input()) {

View File

@ -317,7 +317,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
++pos;
}
node->set_op("NoOp");
node->clear_attr();
EraseRegularNodeAttributes(node);
DedupControlInputs(node);
nodes_to_simplify->PushBack(node_to_idx_[node]);
return;

View File

@ -327,7 +327,7 @@ void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
// Modifies the DeviceIndex node to be an Const op with correct device index.
auto node = device_index_node->node();
node->set_op(kConstOp);
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["dtype"].set_type(DT_INT32);
auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
tensor->set_dtype(DT_INT32);

View File

@ -723,7 +723,7 @@ bool SchedulingPass(Cluster* cluster, std::unique_ptr<GraphMemory>* memory_ptr,
// Rewrite the AddN node as a DestroyTemporaryVariable ops
node->set_op("DestroyTemporaryVariable");
node->clear_input();
node->clear_attr();
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(dtype);
(*node->mutable_attr())["var_name"].set_s(tmp_var->name());
*node->add_input() = initialize->name();

View File

@ -517,5 +517,42 @@ Status IsKernelRegisteredForNode(const NodeDef& node) {
node.device(), AttrSlice(&node.attr()));
}
namespace {
void RemoveAttributes(const std::vector<absl::string_view>& to_remove,
NodeDef* node) {
if (to_remove.size() == node->attr_size()) {
node->clear_attr();
} else {
for (const auto& key : to_remove) {
node->mutable_attr()->erase(string(key));
}
}
}
} // namespace
int EraseRegularNodeAttributes(NodeDef* node) {
std::vector<absl::string_view> to_remove;
for (const auto& attr : node->attr()) {
if (!attr.first.empty() && (attr.first)[0] != '_') {
to_remove.push_back(attr.first);
}
}
RemoveAttributes(to_remove, node);
return to_remove.size();
}
int EraseNodeOutputAttributes(NodeDef* node) {
std::vector<absl::string_view> to_remove;
for (const auto& attr : node->attr()) {
const string& attr_name = attr.first;
if (attr_name == "_xla_inferred_shapes" ||
absl::StartsWith(attr_name, "_output_")) {
to_remove.push_back(attr_name);
}
}
RemoveAttributes(to_remove, node);
return to_remove.size();
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -364,6 +364,14 @@ void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
GraphDef* graph);
// Erase all attributes without leading underscore. Returns the number of
// attributes erased.
int EraseRegularNodeAttributes(NodeDef* node);
// Erase attribute "_xla_inferred_shapes" as well as all attributes starting in
// "_output_".
int EraseNodeOutputAttributes(NodeDef* node);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -520,6 +520,45 @@ TEST_F(UtilsTest, SafeTensorIdToString) {
EXPECT_EQ(SafeTensorIdToString({"foo", 2}), "foo:2");
}
TEST_F(UtilsTest, EraseRegularNodeAttributes) {
NodeDef node;
AttrValue dummy;
node.set_name("foo");
node.set_op("MatMul");
(*node.mutable_attr())["baz"] = dummy;
EXPECT_EQ(EraseRegularNodeAttributes(&node), 1);
EXPECT_EQ(node.attr_size(), 0);
EXPECT_EQ(EraseRegularNodeAttributes(&node), 0);
(*node.mutable_attr())["baz"] = dummy;
(*node.mutable_attr())["_bar"] = dummy;
EXPECT_EQ(EraseRegularNodeAttributes(&node), 1);
EXPECT_EQ(node.attr_size(), 1);
EXPECT_EQ(node.attr().begin()->first, "_bar");
EXPECT_EQ(EraseRegularNodeAttributes(&node), 0);
}
TEST_F(UtilsTest, EraseNodeOutputAttributes) {
NodeDef node;
AttrValue dummy;
node.set_name("foo");
node.set_op("MatMul");
EXPECT_EQ(EraseNodeOutputAttributes(&node), 0);
(*node.mutable_attr())["_xla_inferred_shapes"] = dummy;
EXPECT_EQ(EraseNodeOutputAttributes(&node), 1);
EXPECT_EQ(node.attr_size(), 0);
EXPECT_EQ(EraseNodeOutputAttributes(&node), 0);
(*node.mutable_attr())["baz"] = dummy;
(*node.mutable_attr())["_output_shapes"] = dummy;
(*node.mutable_attr())["_xla_inferred_shapes"] = dummy;
(*node.mutable_attr())["_output_gnu"] = dummy;
EXPECT_EQ(EraseNodeOutputAttributes(&node), 3);
EXPECT_EQ(node.attr_size(), 1);
EXPECT_EQ(node.attr().begin()->first, "baz");
EXPECT_EQ(EraseNodeOutputAttributes(&node), 0);
}
template <typename T>
void TestSetTensorValue(DataType type, int val, bool success,
absl::string_view error_msg) {