Use flat_hash_set to avoid duplicate group_ids.

PiperOrigin-RevId: 332517806
Change-Id: Ia1db93f8494ed5b5d68e3d0b2317f38b007a2cbf
This commit is contained in:
Jiho Choi 2020-09-18 13:38:29 -07:00 committed by TensorFlower Gardener
parent c1e513b04a
commit ec98fee0c3
2 changed files with 5 additions and 10 deletions

View File

@ -26,23 +26,17 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h" #include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h" #include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_utils.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow { namespace tensorflow {
namespace profiler { namespace profiler {
@ -427,8 +421,8 @@ void EventNode::PropagateGroupId(int64 group_id,
absl::optional<int64> child_group_id = child->GetGroupId(); absl::optional<int64> child_group_id = child->GetGroupId();
if (child_group_id.has_value()) { if (child_group_id.has_value()) {
if (*child_group_id != group_id) { if (*child_group_id != group_id) {
(*group_metadata_map)[group_id].children.push_back(*child_group_id); (*group_metadata_map)[group_id].children.insert(*child_group_id);
(*group_metadata_map)[*child_group_id].parents.push_back(group_id); (*group_metadata_map)[*child_group_id].parents.insert(group_id);
} }
// Stop propagation if it already belongs to a group. It may have been // Stop propagation if it already belongs to a group. It may have been
// grouped by another root. // grouped by another root.

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -52,8 +53,8 @@ struct ContextInfo {
struct GroupMetadata { struct GroupMetadata {
std::string name; std::string name;
std::string model_id; // inference only. std::string model_id; // inference only.
std::vector<int64> parents; absl::flat_hash_set<int64> parents;
std::vector<int64> children; absl::flat_hash_set<int64> children;
}; };
using GroupMetadataMap = absl::flat_hash_map<int64 /*group_id*/, GroupMetadata>; using GroupMetadataMap = absl::flat_hash_map<int64 /*group_id*/, GroupMetadata>;