Add the category enum to avoid string comparisons and enable the name scope line for JAX.
PiperOrigin-RevId: 304437327 Change-Id: Ife469db2caf0f93a0b7e4b9a271e0568fbf0ccf1
This commit is contained in:
parent
b16d24a342
commit
bc2cdb667d
|
@ -153,7 +153,7 @@ absl::flat_hash_map<int64, TfOp> CollectTfOpsFromHostThreadsXPlane(
|
||||||
// user-inserted TraceMe's have "unknown" type. We don't count them in
|
// user-inserted TraceMe's have "unknown" type. We don't count them in
|
||||||
// Tf-stats.
|
// Tf-stats.
|
||||||
TfOp tf_op = ParseTfOpFullname(metadata.name());
|
TfOp tf_op = ParseTfOpFullname(metadata.name());
|
||||||
if (!IsUnknownOp(tf_op.type)) {
|
if (tf_op.category != Category::kUnknown) {
|
||||||
tf_ops.try_emplace(metadata.id(), tf_op);
|
tf_ops.try_emplace(metadata.id(), tf_op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -214,7 +214,7 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(
|
||||||
if (tf_op_fullname.empty()) return;
|
if (tf_op_fullname.empty()) return;
|
||||||
TfOp tf_op = ParseTfOpFullname(tf_op_fullname);
|
TfOp tf_op = ParseTfOpFullname(tf_op_fullname);
|
||||||
TfOpRoofLineCostEstimator::OpRoofLineStats costs;
|
TfOpRoofLineCostEstimator::OpRoofLineStats costs;
|
||||||
if (tf_op.type != kUnknownOp) {
|
if (tf_op.category != Category::kUnknown) {
|
||||||
costs = op_level_cost_estimator.Predict(event);
|
costs = op_level_cost_estimator.Predict(event);
|
||||||
}
|
}
|
||||||
device_op_metrics_db_builder.EnterOp(
|
device_op_metrics_db_builder.EnterOp(
|
||||||
|
|
|
@ -147,7 +147,8 @@ void ProcessTfOpEvent(const XEventVisitor& event,
|
||||||
plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))
|
plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))
|
||||||
->id();
|
->id();
|
||||||
TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
|
TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
|
||||||
if (tf_op.is_tf_op) {
|
Category category = tf_op.category;
|
||||||
|
if (category == Category::kTensorFlow || category == Category::kJax) {
|
||||||
std::vector<XEvent> name_scope_event_per_level;
|
std::vector<XEvent> name_scope_event_per_level;
|
||||||
for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
|
for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
|
||||||
name_scope_event_per_level.push_back(CreateXEvent(
|
name_scope_event_per_level.push_back(CreateXEvent(
|
||||||
|
|
|
@ -47,14 +47,16 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
|
||||||
// JAX op types have only lowercase letters and underscores.
|
// JAX op types have only lowercase letters and underscores.
|
||||||
static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"};
|
static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"};
|
||||||
|
|
||||||
TfOp tf_op = {tf_op_fullname, kUnknownOp, /*is_tf_op=*/false};
|
TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp};
|
||||||
std::vector<absl::string_view> parts =
|
std::vector<absl::string_view> parts =
|
||||||
absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1));
|
absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1));
|
||||||
if (parts.size() != 2) {
|
if (parts.size() != 2) {
|
||||||
// GPU-related Ops that need to be tracked.
|
// GPU-related Ops that need to be tracked.
|
||||||
if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) {
|
if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) {
|
||||||
|
tf_op.category = Category::kMemcpyHToD;
|
||||||
tf_op.type = kMemcpyHToDOp;
|
tf_op.type = kMemcpyHToDOp;
|
||||||
} else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) {
|
} else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) {
|
||||||
|
tf_op.category = Category::kMemcpyDToH;
|
||||||
tf_op.type = kMemcpyDToHOp;
|
tf_op.type = kMemcpyDToHOp;
|
||||||
}
|
}
|
||||||
// TODO(ckluk): Include the corresponding Ops on TPU.
|
// TODO(ckluk): Include the corresponding Ops on TPU.
|
||||||
|
@ -62,12 +64,13 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
|
||||||
// Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
|
// Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
|
||||||
// format of TF Op names. But we still want to capture them for
|
// format of TF Op names. But we still want to capture them for
|
||||||
// input-pipeline analysis.
|
// input-pipeline analysis.
|
||||||
|
tf_op.category = Category::kTfData;
|
||||||
tf_op.type = kDatasetOp;
|
tf_op.type = kDatasetOp;
|
||||||
} else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) &&
|
} else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) &&
|
||||||
RE2::FullMatch(parts[0], *kTfOpNameRegEx)) { // TensorFlow
|
RE2::FullMatch(parts[0], *kTfOpNameRegEx)) { // TensorFlow
|
||||||
tf_op = {parts[0], parts[1], /*is_tf_op=*/true};
|
tf_op = {Category::kTensorFlow, parts[0], parts[1]};
|
||||||
} else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) { // JAX
|
} else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) { // JAX
|
||||||
tf_op = {parts[0], parts[1], /*is_tf_op=*/false};
|
tf_op = {Category::kJax, parts[0], parts[1]};
|
||||||
}
|
}
|
||||||
return tf_op;
|
return tf_op;
|
||||||
}
|
}
|
||||||
|
@ -81,10 +84,10 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op) {
|
||||||
|
|
||||||
std::string TfOpEventName(const TfOp& tf_op) {
|
std::string TfOpEventName(const TfOp& tf_op) {
|
||||||
std::string event_name;
|
std::string event_name;
|
||||||
if (tf_op.type == kUnknownOp) {
|
if (tf_op.category == Category::kUnknown) {
|
||||||
// Some TraceMe names contain trailing whitespace, remove it.
|
// Some TraceMe names contain trailing whitespace, remove it.
|
||||||
event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name));
|
event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name));
|
||||||
} else if (tf_op.type == kDatasetOp) {
|
} else if (tf_op.category == Category::kTfData) {
|
||||||
std::vector<absl::string_view> op_parts =
|
std::vector<absl::string_view> op_parts =
|
||||||
absl::StrSplit(tf_op.name, kSeparator);
|
absl::StrSplit(tf_op.name, kSeparator);
|
||||||
event_name = absl::StrCat(kIterator, kSeparator, op_parts.back());
|
event_name = absl::StrCat(kIterator, kSeparator, op_parts.back());
|
||||||
|
|
|
@ -31,11 +31,20 @@ ABSL_CONST_INIT extern const absl::string_view kDatasetOp;
|
||||||
ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp;
|
ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp;
|
||||||
ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp;
|
ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp;
|
||||||
|
|
||||||
|
enum class Category {
|
||||||
|
kTensorFlow,
|
||||||
|
kJax,
|
||||||
|
kTfData,
|
||||||
|
kMemcpyHToD,
|
||||||
|
kMemcpyDToH,
|
||||||
|
kUnknown,
|
||||||
|
};
|
||||||
|
|
||||||
// Breaks a TensorFlow op fullname into name and type.
|
// Breaks a TensorFlow op fullname into name and type.
|
||||||
struct TfOp {
|
struct TfOp {
|
||||||
|
Category category;
|
||||||
absl::string_view name;
|
absl::string_view name;
|
||||||
absl::string_view type;
|
absl::string_view type;
|
||||||
bool is_tf_op;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TfOp ParseTfOpFullname(absl::string_view tf_op_fullname);
|
TfOp ParseTfOpFullname(absl::string_view tf_op_fullname);
|
||||||
|
@ -48,11 +57,6 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op);
|
||||||
std::string TfOpEventName(const TfOp& tf_op);
|
std::string TfOpEventName(const TfOp& tf_op);
|
||||||
std::string TfOpEventName(absl::string_view tf_op_fullname);
|
std::string TfOpEventName(absl::string_view tf_op_fullname);
|
||||||
|
|
||||||
// Returns true if the given name is not a TensorFlow op.
|
|
||||||
inline bool IsUnknownOp(absl::string_view tf_op_type) {
|
|
||||||
return tf_op_type == kUnknownOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if the given name is a TensorFlow Dataset Op.
|
// Returns true if the given name is a TensorFlow Dataset Op.
|
||||||
inline bool IsDatasetOp(absl::string_view tf_op_type) {
|
inline bool IsDatasetOp(absl::string_view tf_op_type) {
|
||||||
return tf_op_type == kDatasetOp;
|
return tf_op_type == kDatasetOp;
|
||||||
|
|
|
@ -24,6 +24,7 @@ namespace {
|
||||||
TEST(TfOpUtilsTest, TfOpTest) {
|
TEST(TfOpUtilsTest, TfOpTest) {
|
||||||
const absl::string_view kName = "OpName:OpType";
|
const absl::string_view kName = "OpName:OpType";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
|
||||||
EXPECT_EQ(tf_op.name, "OpName");
|
EXPECT_EQ(tf_op.name, "OpName");
|
||||||
EXPECT_EQ(tf_op.type, "OpType");
|
EXPECT_EQ(tf_op.type, "OpType");
|
||||||
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
|
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
|
||||||
|
@ -32,6 +33,7 @@ TEST(TfOpUtilsTest, TfOpTest) {
|
||||||
TEST(TfOpUtilsTest, InternalTfOpTest) {
|
TEST(TfOpUtilsTest, InternalTfOpTest) {
|
||||||
const absl::string_view kName = "OpName:_InternalOpType";
|
const absl::string_view kName = "OpName:_InternalOpType";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
|
||||||
EXPECT_EQ(tf_op.name, "OpName");
|
EXPECT_EQ(tf_op.name, "OpName");
|
||||||
EXPECT_EQ(tf_op.type, "_InternalOpType");
|
EXPECT_EQ(tf_op.type, "_InternalOpType");
|
||||||
EXPECT_EQ(TfOpEventName(kName), "_InternalOpType"); // type only
|
EXPECT_EQ(TfOpEventName(kName), "_InternalOpType"); // type only
|
||||||
|
@ -40,6 +42,7 @@ TEST(TfOpUtilsTest, InternalTfOpTest) {
|
||||||
TEST(TfOpUtilsTest, TfOpWithPathTest) {
|
TEST(TfOpUtilsTest, TfOpWithPathTest) {
|
||||||
const absl::string_view kName = "path/to/name:OpType";
|
const absl::string_view kName = "path/to/name:OpType";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
|
||||||
EXPECT_EQ(tf_op.name, "path/to/name");
|
EXPECT_EQ(tf_op.name, "path/to/name");
|
||||||
EXPECT_EQ(tf_op.type, "OpType");
|
EXPECT_EQ(tf_op.type, "OpType");
|
||||||
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
|
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
|
||||||
|
@ -48,24 +51,27 @@ TEST(TfOpUtilsTest, TfOpWithPathTest) {
|
||||||
TEST(TfOpUtilsTest, ShortDatasetOpTest) {
|
TEST(TfOpUtilsTest, ShortDatasetOpTest) {
|
||||||
const absl::string_view kName = "Iterator::Batch";
|
const absl::string_view kName = "Iterator::Batch";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kTfData);
|
||||||
EXPECT_EQ(tf_op.name, kName);
|
EXPECT_EQ(tf_op.name, kName);
|
||||||
EXPECT_TRUE(IsDatasetOp(tf_op.type));
|
EXPECT_EQ(tf_op.type, kDatasetOp);
|
||||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TfOpUtilsTest, LongDatasetOpTest) {
|
TEST(TfOpUtilsTest, LongDatasetOpTest) {
|
||||||
const absl::string_view kName = "Iterator::Batch::Map::TfRecord";
|
const absl::string_view kName = "Iterator::Batch::Map::TfRecord";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kTfData);
|
||||||
EXPECT_EQ(tf_op.name, kName);
|
EXPECT_EQ(tf_op.name, kName);
|
||||||
EXPECT_TRUE(IsDatasetOp(tf_op.type));
|
EXPECT_EQ(tf_op.type, kDatasetOp);
|
||||||
EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord"); // shorter name
|
EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord"); // shorter name
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TfOpUtilsTest, TraceMeTest) {
|
TEST(TfOpUtilsTest, TraceMeTest) {
|
||||||
const absl::string_view kName = "MyTraceMe";
|
const absl::string_view kName = "MyTraceMe";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kUnknown);
|
||||||
EXPECT_EQ(tf_op.name, kName);
|
EXPECT_EQ(tf_op.name, kName);
|
||||||
EXPECT_TRUE(IsUnknownOp(tf_op.type));
|
EXPECT_EQ(tf_op.type, kUnknownOp);
|
||||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,16 +79,18 @@ TEST(TfOpUtilsTest, TraceMeWithColonTest) {
|
||||||
// "12345" is not a valid op type.
|
// "12345" is not a valid op type.
|
||||||
const absl::string_view kName = "RunStep/Server:54635";
|
const absl::string_view kName = "RunStep/Server:54635";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kUnknown);
|
||||||
EXPECT_EQ(tf_op.name, kName);
|
EXPECT_EQ(tf_op.name, kName);
|
||||||
EXPECT_TRUE(IsUnknownOp(tf_op.type));
|
EXPECT_EQ(tf_op.type, kUnknownOp);
|
||||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) {
|
TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) {
|
||||||
const absl::string_view kName = "XLA::StartProgram";
|
const absl::string_view kName = "XLA::StartProgram";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kUnknown);
|
||||||
EXPECT_EQ(tf_op.name, kName);
|
EXPECT_EQ(tf_op.name, kName);
|
||||||
EXPECT_TRUE(IsUnknownOp(tf_op.type));
|
EXPECT_EQ(tf_op.type, kUnknownOp);
|
||||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,11 +98,39 @@ TEST(TfOpUtilsTest, TraceMeWithTrailingWhitespaceTest) {
|
||||||
const absl::string_view kName = "SessionRun ";
|
const absl::string_view kName = "SessionRun ";
|
||||||
const absl::string_view kNameTrimmed = "SessionRun";
|
const absl::string_view kNameTrimmed = "SessionRun";
|
||||||
TfOp tf_op = ParseTfOpFullname(kName);
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kUnknown);
|
||||||
EXPECT_EQ(tf_op.name, kName);
|
EXPECT_EQ(tf_op.name, kName);
|
||||||
EXPECT_TRUE(IsUnknownOp(tf_op.type));
|
EXPECT_EQ(tf_op.type, kUnknownOp);
|
||||||
EXPECT_EQ(TfOpEventName(kName), kNameTrimmed);
|
EXPECT_EQ(TfOpEventName(kName), kNameTrimmed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TfOpUtilsTest, MemcpyHToDTest) {
|
||||||
|
const absl::string_view kName = "MemcpyHToD";
|
||||||
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kMemcpyHToD);
|
||||||
|
EXPECT_EQ(tf_op.name, kName);
|
||||||
|
EXPECT_EQ(tf_op.type, kMemcpyHToDOp);
|
||||||
|
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TfOpUtilsTest, MemcpyDToHTest) {
|
||||||
|
const absl::string_view kName = "MemcpyDToH";
|
||||||
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kMemcpyDToH);
|
||||||
|
EXPECT_EQ(tf_op.name, kName);
|
||||||
|
EXPECT_EQ(tf_op.type, kMemcpyDToHOp);
|
||||||
|
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TfOpUtilsTest, JaxOpTest) {
|
||||||
|
const absl::string_view kName = "op_name:op_type";
|
||||||
|
TfOp tf_op = ParseTfOpFullname(kName);
|
||||||
|
EXPECT_EQ(tf_op.category, Category::kJax);
|
||||||
|
EXPECT_EQ(tf_op.name, "op_name");
|
||||||
|
EXPECT_EQ(tf_op.type, "op_type");
|
||||||
|
EXPECT_EQ(TfOpEventName(kName), "op_type");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace profiler
|
} // namespace profiler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
Loading…
Reference in New Issue