[Grappler] Make Grappler play nice with other graph rewriting passes by preserving custom attributes.
PiperOrigin-RevId: 320069254 Change-Id: I7fcbc022a1203a2a9999fad42c41683855d38e42
This commit is contained in:
parent
6056572c1e
commit
586744d2e4
@ -420,7 +420,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
// Device placement is preserved.
|
||||
graph_modified_ = true;
|
||||
node->set_op("Const");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["dtype"].set_type(type);
|
||||
constant_value.AsProtoTensorContent(
|
||||
(*node->mutable_attr())["value"].mutable_tensor());
|
||||
@ -1790,7 +1790,7 @@ void ConstantFolding::ReplaceOperationWithIdentity(
|
||||
if (dtype == DT_INVALID) return;
|
||||
|
||||
node->set_op("Identity");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["T"].set_type(dtype);
|
||||
// Propagate the designated input through the identity.
|
||||
node->mutable_input()->SwapElements(0, input_to_forward);
|
||||
@ -1821,7 +1821,7 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
|
||||
if (dtype == DT_INVALID) return;
|
||||
|
||||
node->set_op("Snapshot");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["T"].set_type(dtype);
|
||||
// Propagate the designated input through the Snapshot.
|
||||
node->mutable_input()->SwapElements(0, input_to_forward);
|
||||
@ -1840,10 +1840,15 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
|
||||
|
||||
// Replace a node with NoOp. Change all inputs to control dependencies.
|
||||
// If the node has non-control outputs, no change will be performed.
|
||||
void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph) {
|
||||
void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
|
||||
GraphProperties* properties,
|
||||
GraphDef* graph) {
|
||||
if (HasRegularOutputs(*node, *node_map_)) return;
|
||||
node->set_op("NoOp");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
EraseNodeOutputAttributes(node);
|
||||
// Erase attributes that describe output properties.
|
||||
properties->ClearOutputProperties(node->name());
|
||||
// Change all inputs to control dependencies.
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
if (IsControlInput(node->input(i))) {
|
||||
@ -1890,7 +1895,7 @@ void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
|
||||
|
||||
// Rewrite `node` in-place to BroadcastTo.
|
||||
node->set_op("BroadcastTo");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["T"].set_type(dtype);
|
||||
(*node->mutable_attr())["Tidx"].set_type(DT_INT32);
|
||||
// Set the designated input to BroadcastTo.
|
||||
@ -1940,7 +1945,7 @@ Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
|
||||
GraphDef* graph) {
|
||||
if (dtype == DT_VARIANT) return Status::OK();
|
||||
node->set_op("Const");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["dtype"].set_type(dtype);
|
||||
(*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
|
||||
// Convert all inputs to control dependencies.
|
||||
@ -2050,7 +2055,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
PartialAssocOpConstFolding(optimized_graph, properties, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
MergeConcat(use_shape_info, optimized_graph, node));
|
||||
MergeConcat(use_shape_info, properties, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
PartialConcatConstFolding(optimized_graph, properties, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
@ -2059,7 +2064,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
SimplifySelect(*properties, optimized_graph, node));
|
||||
RETURN_IF_MODIFIED(
|
||||
RemoveRedundantVariableUpdates(*properties, optimized_graph, node));
|
||||
RemoveRedundantVariableUpdates(properties, optimized_graph, node));
|
||||
|
||||
graph_modified_ = graph_modified_cached;
|
||||
return Status::OK();
|
||||
@ -2485,8 +2490,7 @@ bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
|
||||
}
|
||||
|
||||
void ConstantFolding::RemoveRedundantVariableUpdates(
|
||||
const GraphProperties& properties, GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
|
||||
static const absl::flat_hash_set<string>* kVariableReadOps =
|
||||
new absl::flat_hash_set<string>{"AssignAddVariableOp",
|
||||
"AssignSubVariableOp",
|
||||
@ -2521,9 +2525,9 @@ void ConstantFolding::RemoveRedundantVariableUpdates(
|
||||
VLOG(1) << "Removing redundant variable update: " << node->DebugString();
|
||||
if (absl::StrContains(node->op(), "Variable") ||
|
||||
absl::StrContains(node->op(), "Resource")) {
|
||||
ReplaceOperationWithNoOp(node, optimized_graph);
|
||||
ReplaceOperationWithNoOp(node, properties, optimized_graph);
|
||||
} else {
|
||||
ReplaceOperationWithIdentity(0 /* input_to_forward */, properties, node,
|
||||
ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
|
||||
optimized_graph);
|
||||
}
|
||||
}
|
||||
@ -2762,7 +2766,7 @@ bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
|
||||
return false;
|
||||
}
|
||||
node->set_op("Identity");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["T"].set_type(output_type);
|
||||
*node->mutable_input(1) = AsControlDependency(node->input(1));
|
||||
return true;
|
||||
@ -2852,7 +2856,7 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
|
||||
}
|
||||
DataType output_type = node->attr().at("T").type();
|
||||
node->set_op("Identity");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["T"].set_type(output_type);
|
||||
*node->mutable_input(1) = AsControlDependency(node->input(1));
|
||||
return true;
|
||||
@ -3723,6 +3727,7 @@ bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
|
||||
}
|
||||
|
||||
bool ConstantFolding::MergeConcat(bool use_shape_info,
|
||||
GraphProperties* properties,
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
// We only optimize for ConcatV2.
|
||||
int axis;
|
||||
@ -3791,16 +3796,15 @@ bool ConstantFolding::MergeConcat(bool use_shape_info,
|
||||
}
|
||||
}
|
||||
// Forward Add control inputs
|
||||
for (int i = num_regular_inputs; i < node->input_size(); ++i) {
|
||||
const int num_inputs = node->input_size();
|
||||
for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
|
||||
parent->add_input(node->input(i));
|
||||
node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
|
||||
node->mutable_input()->RemoveLast();
|
||||
}
|
||||
node->clear_input();
|
||||
node->set_op("NoOp");
|
||||
node->clear_attr();
|
||||
node_map_->RemoveNode(node->name());
|
||||
(*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
|
||||
DedupControlInputs(parent);
|
||||
ReplaceOperationWithNoOp(node, properties, optimized_graph);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -106,7 +106,8 @@ class ConstantFolding : public GraphOptimizer {
|
||||
void ReplaceOperationWithSnapshot(int input_to_forward,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
void ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph);
|
||||
void ReplaceOperationWithNoOp(NodeDef* node, GraphProperties* properties,
|
||||
GraphDef* graph);
|
||||
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
@ -289,7 +290,7 @@ class ConstantFolding : public GraphOptimizer {
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Replaces variable updates that are effectively no-ops with NoOp nodes.
|
||||
void RemoveRedundantVariableUpdates(const GraphProperties& properties,
|
||||
void RemoveRedundantVariableUpdates(GraphProperties* properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Removes Reverse op over dimensions with size 1.
|
||||
@ -311,8 +312,8 @@ class ConstantFolding : public GraphOptimizer {
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
bool GetConcatAxis(const NodeDef& node, int* axis);
|
||||
bool MergeConcat(bool use_shape_info, GraphDef* optimized_graph,
|
||||
NodeDef* node);
|
||||
bool MergeConcat(bool use_shape_info, GraphProperties* properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
|
||||
GraphDef* optimized_graph);
|
||||
|
@ -44,7 +44,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
if (IsAssert(node) || node.op() == "PrintV2") {
|
||||
// Convert this node into a no-op.
|
||||
node.set_op("NoOp");
|
||||
node.clear_attr();
|
||||
EraseRegularNodeAttributes(&node);
|
||||
// Convert all its inputs into control dependency, which will then
|
||||
// be optimized away by dependency optimizer.
|
||||
for (string& inp : *node.mutable_input()) {
|
||||
|
@ -317,7 +317,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
||||
++pos;
|
||||
}
|
||||
node->set_op("NoOp");
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
DedupControlInputs(node);
|
||||
nodes_to_simplify->PushBack(node_to_idx_[node]);
|
||||
return;
|
||||
|
@ -327,7 +327,7 @@ void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node,
|
||||
// Modifies the DeviceIndex node to be an Const op with correct device index.
|
||||
auto node = device_index_node->node();
|
||||
node->set_op(kConstOp);
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["dtype"].set_type(DT_INT32);
|
||||
auto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
|
||||
tensor->set_dtype(DT_INT32);
|
||||
|
@ -723,7 +723,7 @@ bool SchedulingPass(Cluster* cluster, std::unique_ptr<GraphMemory>* memory_ptr,
|
||||
// Rewrite the AddN node as a DestroyTemporaryVariable ops
|
||||
node->set_op("DestroyTemporaryVariable");
|
||||
node->clear_input();
|
||||
node->clear_attr();
|
||||
EraseRegularNodeAttributes(node);
|
||||
(*node->mutable_attr())["T"].set_type(dtype);
|
||||
(*node->mutable_attr())["var_name"].set_s(tmp_var->name());
|
||||
*node->add_input() = initialize->name();
|
||||
|
@ -517,5 +517,42 @@ Status IsKernelRegisteredForNode(const NodeDef& node) {
|
||||
node.device(), AttrSlice(&node.attr()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
void RemoveAttributes(const std::vector<absl::string_view>& to_remove,
|
||||
NodeDef* node) {
|
||||
if (to_remove.size() == node->attr_size()) {
|
||||
node->clear_attr();
|
||||
} else {
|
||||
for (const auto& key : to_remove) {
|
||||
node->mutable_attr()->erase(string(key));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int EraseRegularNodeAttributes(NodeDef* node) {
|
||||
std::vector<absl::string_view> to_remove;
|
||||
for (const auto& attr : node->attr()) {
|
||||
if (!attr.first.empty() && (attr.first)[0] != '_') {
|
||||
to_remove.push_back(attr.first);
|
||||
}
|
||||
}
|
||||
RemoveAttributes(to_remove, node);
|
||||
return to_remove.size();
|
||||
}
|
||||
|
||||
int EraseNodeOutputAttributes(NodeDef* node) {
|
||||
std::vector<absl::string_view> to_remove;
|
||||
for (const auto& attr : node->attr()) {
|
||||
const string& attr_name = attr.first;
|
||||
if (attr_name == "_xla_inferred_shapes" ||
|
||||
absl::StartsWith(attr_name, "_output_")) {
|
||||
to_remove.push_back(attr_name);
|
||||
}
|
||||
}
|
||||
RemoveAttributes(to_remove, node);
|
||||
return to_remove.size();
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -364,6 +364,14 @@ void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
|
||||
void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
|
||||
GraphDef* graph);
|
||||
|
||||
// Erase all attributes without leading underscore. Returns the number of
|
||||
// attributes erased.
|
||||
int EraseRegularNodeAttributes(NodeDef* node);
|
||||
|
||||
// Erase attribute "_xla_inferred_shapes" as well as all attributes starting in
|
||||
// "_output_".
|
||||
int EraseNodeOutputAttributes(NodeDef* node);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -520,6 +520,45 @@ TEST_F(UtilsTest, SafeTensorIdToString) {
|
||||
EXPECT_EQ(SafeTensorIdToString({"foo", 2}), "foo:2");
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, EraseRegularNodeAttributes) {
|
||||
NodeDef node;
|
||||
AttrValue dummy;
|
||||
node.set_name("foo");
|
||||
node.set_op("MatMul");
|
||||
(*node.mutable_attr())["baz"] = dummy;
|
||||
EXPECT_EQ(EraseRegularNodeAttributes(&node), 1);
|
||||
EXPECT_EQ(node.attr_size(), 0);
|
||||
EXPECT_EQ(EraseRegularNodeAttributes(&node), 0);
|
||||
|
||||
(*node.mutable_attr())["baz"] = dummy;
|
||||
(*node.mutable_attr())["_bar"] = dummy;
|
||||
EXPECT_EQ(EraseRegularNodeAttributes(&node), 1);
|
||||
EXPECT_EQ(node.attr_size(), 1);
|
||||
EXPECT_EQ(node.attr().begin()->first, "_bar");
|
||||
EXPECT_EQ(EraseRegularNodeAttributes(&node), 0);
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, EraseNodeOutputAttributes) {
|
||||
NodeDef node;
|
||||
AttrValue dummy;
|
||||
node.set_name("foo");
|
||||
node.set_op("MatMul");
|
||||
EXPECT_EQ(EraseNodeOutputAttributes(&node), 0);
|
||||
(*node.mutable_attr())["_xla_inferred_shapes"] = dummy;
|
||||
EXPECT_EQ(EraseNodeOutputAttributes(&node), 1);
|
||||
EXPECT_EQ(node.attr_size(), 0);
|
||||
EXPECT_EQ(EraseNodeOutputAttributes(&node), 0);
|
||||
|
||||
(*node.mutable_attr())["baz"] = dummy;
|
||||
(*node.mutable_attr())["_output_shapes"] = dummy;
|
||||
(*node.mutable_attr())["_xla_inferred_shapes"] = dummy;
|
||||
(*node.mutable_attr())["_output_gnu"] = dummy;
|
||||
EXPECT_EQ(EraseNodeOutputAttributes(&node), 3);
|
||||
EXPECT_EQ(node.attr_size(), 1);
|
||||
EXPECT_EQ(node.attr().begin()->first, "baz");
|
||||
EXPECT_EQ(EraseNodeOutputAttributes(&node), 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestSetTensorValue(DataType type, int val, bool success,
|
||||
absl::string_view error_msg) {
|
||||
|
Loading…
Reference in New Issue
Block a user