diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index b5181b1edd3..9fd12a20cad 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -153,7 +153,7 @@ absl::flat_hash_map CollectTfOpsFromHostThreadsXPlane( // user-inserted TraceMe's have "unknown" type. We don't count them in // Tf-stats. TfOp tf_op = ParseTfOpFullname(metadata.name()); - if (!IsUnknownOp(tf_op.type)) { + if (tf_op.category != Category::kUnknown) { tf_ops.try_emplace(metadata.id(), tf_op); } } @@ -214,7 +214,7 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( if (tf_op_fullname.empty()) return; TfOp tf_op = ParseTfOpFullname(tf_op_fullname); TfOpRoofLineCostEstimator::OpRoofLineStats costs; - if (tf_op.type != kUnknownOp) { + if (tf_op.category != Category::kUnknown) { costs = op_level_cost_estimator.Predict(event); } device_op_metrics_db_builder.EnterOp( diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index f61926a1850..ef2e38b45e2 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -147,7 +147,8 @@ void ProcessTfOpEvent(const XEventVisitor& event, plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId)) ->id(); TfOp tf_op = ParseTfOpFullname(tf_op_full_name); - if (tf_op.is_tf_op) { + Category category = tf_op.category; + if (category == Category::kTensorFlow || category == Category::kJax) { std::vector name_scope_event_per_level; for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) { name_scope_event_per_level.push_back(CreateXEvent( diff --git a/tensorflow/core/profiler/utils/tf_op_utils.cc b/tensorflow/core/profiler/utils/tf_op_utils.cc index 8a9556fb4cd..5a4204440a3 100644 --- a/tensorflow/core/profiler/utils/tf_op_utils.cc +++ b/tensorflow/core/profiler/utils/tf_op_utils.cc @@ -47,14 +47,16 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) { // JAX op types have only lowercase letters and underscores. static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"}; - TfOp tf_op = {tf_op_fullname, kUnknownOp, /*is_tf_op=*/false}; + TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp}; std::vector parts = absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1)); if (parts.size() != 2) { // GPU-related Ops that need to be tracked. if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) { + tf_op.category = Category::kMemcpyHToD; tf_op.type = kMemcpyHToDOp; } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) { + tf_op.category = Category::kMemcpyDToH; tf_op.type = kMemcpyDToHOp; } // TODO(ckluk): Include the corresponding Ops on TPU. @@ -62,12 +64,13 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) { // Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the // format of TF Op names. But we still want to capture them for // input-pipeline analysis. + tf_op.category = Category::kTfData; tf_op.type = kDatasetOp; } else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) && RE2::FullMatch(parts[0], *kTfOpNameRegEx)) { // TensorFlow - tf_op = {parts[0], parts[1], /*is_tf_op=*/true}; + tf_op = {Category::kTensorFlow, parts[0], parts[1]}; } else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) { // JAX - tf_op = {parts[0], parts[1], /*is_tf_op=*/false}; + tf_op = {Category::kJax, parts[0], parts[1]}; } return tf_op; } @@ -81,10 +84,10 @@ std::vector ParseTfNameScopes(const TfOp& tf_op) { std::string TfOpEventName(const TfOp& tf_op) { std::string event_name; - if (tf_op.type == kUnknownOp) { + if (tf_op.category == Category::kUnknown) { // Some TraceMe names contain trailing whitespace, remove it. event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name)); - } else if (tf_op.type == kDatasetOp) { + } else if (tf_op.category == Category::kTfData) { std::vector op_parts = absl::StrSplit(tf_op.name, kSeparator); event_name = absl::StrCat(kIterator, kSeparator, op_parts.back()); diff --git a/tensorflow/core/profiler/utils/tf_op_utils.h b/tensorflow/core/profiler/utils/tf_op_utils.h index 4647dbbcc40..d1ac69e2976 100644 --- a/tensorflow/core/profiler/utils/tf_op_utils.h +++ b/tensorflow/core/profiler/utils/tf_op_utils.h @@ -31,11 +31,20 @@ ABSL_CONST_INIT extern const absl::string_view kDatasetOp; ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp; ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp; +enum class Category { + kTensorFlow, + kJax, + kTfData, + kMemcpyHToD, + kMemcpyDToH, + kUnknown, +}; + // Breaks a TensorFlow op fullname into name and type. struct TfOp { + Category category; absl::string_view name; absl::string_view type; - bool is_tf_op; }; TfOp ParseTfOpFullname(absl::string_view tf_op_fullname); @@ -48,11 +57,6 @@ std::vector ParseTfNameScopes(const TfOp& tf_op); std::string TfOpEventName(const TfOp& tf_op); std::string TfOpEventName(absl::string_view tf_op_fullname); -// Returns true if the given name is not a TensorFlow op. -inline bool IsUnknownOp(absl::string_view tf_op_type) { - return tf_op_type == kUnknownOp; -} - // Returns true if the given name is a TensorFlow Dataset Op. inline bool IsDatasetOp(absl::string_view tf_op_type) { return tf_op_type == kDatasetOp; diff --git a/tensorflow/core/profiler/utils/tf_op_utils_test.cc b/tensorflow/core/profiler/utils/tf_op_utils_test.cc index ff62c822e65..fa5169557d1 100644 --- a/tensorflow/core/profiler/utils/tf_op_utils_test.cc +++ b/tensorflow/core/profiler/utils/tf_op_utils_test.cc @@ -24,6 +24,7 @@ namespace { TEST(TfOpUtilsTest, TfOpTest) { const absl::string_view kName = "OpName:OpType"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTensorFlow); EXPECT_EQ(tf_op.name, "OpName"); EXPECT_EQ(tf_op.type, "OpType"); EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only @@ -32,6 +33,7 @@ TEST(TfOpUtilsTest, TfOpTest) { TEST(TfOpUtilsTest, InternalTfOpTest) { const absl::string_view kName = "OpName:_InternalOpType"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTensorFlow); EXPECT_EQ(tf_op.name, "OpName"); EXPECT_EQ(tf_op.type, "_InternalOpType"); EXPECT_EQ(TfOpEventName(kName), "_InternalOpType"); // type only @@ -40,6 +42,7 @@ TEST(TfOpUtilsTest, InternalTfOpTest) { TEST(TfOpUtilsTest, TfOpWithPathTest) { const absl::string_view kName = "path/to/name:OpType"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTensorFlow); EXPECT_EQ(tf_op.name, "path/to/name"); EXPECT_EQ(tf_op.type, "OpType"); EXPECT_EQ(TfOpEventName(kName), "OpType"); // type only @@ -48,24 +51,27 @@ TEST(TfOpUtilsTest, TfOpWithPathTest) { TEST(TfOpUtilsTest, ShortDatasetOpTest) { const absl::string_view kName = "Iterator::Batch"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTfData); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsDatasetOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kDatasetOp); EXPECT_EQ(TfOpEventName(kName), kName); } TEST(TfOpUtilsTest, LongDatasetOpTest) { const absl::string_view kName = "Iterator::Batch::Map::TfRecord"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kTfData); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsDatasetOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kDatasetOp); EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord"); // shorter name } TEST(TfOpUtilsTest, TraceMeTest) { const absl::string_view kName = "MyTraceMe"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kUnknown); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsUnknownOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kUnknownOp); EXPECT_EQ(TfOpEventName(kName), kName); } @@ -73,16 +79,18 @@ TEST(TfOpUtilsTest, TraceMeWithColonTest) { // "12345" is not a valid op type. const absl::string_view kName = "RunStep/Server:54635"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kUnknown); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsUnknownOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kUnknownOp); EXPECT_EQ(TfOpEventName(kName), kName); } TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) { const absl::string_view kName = "XLA::StartProgram"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kUnknown); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsUnknownOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kUnknownOp); EXPECT_EQ(TfOpEventName(kName), kName); } @@ -90,11 +98,39 @@ TEST(TfOpUtilsTest, TraceMeWithTrailingWhitespaceTest) { const absl::string_view kName = "SessionRun "; const absl::string_view kNameTrimmed = "SessionRun"; TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kUnknown); EXPECT_EQ(tf_op.name, kName); - EXPECT_TRUE(IsUnknownOp(tf_op.type)); + EXPECT_EQ(tf_op.type, kUnknownOp); EXPECT_EQ(TfOpEventName(kName), kNameTrimmed); } +TEST(TfOpUtilsTest, MemcpyHToDTest) { + const absl::string_view kName = "MemcpyHToD"; + TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kMemcpyHToD); + EXPECT_EQ(tf_op.name, kName); + EXPECT_EQ(tf_op.type, kMemcpyHToDOp); + EXPECT_EQ(TfOpEventName(kName), kName); +} + +TEST(TfOpUtilsTest, MemcpyDToHTest) { + const absl::string_view kName = "MemcpyDToH"; + TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kMemcpyDToH); + EXPECT_EQ(tf_op.name, kName); + EXPECT_EQ(tf_op.type, kMemcpyDToHOp); + EXPECT_EQ(TfOpEventName(kName), kName); +} + +TEST(TfOpUtilsTest, JaxOpTest) { + const absl::string_view kName = "op_name:op_type"; + TfOp tf_op = ParseTfOpFullname(kName); + EXPECT_EQ(tf_op.category, Category::kJax); + EXPECT_EQ(tf_op.name, "op_name"); + EXPECT_EQ(tf_op.type, "op_type"); + EXPECT_EQ(TfOpEventName(kName), "op_type"); +} + } // namespace } // namespace profiler } // namespace tensorflow