Add the category enum to avoid string comparisons and enable the name scope line for JAX.

PiperOrigin-RevId: 304437327
Change-Id: Ife469db2caf0f93a0b7e4b9a271e0568fbf0ccf1
This commit is contained in:
Jiho Choi 2020-04-02 11:09:08 -07:00 committed by TensorFlower Gardener
parent b16d24a342
commit bc2cdb667d
5 changed files with 64 additions and 20 deletions

View File

@ -153,7 +153,7 @@ absl::flat_hash_map<int64, TfOp> CollectTfOpsFromHostThreadsXPlane(
// user-inserted TraceMe's have "unknown" type. We don't count them in // user-inserted TraceMe's have "unknown" type. We don't count them in
// Tf-stats. // Tf-stats.
TfOp tf_op = ParseTfOpFullname(metadata.name()); TfOp tf_op = ParseTfOpFullname(metadata.name());
if (!IsUnknownOp(tf_op.type)) { if (tf_op.category != Category::kUnknown) {
tf_ops.try_emplace(metadata.id(), tf_op); tf_ops.try_emplace(metadata.id(), tf_op);
} }
} }
@ -214,7 +214,7 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(
if (tf_op_fullname.empty()) return; if (tf_op_fullname.empty()) return;
TfOp tf_op = ParseTfOpFullname(tf_op_fullname); TfOp tf_op = ParseTfOpFullname(tf_op_fullname);
TfOpRoofLineCostEstimator::OpRoofLineStats costs; TfOpRoofLineCostEstimator::OpRoofLineStats costs;
if (tf_op.type != kUnknownOp) { if (tf_op.category != Category::kUnknown) {
costs = op_level_cost_estimator.Predict(event); costs = op_level_cost_estimator.Predict(event);
} }
device_op_metrics_db_builder.EnterOp( device_op_metrics_db_builder.EnterOp(

View File

@ -147,7 +147,8 @@ void ProcessTfOpEvent(const XEventVisitor& event,
plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId)) plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))
->id(); ->id();
TfOp tf_op = ParseTfOpFullname(tf_op_full_name); TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
if (tf_op.is_tf_op) { Category category = tf_op.category;
if (category == Category::kTensorFlow || category == Category::kJax) {
std::vector<XEvent> name_scope_event_per_level; std::vector<XEvent> name_scope_event_per_level;
for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) { for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) {
name_scope_event_per_level.push_back(CreateXEvent( name_scope_event_per_level.push_back(CreateXEvent(

View File

@ -47,14 +47,16 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
// JAX op types have only lowercase letters and underscores. // JAX op types have only lowercase letters and underscores.
static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"}; static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"};
TfOp tf_op = {tf_op_fullname, kUnknownOp, /*is_tf_op=*/false}; TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp};
std::vector<absl::string_view> parts = std::vector<absl::string_view> parts =
absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1)); absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1));
if (parts.size() != 2) { if (parts.size() != 2) {
// GPU-related Ops that need to be tracked. // GPU-related Ops that need to be tracked.
if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) { if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) {
tf_op.category = Category::kMemcpyHToD;
tf_op.type = kMemcpyHToDOp; tf_op.type = kMemcpyHToDOp;
} else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) { } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) {
tf_op.category = Category::kMemcpyDToH;
tf_op.type = kMemcpyDToHOp; tf_op.type = kMemcpyDToHOp;
} }
// TODO(ckluk): Include the corresponding Ops on TPU. // TODO(ckluk): Include the corresponding Ops on TPU.
@ -62,12 +64,13 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
// Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the // Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
// format of TF Op names. But we still want to capture them for // format of TF Op names. But we still want to capture them for
// input-pipeline analysis. // input-pipeline analysis.
tf_op.category = Category::kTfData;
tf_op.type = kDatasetOp; tf_op.type = kDatasetOp;
} else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) && } else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) &&
RE2::FullMatch(parts[0], *kTfOpNameRegEx)) { // TensorFlow RE2::FullMatch(parts[0], *kTfOpNameRegEx)) { // TensorFlow
tf_op = {parts[0], parts[1], /*is_tf_op=*/true}; tf_op = {Category::kTensorFlow, parts[0], parts[1]};
} else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) { // JAX } else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) { // JAX
tf_op = {parts[0], parts[1], /*is_tf_op=*/false}; tf_op = {Category::kJax, parts[0], parts[1]};
} }
return tf_op; return tf_op;
} }
@ -81,10 +84,10 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op) {
std::string TfOpEventName(const TfOp& tf_op) { std::string TfOpEventName(const TfOp& tf_op) {
std::string event_name; std::string event_name;
if (tf_op.type == kUnknownOp) { if (tf_op.category == Category::kUnknown) {
// Some TraceMe names contain trailing whitespace, remove it. // Some TraceMe names contain trailing whitespace, remove it.
event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name)); event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name));
} else if (tf_op.type == kDatasetOp) { } else if (tf_op.category == Category::kTfData) {
std::vector<absl::string_view> op_parts = std::vector<absl::string_view> op_parts =
absl::StrSplit(tf_op.name, kSeparator); absl::StrSplit(tf_op.name, kSeparator);
event_name = absl::StrCat(kIterator, kSeparator, op_parts.back()); event_name = absl::StrCat(kIterator, kSeparator, op_parts.back());

View File

@ -31,11 +31,20 @@ ABSL_CONST_INIT extern const absl::string_view kDatasetOp;
ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp; ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp;
ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp; ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp;
enum class Category {
kTensorFlow,
kJax,
kTfData,
kMemcpyHToD,
kMemcpyDToH,
kUnknown,
};
// Breaks a TensorFlow op fullname into name and type. // Breaks a TensorFlow op fullname into name and type.
struct TfOp { struct TfOp {
Category category;
absl::string_view name; absl::string_view name;
absl::string_view type; absl::string_view type;
bool is_tf_op;
}; };
TfOp ParseTfOpFullname(absl::string_view tf_op_fullname); TfOp ParseTfOpFullname(absl::string_view tf_op_fullname);
@ -48,11 +57,6 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op);
std::string TfOpEventName(const TfOp& tf_op); std::string TfOpEventName(const TfOp& tf_op);
std::string TfOpEventName(absl::string_view tf_op_fullname); std::string TfOpEventName(absl::string_view tf_op_fullname);
// Returns true if the given name is not a TensorFlow op.
inline bool IsUnknownOp(absl::string_view tf_op_type) {
return tf_op_type == kUnknownOp;
}
// Returns true if the given name is a TensorFlow Dataset Op. // Returns true if the given name is a TensorFlow Dataset Op.
inline bool IsDatasetOp(absl::string_view tf_op_type) { inline bool IsDatasetOp(absl::string_view tf_op_type) {
return tf_op_type == kDatasetOp; return tf_op_type == kDatasetOp;

View File

@ -24,6 +24,7 @@ namespace {
TEST(TfOpUtilsTest, TfOpTest) { TEST(TfOpUtilsTest, TfOpTest) {
const absl::string_view kName = "OpName:OpType"; const absl::string_view kName = "OpName:OpType";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
EXPECT_EQ(tf_op.name, "OpName"); EXPECT_EQ(tf_op.name, "OpName");
EXPECT_EQ(tf_op.type, "OpType"); EXPECT_EQ(tf_op.type, "OpType");
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
@ -32,6 +33,7 @@ TEST(TfOpUtilsTest, TfOpTest) {
TEST(TfOpUtilsTest, InternalTfOpTest) { TEST(TfOpUtilsTest, InternalTfOpTest) {
const absl::string_view kName = "OpName:_InternalOpType"; const absl::string_view kName = "OpName:_InternalOpType";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
EXPECT_EQ(tf_op.name, "OpName"); EXPECT_EQ(tf_op.name, "OpName");
EXPECT_EQ(tf_op.type, "_InternalOpType"); EXPECT_EQ(tf_op.type, "_InternalOpType");
EXPECT_EQ(TfOpEventName(kName), "_InternalOpType"); // type only EXPECT_EQ(TfOpEventName(kName), "_InternalOpType"); // type only
@ -40,6 +42,7 @@ TEST(TfOpUtilsTest, InternalTfOpTest) {
TEST(TfOpUtilsTest, TfOpWithPathTest) { TEST(TfOpUtilsTest, TfOpWithPathTest) {
const absl::string_view kName = "path/to/name:OpType"; const absl::string_view kName = "path/to/name:OpType";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTensorFlow);
EXPECT_EQ(tf_op.name, "path/to/name"); EXPECT_EQ(tf_op.name, "path/to/name");
EXPECT_EQ(tf_op.type, "OpType"); EXPECT_EQ(tf_op.type, "OpType");
EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only
@ -48,24 +51,27 @@ TEST(TfOpUtilsTest, TfOpWithPathTest) {
TEST(TfOpUtilsTest, ShortDatasetOpTest) { TEST(TfOpUtilsTest, ShortDatasetOpTest) {
const absl::string_view kName = "Iterator::Batch"; const absl::string_view kName = "Iterator::Batch";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTfData);
EXPECT_EQ(tf_op.name, kName); EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsDatasetOp(tf_op.type)); EXPECT_EQ(tf_op.type, kDatasetOp);
EXPECT_EQ(TfOpEventName(kName), kName); EXPECT_EQ(TfOpEventName(kName), kName);
} }
TEST(TfOpUtilsTest, LongDatasetOpTest) { TEST(TfOpUtilsTest, LongDatasetOpTest) {
const absl::string_view kName = "Iterator::Batch::Map::TfRecord"; const absl::string_view kName = "Iterator::Batch::Map::TfRecord";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kTfData);
EXPECT_EQ(tf_op.name, kName); EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsDatasetOp(tf_op.type)); EXPECT_EQ(tf_op.type, kDatasetOp);
EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord"); // shorter name EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord"); // shorter name
} }
TEST(TfOpUtilsTest, TraceMeTest) { TEST(TfOpUtilsTest, TraceMeTest) {
const absl::string_view kName = "MyTraceMe"; const absl::string_view kName = "MyTraceMe";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kUnknown);
EXPECT_EQ(tf_op.name, kName); EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsUnknownOp(tf_op.type)); EXPECT_EQ(tf_op.type, kUnknownOp);
EXPECT_EQ(TfOpEventName(kName), kName); EXPECT_EQ(TfOpEventName(kName), kName);
} }
@ -73,16 +79,18 @@ TEST(TfOpUtilsTest, TraceMeWithColonTest) {
// "12345" is not a valid op type. // "12345" is not a valid op type.
const absl::string_view kName = "RunStep/Server:54635"; const absl::string_view kName = "RunStep/Server:54635";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kUnknown);
EXPECT_EQ(tf_op.name, kName); EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsUnknownOp(tf_op.type)); EXPECT_EQ(tf_op.type, kUnknownOp);
EXPECT_EQ(TfOpEventName(kName), kName); EXPECT_EQ(TfOpEventName(kName), kName);
} }
TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) { TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) {
const absl::string_view kName = "XLA::StartProgram"; const absl::string_view kName = "XLA::StartProgram";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kUnknown);
EXPECT_EQ(tf_op.name, kName); EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsUnknownOp(tf_op.type)); EXPECT_EQ(tf_op.type, kUnknownOp);
EXPECT_EQ(TfOpEventName(kName), kName); EXPECT_EQ(TfOpEventName(kName), kName);
} }
@ -90,11 +98,39 @@ TEST(TfOpUtilsTest, TraceMeWithTrailingWhitespaceTest) {
const absl::string_view kName = "SessionRun "; const absl::string_view kName = "SessionRun ";
const absl::string_view kNameTrimmed = "SessionRun"; const absl::string_view kNameTrimmed = "SessionRun";
TfOp tf_op = ParseTfOpFullname(kName); TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kUnknown);
EXPECT_EQ(tf_op.name, kName); EXPECT_EQ(tf_op.name, kName);
EXPECT_TRUE(IsUnknownOp(tf_op.type)); EXPECT_EQ(tf_op.type, kUnknownOp);
EXPECT_EQ(TfOpEventName(kName), kNameTrimmed); EXPECT_EQ(TfOpEventName(kName), kNameTrimmed);
} }
TEST(TfOpUtilsTest, MemcpyHToDTest) {
const absl::string_view kName = "MemcpyHToD";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kMemcpyHToD);
EXPECT_EQ(tf_op.name, kName);
EXPECT_EQ(tf_op.type, kMemcpyHToDOp);
EXPECT_EQ(TfOpEventName(kName), kName);
}
TEST(TfOpUtilsTest, MemcpyDToHTest) {
const absl::string_view kName = "MemcpyDToH";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kMemcpyDToH);
EXPECT_EQ(tf_op.name, kName);
EXPECT_EQ(tf_op.type, kMemcpyDToHOp);
EXPECT_EQ(TfOpEventName(kName), kName);
}
TEST(TfOpUtilsTest, JaxOpTest) {
const absl::string_view kName = "op_name:op_type";
TfOp tf_op = ParseTfOpFullname(kName);
EXPECT_EQ(tf_op.category, Category::kJax);
EXPECT_EQ(tf_op.name, "op_name");
EXPECT_EQ(tf_op.type, "op_type");
EXPECT_EQ(TfOpEventName(kName), "op_type");
}
} // namespace } // namespace
} // namespace profiler } // namespace profiler
} // namespace tensorflow } // namespace tensorflow