Add the category enum to avoid string comparisons and enable the name scope line for JAX.

PiperOrigin-RevId: 304437327
Change-Id: Ife469db2caf0f93a0b7e4b9a271e0568fbf0ccf1
This commit is contained in:
Jiho Choi 2020-04-02 11:09:08 -07:00 committed by TensorFlower Gardener
parent b16d24a342
commit bc2cdb667d
5 changed files with 64 additions and 20 deletions

View File

@ -153,7 +153,7 @@ absl::flat_hash_map<int64, TfOp> CollectTfOpsFromHostThreadsXPlane(
// user-inserted TraceMe's have "unknown" type. We don't count them in
// Tf-stats.
TfOp tf_op = ParseTfOpFullname(metadata.name());
if (!IsUnknownOp(tf_op.type)) {
if (tf_op.category != Category::kUnknown) {
tf_ops.try_emplace(metadata.id(), tf_op);
}
}
@ -214,7 +214,7 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(
if (tf_op_fullname.empty()) return;
TfOp tf_op = ParseTfOpFullname(tf_op_fullname);
TfOpRoofLineCostEstimator::OpRoofLineStats costs;
if (tf_op.type != kUnknownOp) {
if (tf_op.category != Category::kUnknown) {
costs = op_level_cost_estimator.Predict(event);
}
device_op_metrics_db_builder.EnterOp(

View File

@ -147,7 +147,8 @@ void ProcessTfOpEvent(const XEventVisitor& event,
plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))
->id();
TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
if (tf_op.is_tf_op) {
Category category = tf_op.category;
if (category == Category::kTensorFlow || category == Category::kJax) {
std::vector<XEvent> name_scope_event_per_level;
for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
name_scope_event_per_level.push_back(CreateXEvent(

View File

@ -47,14 +47,16 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
// JAX op types have only lowercase letters and underscores.
static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"};
TfOp tf_op = {tf_op_fullname, kUnknownOp, /*is_tf_op=*/false};
TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp};
std::vector<absl::string_view> parts =
absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1));
if (parts.size() != 2) {
// GPU-related Ops that need to be tracked.
if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) {
tf_op.category = Category::kMemcpyHToD;
tf_op.type = kMemcpyHToDOp;
} else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) {
tf_op.category = Category::kMemcpyDToH;
tf_op.type = kMemcpyDToHOp;
}
// TODO(ckluk): Include the corresponding Ops on TPU.
@ -62,12 +64,13 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
// Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
// format of TF Op names. But we still want to capture them for
// input-pipeline analysis.
tf_op.category = Category::kTfData;
tf_op.type = kDatasetOp;
} else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) &&
RE2::FullMatch(parts[0], *kTfOpNameRegEx)) { // TensorFlow
tf_op = {parts[0], parts[1], /*is_tf_op=*/true};
tf_op = {Category::kTensorFlow, parts[0], parts[1]};
} else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) { // JAX
tf_op = {parts[0], parts[1], /*is_tf_op=*/false};
tf_op = {Category::kJax, parts[0], parts[1]};
}
return tf_op;
}
@ -81,10 +84,10 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op) {
std::string TfOpEventName(const TfOp& tf_op) {
std::string event_name;
if (tf_op.type == kUnknownOp) {
if (tf_op.category == Category::kUnknown) {
// Some TraceMe names contain trailing whitespace, remove it.
event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name));
} else if (tf_op.type == kDatasetOp) {
} else if (tf_op.category == Category::kTfData) {
std::vector<absl::string_view> op_parts =
absl::StrSplit(tf_op.name, kSeparator);
event_name = absl::StrCat(kIterator, kSeparator, op_parts.back());

View File

@ -31,11 +31,20 @@ ABSL_CONST_INIT extern const absl::string_view kDatasetOp;
ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp;
ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp;
enum class Category {
kTensorFlow,
kJax,
kTfData,
kMemcpyHToD,
kMemcpyDToH,
kUnknown,
};
// Breaks a TensorFlow op fullname into name and type.
struct TfOp {
Category category;
absl::string_view name;
absl::string_view type;
bool is_tf_op;
};
TfOp ParseTfOpFullname(absl::string_view tf_op_fullname);
@ -48,11 +57,6 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op);
std::string TfOpEventName(const TfOp& tf_op);
std::string TfOpEventName(absl::string_view tf_op_fullname);
// Returns true if the given name is not a TensorFlow op.
inline bool IsUnknownOp(absl::string_view tf_op_type) {
return tf_op_type == kUnknownOp;
}
// Returns true if the given name is a TensorFlow Dataset Op.
inline bool IsDatasetOp(absl::string_view tf_op_type) {
return tf_op_type == kDatasetOp;

View File

@ -24,6 +24,7 @@ namespace {
TEST(TfOpUtilsTest, TfOpTest) {
const absl::string_view kName = "OpName:OpType";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
EXPECT_EQ(tf_op.name, "OpName");
EXPECT_EQ(tf_op.type, "OpType");
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
@ -32,6 +33,7 @@ TEST(TfOpUtilsTest, TfOpTest) {
TEST(TfOpUtilsTest, InternalTfOpTest) {
const absl::string_view kName = "OpName:_InternalOpType";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
EXPECT_EQ(tf_op.name, "OpName");
EXPECT_EQ(tf_op.type, "_InternalOpType");
EXPECT_EQ(TfOpEventName(kName), "_InternalOpType"); // type only
@ -40,6 +42,7 @@ TEST(TfOpUtilsTest, InternalTfOpTest) {
TEST(TfOpUtilsTest, TfOpWithPathTest) {
const absl::string_view kName = "path/to/name:OpType";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
EXPECT_EQ(tf_op.name, "path/to/name");
EXPECT_EQ(tf_op.type, "OpType");
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
@ -48,24 +51,27 @@ TEST(TfOpUtilsTest, TfOpWithPathTest) {
TEST(TfOpUtilsTest, ShortDatasetOpTest) {
const absl::string_view kName = "Iterator::Batch";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTfData);
EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsDatasetOp(tf_op.type));
EXPECT_EQ(tf_op.type, kDatasetOp);
EXPECT_EQ(TfOpEventName(kName), kName);
}
TEST(TfOpUtilsTest, LongDatasetOpTest) {
const absl::string_view kName = "Iterator::Batch::Map::TfRecord";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTfData);
EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsDatasetOp(tf_op.type));
EXPECT_EQ(tf_op.type, kDatasetOp);
EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord"); // shorter name
}
TEST(TfOpUtilsTest, TraceMeTest) {
const absl::string_view kName = "MyTraceMe";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kUnknown);
EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsUnknownOp(tf_op.type));
EXPECT_EQ(tf_op.type, kUnknownOp);
EXPECT_EQ(TfOpEventName(kName), kName);
}
@ -73,16 +79,18 @@ TEST(TfOpUtilsTest, TraceMeWithColonTest) {
// "12345" is not a valid op type.
const absl::string_view kName = "RunStep/Server:54635";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kUnknown);
EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsUnknownOp(tf_op.type));
EXPECT_EQ(tf_op.type, kUnknownOp);
EXPECT_EQ(TfOpEventName(kName), kName);
}
TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) {
const absl::string_view kName = "XLA::StartProgram";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kUnknown);
EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsUnknownOp(tf_op.type));
EXPECT_EQ(tf_op.type, kUnknownOp);
EXPECT_EQ(TfOpEventName(kName), kName);
}
@ -90,11 +98,39 @@ TEST(TfOpUtilsTest, TraceMeWithTrailingWhitespaceTest) {
const absl::string_view kName = "SessionRun ";
const absl::string_view kNameTrimmed = "SessionRun";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kUnknown);
EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsUnknownOp(tf_op.type));
EXPECT_EQ(tf_op.type, kUnknownOp);
EXPECT_EQ(TfOpEventName(kName), kNameTrimmed);
}
TEST(TfOpUtilsTest, MemcpyHToDTest) {
const absl::string_view kName = "MemcpyHToD";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kMemcpyHToD);
EXPECT_EQ(tf_op.name, kName);
EXPECT_EQ(tf_op.type, kMemcpyHToDOp);
EXPECT_EQ(TfOpEventName(kName), kName);
}
TEST(TfOpUtilsTest, MemcpyDToHTest) {
const absl::string_view kName = "MemcpyDToH";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kMemcpyDToH);
EXPECT_EQ(tf_op.name, kName);
EXPECT_EQ(tf_op.type, kMemcpyDToHOp);
EXPECT_EQ(TfOpEventName(kName), kName);
}
TEST(TfOpUtilsTest, JaxOpTest) {
const absl::string_view kName = "op_name:op_type";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kJax);
EXPECT_EQ(tf_op.name, "op_name");
EXPECT_EQ(tf_op.type, "op_type");
EXPECT_EQ(TfOpEventName(kName), "op_type");
}
} // namespace
} // namespace profiler
} // namespace tensorflow