Add the category enum to avoid string comparisons and enable the name scope line for JAX.
PiperOrigin-RevId: 304437327 Change-Id: Ife469db2caf0f93a0b7e4b9a271e0568fbf0ccf1
This commit is contained in:
parent
b16d24a342
commit
bc2cdb667d
|
@ -153,7 +153,7 @@ absl::flat_hash_map<int64, TfOp> CollectTfOpsFromHostThreadsXPlane(
|
|||
// user-inserted TraceMe's have "unknown" type. We don't count them in
|
||||
// Tf-stats.
|
||||
TfOp tf_op = ParseTfOpFullname(metadata.name());
|
||||
if (!IsUnknownOp(tf_op.type)) {
|
||||
if (tf_op.category != Category::kUnknown) {
|
||||
tf_ops.try_emplace(metadata.id(), tf_op);
|
||||
}
|
||||
}
|
||||
|
@ -214,7 +214,7 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(
|
|||
if (tf_op_fullname.empty()) return;
|
||||
TfOp tf_op = ParseTfOpFullname(tf_op_fullname);
|
||||
TfOpRoofLineCostEstimator::OpRoofLineStats costs;
|
||||
if (tf_op.type != kUnknownOp) {
|
||||
if (tf_op.category != Category::kUnknown) {
|
||||
costs = op_level_cost_estimator.Predict(event);
|
||||
}
|
||||
device_op_metrics_db_builder.EnterOp(
|
||||
|
|
|
@ -147,7 +147,8 @@ void ProcessTfOpEvent(const XEventVisitor& event,
|
|||
plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))
|
||||
->id();
|
||||
TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
|
||||
if (tf_op.is_tf_op) {
|
||||
Category category = tf_op.category;
|
||||
if (category == Category::kTensorFlow || category == Category::kJax) {
|
||||
std::vector<XEvent> name_scope_event_per_level;
|
||||
for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
|
||||
name_scope_event_per_level.push_back(CreateXEvent(
|
||||
|
|
|
@ -47,14 +47,16 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
|
|||
// JAX op types have only lowercase letters and underscores.
|
||||
static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"};
|
||||
|
||||
TfOp tf_op = {tf_op_fullname, kUnknownOp, /*is_tf_op=*/false};
|
||||
TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp};
|
||||
std::vector<absl::string_view> parts =
|
||||
absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1));
|
||||
if (parts.size() != 2) {
|
||||
// GPU-related Ops that need to be tracked.
|
||||
if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) {
|
||||
tf_op.category = Category::kMemcpyHToD;
|
||||
tf_op.type = kMemcpyHToDOp;
|
||||
} else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) {
|
||||
tf_op.category = Category::kMemcpyDToH;
|
||||
tf_op.type = kMemcpyDToHOp;
|
||||
}
|
||||
// TODO(ckluk): Include the corresponding Ops on TPU.
|
||||
|
@ -62,12 +64,13 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
|
|||
// Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
|
||||
// format of TF Op names. But we still want to capture them for
|
||||
// input-pipeline analysis.
|
||||
tf_op.category = Category::kTfData;
|
||||
tf_op.type = kDatasetOp;
|
||||
} else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) &&
|
||||
RE2::FullMatch(parts[0], *kTfOpNameRegEx)) { // TensorFlow
|
||||
tf_op = {parts[0], parts[1], /*is_tf_op=*/true};
|
||||
tf_op = {Category::kTensorFlow, parts[0], parts[1]};
|
||||
} else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) { // JAX
|
||||
tf_op = {parts[0], parts[1], /*is_tf_op=*/false};
|
||||
tf_op = {Category::kJax, parts[0], parts[1]};
|
||||
}
|
||||
return tf_op;
|
||||
}
|
||||
|
@ -81,10 +84,10 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op) {
|
|||
|
||||
std::string TfOpEventName(const TfOp& tf_op) {
|
||||
std::string event_name;
|
||||
if (tf_op.type == kUnknownOp) {
|
||||
if (tf_op.category == Category::kUnknown) {
|
||||
// Some TraceMe names contain trailing whitespace, remove it.
|
||||
event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name));
|
||||
} else if (tf_op.type == kDatasetOp) {
|
||||
} else if (tf_op.category == Category::kTfData) {
|
||||
std::vector<absl::string_view> op_parts =
|
||||
absl::StrSplit(tf_op.name, kSeparator);
|
||||
event_name = absl::StrCat(kIterator, kSeparator, op_parts.back());
|
||||
|
|
|
@ -31,11 +31,20 @@ ABSL_CONST_INIT extern const absl::string_view kDatasetOp;
|
|||
ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp;
|
||||
ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp;
|
||||
|
||||
enum class Category {
|
||||
kTensorFlow,
|
||||
kJax,
|
||||
kTfData,
|
||||
kMemcpyHToD,
|
||||
kMemcpyDToH,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
// Breaks a TensorFlow op fullname into name and type.
|
||||
struct TfOp {
|
||||
Category category;
|
||||
absl::string_view name;
|
||||
absl::string_view type;
|
||||
bool is_tf_op;
|
||||
};
|
||||
|
||||
TfOp ParseTfOpFullname(absl::string_view tf_op_fullname);
|
||||
|
@ -48,11 +57,6 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op);
|
|||
std::string TfOpEventName(const TfOp& tf_op);
|
||||
std::string TfOpEventName(absl::string_view tf_op_fullname);
|
||||
|
||||
// Returns true if the given name is not a TensorFlow op.
|
||||
inline bool IsUnknownOp(absl::string_view tf_op_type) {
|
||||
return tf_op_type == kUnknownOp;
|
||||
}
|
||||
|
||||
// Returns true if the given name is a TensorFlow Dataset Op.
|
||||
inline bool IsDatasetOp(absl::string_view tf_op_type) {
|
||||
return tf_op_type == kDatasetOp;
|
||||
|
|
|
@ -24,6 +24,7 @@ namespace {
|
|||
TEST(TfOpUtilsTest, TfOpTest) {
|
||||
const absl::string_view kName = "OpName:OpType";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
|
||||
EXPECT_EQ(tf_op.name, "OpName");
|
||||
EXPECT_EQ(tf_op.type, "OpType");
|
||||
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
|
||||
|
@ -32,6 +33,7 @@ TEST(TfOpUtilsTest, TfOpTest) {
|
|||
TEST(TfOpUtilsTest, InternalTfOpTest) {
|
||||
const absl::string_view kName = "OpName:_InternalOpType";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
|
||||
EXPECT_EQ(tf_op.name, "OpName");
|
||||
EXPECT_EQ(tf_op.type, "_InternalOpType");
|
||||
EXPECT_EQ(TfOpEventName(kName), "_InternalOpType"); // type only
|
||||
|
@ -40,6 +42,7 @@ TEST(TfOpUtilsTest, InternalTfOpTest) {
|
|||
TEST(TfOpUtilsTest, TfOpWithPathTest) {
|
||||
const absl::string_view kName = "path/to/name:OpType";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
|
||||
EXPECT_EQ(tf_op.name, "path/to/name");
|
||||
EXPECT_EQ(tf_op.type, "OpType");
|
||||
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
|
||||
|
@ -48,24 +51,27 @@ TEST(TfOpUtilsTest, TfOpWithPathTest) {
|
|||
TEST(TfOpUtilsTest, ShortDatasetOpTest) {
|
||||
const absl::string_view kName = "Iterator::Batch";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kTfData);
|
||||
EXPECT_EQ(tf_op.name, kName);
|
||||
EXPECT_TRUE(IsDatasetOp(tf_op.type));
|
||||
EXPECT_EQ(tf_op.type, kDatasetOp);
|
||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||
}
|
||||
|
||||
TEST(TfOpUtilsTest, LongDatasetOpTest) {
|
||||
const absl::string_view kName = "Iterator::Batch::Map::TfRecord";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kTfData);
|
||||
EXPECT_EQ(tf_op.name, kName);
|
||||
EXPECT_TRUE(IsDatasetOp(tf_op.type));
|
||||
EXPECT_EQ(tf_op.type, kDatasetOp);
|
||||
EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord"); // shorter name
|
||||
}
|
||||
|
||||
TEST(TfOpUtilsTest, TraceMeTest) {
|
||||
const absl::string_view kName = "MyTraceMe";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kUnknown);
|
||||
EXPECT_EQ(tf_op.name, kName);
|
||||
EXPECT_TRUE(IsUnknownOp(tf_op.type));
|
||||
EXPECT_EQ(tf_op.type, kUnknownOp);
|
||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||
}
|
||||
|
||||
|
@ -73,16 +79,18 @@ TEST(TfOpUtilsTest, TraceMeWithColonTest) {
|
|||
// "12345" is not a valid op type.
|
||||
const absl::string_view kName = "RunStep/Server:54635";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kUnknown);
|
||||
EXPECT_EQ(tf_op.name, kName);
|
||||
EXPECT_TRUE(IsUnknownOp(tf_op.type));
|
||||
EXPECT_EQ(tf_op.type, kUnknownOp);
|
||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||
}
|
||||
|
||||
TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) {
|
||||
const absl::string_view kName = "XLA::StartProgram";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kUnknown);
|
||||
EXPECT_EQ(tf_op.name, kName);
|
||||
EXPECT_TRUE(IsUnknownOp(tf_op.type));
|
||||
EXPECT_EQ(tf_op.type, kUnknownOp);
|
||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||
}
|
||||
|
||||
|
@ -90,11 +98,39 @@ TEST(TfOpUtilsTest, TraceMeWithTrailingWhitespaceTest) {
|
|||
const absl::string_view kName = "SessionRun ";
|
||||
const absl::string_view kNameTrimmed = "SessionRun";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kUnknown);
|
||||
EXPECT_EQ(tf_op.name, kName);
|
||||
EXPECT_TRUE(IsUnknownOp(tf_op.type));
|
||||
EXPECT_EQ(tf_op.type, kUnknownOp);
|
||||
EXPECT_EQ(TfOpEventName(kName), kNameTrimmed);
|
||||
}
|
||||
|
||||
TEST(TfOpUtilsTest, MemcpyHToDTest) {
|
||||
const absl::string_view kName = "MemcpyHToD";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kMemcpyHToD);
|
||||
EXPECT_EQ(tf_op.name, kName);
|
||||
EXPECT_EQ(tf_op.type, kMemcpyHToDOp);
|
||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||
}
|
||||
|
||||
TEST(TfOpUtilsTest, MemcpyDToHTest) {
|
||||
const absl::string_view kName = "MemcpyDToH";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kMemcpyDToH);
|
||||
EXPECT_EQ(tf_op.name, kName);
|
||||
EXPECT_EQ(tf_op.type, kMemcpyDToHOp);
|
||||
EXPECT_EQ(TfOpEventName(kName), kName);
|
||||
}
|
||||
|
||||
TEST(TfOpUtilsTest, JaxOpTest) {
|
||||
const absl::string_view kName = "op_name:op_type";
|
||||
TfOp tf_op = ParseTfOpFullname(kName);
|
||||
EXPECT_EQ(tf_op.category, Category::kJax);
|
||||
EXPECT_EQ(tf_op.name, "op_name");
|
||||
EXPECT_EQ(tf_op.type, "op_type");
|
||||
EXPECT_EQ(TfOpEventName(kName), "op_type");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
|
Loading…
Reference in New Issue