[Grappler:GraphView] Add NodeDiff only if mutation is not a no-op
PiperOrigin-RevId: 276609854 Change-Id: I3508fef7f3a7ebb2c235b54b49685521d52b6e31
This commit is contained in:
parent
55b19ed22e
commit
78529cfe02
@ -262,14 +262,16 @@ MutationNewNode Mutation::AddNode(NodeDef&& node, Status* status) {
|
||||
|
||||
void Mutation::AddMutation(
|
||||
MutableNodeView* node,
|
||||
std::function<void(MutableNodeViewDiff*)> mutate_fn) {
|
||||
std::function<bool(MutableNodeViewDiff*)> mutate_fn) {
|
||||
DCHECK(node->graph_view_ == graph_view_);
|
||||
if (node->update_index_ == internal::kMissingIndex) {
|
||||
MutableNodeViewDiff diff(graph_view_, node->node_index_);
|
||||
// If mutation is a no-op return and do not add it to the `updated_nodes_`.
|
||||
if (!mutate_fn(&diff)) return;
|
||||
node->update_index_ = updated_nodes_.size();
|
||||
updated_nodes_.emplace_back(graph_view_, node->node_index_);
|
||||
mutate_fn(&updated_nodes_.back());
|
||||
updated_nodes_.push_back(std::move(diff));
|
||||
} else if (!removed_nodes_.contains(node->node_index_)) {
|
||||
auto& diff = updated_nodes_[node->update_index_];
|
||||
MutableNodeViewDiff& diff = updated_nodes_[node->update_index_];
|
||||
mutate_fn(&diff);
|
||||
}
|
||||
}
|
||||
@ -290,7 +292,7 @@ void Mutation::RemoveNode(MutableNodeView* node) {
|
||||
|
||||
void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
|
||||
AddMutation(node, [name](MutableNodeViewDiff* diff) {
|
||||
internal::UpdateName(diff, name);
|
||||
return internal::UpdateName(diff, name);
|
||||
});
|
||||
}
|
||||
|
||||
@ -301,8 +303,9 @@ void Mutation::UpdateNodeName(const MutationNewNode& node,
|
||||
}
|
||||
|
||||
void Mutation::UpdateNodeOp(MutableNodeView* node, absl::string_view op) {
|
||||
AddMutation(
|
||||
node, [op](MutableNodeViewDiff* diff) { internal::UpdateOp(diff, op); });
|
||||
AddMutation(node, [op](MutableNodeViewDiff* diff) {
|
||||
return internal::UpdateOp(diff, op);
|
||||
});
|
||||
}
|
||||
|
||||
void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) {
|
||||
@ -313,7 +316,7 @@ void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) {
|
||||
void Mutation::UpdateNodeDevice(MutableNodeView* node,
|
||||
absl::string_view device) {
|
||||
AddMutation(node, [device](MutableNodeViewDiff* diff) {
|
||||
internal::UpdateDevice(diff, device);
|
||||
return internal::UpdateDevice(diff, device);
|
||||
});
|
||||
}
|
||||
|
||||
@ -326,7 +329,7 @@ void Mutation::UpdateNodeDevice(const MutationNewNode& node,
|
||||
void Mutation::AddOrUpdateRegularFanin(MutableNodeView* node, int index,
|
||||
const TensorId& fanin) {
|
||||
AddMutation(node, [index, fanin](MutableNodeViewDiff* diff) {
|
||||
internal::AddOrUpdateRegularFanin(diff, index, fanin);
|
||||
return internal::AddOrUpdateRegularFanin(diff, index, fanin);
|
||||
});
|
||||
}
|
||||
|
||||
@ -340,7 +343,7 @@ void Mutation::AddOrUpdateRegularFanin(const MutationNewNode& node, int index,
|
||||
|
||||
void Mutation::RemoveRegularFanin(MutableNodeView* node, int index) {
|
||||
AddMutation(node, [index](MutableNodeViewDiff* diff) {
|
||||
internal::RemoveRegularFanin(diff, index);
|
||||
return internal::RemoveRegularFanin(diff, index);
|
||||
});
|
||||
}
|
||||
|
||||
@ -357,7 +360,7 @@ void Mutation::AddControllingFanin(MutableNodeView* node,
|
||||
const int control_index = it != node->controlling_fanins_index_.end()
|
||||
? it->second
|
||||
: internal::kMissingIndex;
|
||||
internal::AddControllingFanin(diff, control_index, fanin_node_name);
|
||||
return internal::AddControllingFanin(diff, control_index, fanin_node_name);
|
||||
});
|
||||
}
|
||||
|
||||
@ -374,7 +377,8 @@ void Mutation::RemoveControllingFanin(MutableNodeView* node,
|
||||
const int control_index = it != node->controlling_fanins_index_.end()
|
||||
? it->second
|
||||
: internal::kMissingIndex;
|
||||
internal::RemoveControllingFanin(diff, control_index, fanin_node_name);
|
||||
return internal::RemoveControllingFanin(diff, control_index,
|
||||
fanin_node_name);
|
||||
});
|
||||
}
|
||||
|
||||
@ -388,7 +392,7 @@ void Mutation::AddOrUpdateNodeAttr(MutableNodeView* node,
|
||||
absl::string_view attr_name,
|
||||
const AttrValue& attr_value) {
|
||||
AddMutation(node, [attr_name, attr_value](MutableNodeViewDiff* diff) {
|
||||
internal::AddOrUpdateAttribute(diff, attr_name, attr_value);
|
||||
return internal::AddOrUpdateAttribute(diff, attr_name, attr_value);
|
||||
});
|
||||
}
|
||||
|
||||
@ -403,7 +407,7 @@ void Mutation::AddOrUpdateNodeAttr(const MutationNewNode& node,
|
||||
void Mutation::RemoveNodeAttr(MutableNodeView* node,
|
||||
absl::string_view attr_name) {
|
||||
AddMutation(node, [attr_name](MutableNodeViewDiff* diff) {
|
||||
internal::RemoveAttribute(diff, attr_name);
|
||||
return internal::RemoveAttribute(diff, attr_name);
|
||||
});
|
||||
}
|
||||
|
||||
@ -807,10 +811,10 @@ Status MutableGraphView::CheckKernelRegisteredForNodes() {
|
||||
diff.processed_attrs =
|
||||
AttrValueMap(node->attr().begin(), node->attr().end());
|
||||
for (const auto& attr_to_remove : diff.attrs_to_remove) {
|
||||
diff.processed_attrs.erase(attr_to_remove);
|
||||
(*diff.processed_attrs).erase(attr_to_remove);
|
||||
}
|
||||
for (const auto& attr_to_add : diff.attrs_to_add) {
|
||||
gtl::InsertOrUpdate(&diff.processed_attrs, attr_to_add.first,
|
||||
gtl::InsertOrUpdate(&(*diff.processed_attrs), attr_to_add.first,
|
||||
attr_to_add.second);
|
||||
}
|
||||
const string& device = diff.update_device ? diff.device : node->device();
|
||||
@ -823,7 +827,7 @@ Status MutableGraphView::CheckKernelRegisteredForNodes() {
|
||||
node->has_experimental_debug_info(),
|
||||
node->experimental_debug_info(),
|
||||
diff.update_op ? diff.op : node->op(), device,
|
||||
AttrSlice(&diff.processed_attrs));
|
||||
AttrSlice(&(*diff.processed_attrs)));
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << s.error_message();
|
||||
}
|
||||
@ -1189,7 +1193,7 @@ void MutableGraphView::ApplyNodeUpdates() {
|
||||
if (diff.update_device) {
|
||||
node_def->set_device(diff.device);
|
||||
}
|
||||
node_def->mutable_attr()->swap(diff.processed_attrs);
|
||||
node_def->mutable_attr()->swap((*diff.processed_attrs));
|
||||
|
||||
// Updated fanins. Only one of `regular_inputs_to_remove_` or
|
||||
// `regular_inputs_to_add_` can be set.
|
||||
|
@ -353,8 +353,12 @@ class Mutation {
|
||||
void ResetInternal();
|
||||
|
||||
using MutableNodeViewDiff = internal::NodeViewDiff<MutableGraphView>;
|
||||
|
||||
// Adds a mutation to the `node`. Mutation function `mutate_fn` must return
|
||||
// `true` if it actually does any mutations. If it returns `false` mutation
|
||||
// will be ignored.
|
||||
void AddMutation(MutableNodeView* node,
|
||||
std::function<void(MutableNodeViewDiff*)> mutate_fn);
|
||||
std::function<bool(MutableNodeViewDiff*)> mutate_fn);
|
||||
|
||||
MutableGraphView* graph_view_ = nullptr;
|
||||
int mutation_counter_ = 0;
|
||||
|
@ -389,13 +389,15 @@ struct NodeViewDiff {
|
||||
std::set<int> controlling_inputs_to_remove;
|
||||
absl::flat_hash_map<string, AttrValue> attrs_to_add;
|
||||
absl::flat_hash_set<string> attrs_to_remove;
|
||||
AttrValueMap processed_attrs;
|
||||
// AttrValueMap constructor and destructor are very expensive, we will
|
||||
// initialize it lazily only if needed.
|
||||
absl::optional<AttrValueMap> processed_attrs;
|
||||
};
|
||||
|
||||
// Updates node name. If `name` is the same as the name in the original node,
|
||||
// the field will be cleared in the diff.
|
||||
template <typename GraphViewT>
|
||||
inline void UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) {
|
||||
inline bool UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) {
|
||||
if (diff->graph_view->GetNode(diff->node_index)->GetName() == name) {
|
||||
diff->name.clear();
|
||||
diff->update_name = false;
|
||||
@ -403,12 +405,13 @@ inline void UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) {
|
||||
diff->name = string(name);
|
||||
diff->update_name = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Updates node op. If `op` is the same as the op in the original node, the
|
||||
// field will be cleared in the diff.
|
||||
template <typename GraphViewT>
|
||||
inline void UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) {
|
||||
inline bool UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) {
|
||||
if (diff->graph_view->GetNode(diff->node_index)->GetOp() == op) {
|
||||
diff->op.clear();
|
||||
diff->update_op = false;
|
||||
@ -416,12 +419,13 @@ inline void UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) {
|
||||
diff->op = string(op);
|
||||
diff->update_op = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Updates node device. If `device` is the same as the device in the original
|
||||
// node, the field will be cleared in the diff.
|
||||
template <typename GraphViewT>
|
||||
inline void UpdateDevice(NodeViewDiff<GraphViewT>* diff,
|
||||
inline bool UpdateDevice(NodeViewDiff<GraphViewT>* diff,
|
||||
absl::string_view device) {
|
||||
if (diff->graph_view->GetNode(diff->node_index)->GetDevice() == device) {
|
||||
diff->device.clear();
|
||||
@ -430,6 +434,7 @@ inline void UpdateDevice(NodeViewDiff<GraphViewT>* diff,
|
||||
diff->device = string(device);
|
||||
diff->update_device = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Adds or updates value in vector `v` at index `i`. This will also resize the
|
||||
@ -476,11 +481,11 @@ inline bool CheckNodeNameExists(
|
||||
// differs. If `index` is greater than or equal to the number of regular fanins,
|
||||
// `fanin` will be added beyond the end of regular fanins at `index`.
|
||||
template <typename GraphViewT>
|
||||
inline void AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index,
|
||||
inline bool AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index,
|
||||
const TensorId& fanin) {
|
||||
if (index < 0) {
|
||||
// Not a valid index for regular fanins.
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
auto* node_view = diff->graph_view->GetNode(diff->node_index);
|
||||
const int num_regular_fanins = node_view->NumRegularFanins();
|
||||
@ -511,15 +516,16 @@ inline void AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index,
|
||||
++diff->num_regular_inputs_to_add;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Remove regular fanin at `index` of regular fanins. This can remove existing
|
||||
// fanins and updated/added fanins via AddOrUpdateRegularFanins.
|
||||
template <typename GraphViewT>
|
||||
inline void RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) {
|
||||
inline bool RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) {
|
||||
if (index < 0) {
|
||||
// Not a valid index for regular fanins.
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
auto* node_view = diff->graph_view->GetNode(diff->node_index);
|
||||
const int num_regular_fanins = node_view->NumRegularFanins();
|
||||
@ -541,19 +547,20 @@ inline void RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) {
|
||||
IsEmptyTensorId(diff->regular_inputs_to_add[relative_add_index])) {
|
||||
// At relative index, appended regular fanin was already marked for
|
||||
// removal.
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
// Remove added fanin.
|
||||
diff->regular_inputs_to_add[relative_add_index] = EmptyTensorId();
|
||||
--diff->num_regular_inputs_to_add;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Adds controlling fanin. If the controlling fanin already exists in the
|
||||
// original node, it will be dedupped. If the controlling fanin is marked for
|
||||
// removal, this will reverse it.
|
||||
template <typename GraphViewT>
|
||||
inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
|
||||
inline bool AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
|
||||
int control_index,
|
||||
absl::string_view fanin_node_name) {
|
||||
if (control_index == kMissingIndex) {
|
||||
@ -561,6 +568,7 @@ inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
|
||||
} else {
|
||||
diff->controlling_inputs_to_remove.erase(control_index);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Remove controlling fanin. If the controlling fanin does not exist in the
|
||||
@ -568,7 +576,7 @@ inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
|
||||
// in the diff, it will be removed. Otherwise the controlling fanin will be
|
||||
// marked for removal from the original node.
|
||||
template <typename GraphViewT>
|
||||
inline void RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff,
|
||||
inline bool RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff,
|
||||
int control_index,
|
||||
absl::string_view fanin_node_name) {
|
||||
if (control_index == kMissingIndex) {
|
||||
@ -576,28 +584,33 @@ inline void RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff,
|
||||
} else {
|
||||
diff->controlling_inputs_to_remove.emplace(control_index);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Adds or updates an attribute by name. If an attribute exist in the original
|
||||
// node or diff (including those marked for removal), this will overwrite it.
|
||||
template <typename GraphViewT>
|
||||
inline void AddOrUpdateAttribute(NodeViewDiff<GraphViewT>* diff,
|
||||
inline bool AddOrUpdateAttribute(NodeViewDiff<GraphViewT>* diff,
|
||||
absl::string_view attr_name,
|
||||
const AttrValue& attr_value) {
|
||||
diff->attrs_to_remove.erase(attr_name);
|
||||
diff->attrs_to_add.empty() ? 0 : diff->attrs_to_remove.erase(attr_name);
|
||||
gtl::InsertOrUpdate(&diff->attrs_to_add, string(attr_name), attr_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Removes an attribute by name. If an attribute exist in the original node or
|
||||
// diff, this will remove it.
|
||||
template <typename GraphViewT>
|
||||
inline void RemoveAttribute(NodeViewDiff<GraphViewT>* diff,
|
||||
inline bool RemoveAttribute(NodeViewDiff<GraphViewT>* diff,
|
||||
absl::string_view attr_name) {
|
||||
diff->attrs_to_add.erase(attr_name);
|
||||
const size_t num_erased =
|
||||
diff->attrs_to_add.empty() ? 0 : diff->attrs_to_add.erase(attr_name);
|
||||
auto* node_view = diff->graph_view->GetNode(diff->node_index);
|
||||
if (node_view->HasAttr(attr_name)) {
|
||||
diff->attrs_to_remove.emplace(attr_name);
|
||||
return true;
|
||||
}
|
||||
return num_erased > 0;
|
||||
}
|
||||
|
||||
// Removes trailing values in vector `v` for values equal to `value`.
|
||||
|
Loading…
Reference in New Issue
Block a user