Use flat_hash_set to avoid duplicate group_ids.

PiperOrigin-RevId: 332517806
Change-Id: Ia1db93f8494ed5b5d68e3d0b2317f38b007a2cbf
This commit is contained in:
Jiho Choi 2020-09-18 13:38:29 -07:00 committed by TensorFlower Gardener
parent c1e513b04a
commit ec98fee0c3
2 changed files with 5 additions and 10 deletions

View File

@ -26,23 +26,17 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
@ -427,8 +421,8 @@ void EventNode::PropagateGroupId(int64 group_id,
absl::optional<int64> child_group_id = child->GetGroupId();
if (child_group_id.has_value()) {
if (*child_group_id != group_id) {
(*group_metadata_map)[group_id].children.push_back(*child_group_id);
(*group_metadata_map)[*child_group_id].parents.push_back(group_id);
(*group_metadata_map)[group_id].children.insert(*child_group_id);
(*group_metadata_map)[*child_group_id].parents.insert(group_id);
}
// Stop propagation if it already belongs to a group. It may have been
// grouped by another root.

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/core/platform/logging.h"
@ -52,8 +53,8 @@ struct ContextInfo {
struct GroupMetadata {
std::string name;
std::string model_id; // inference only.
std::vector<int64> parents;
std::vector<int64> children;
absl::flat_hash_set<int64> parents;
absl::flat_hash_set<int64> children;
};
using GroupMetadataMap = absl::flat_hash_map<int64 /*group_id*/, GroupMetadata>;