Add the category enum to avoid string comparisons and enable the name scope line for JAX.
PiperOrigin-RevId: 304437327 Change-Id: Ife469db2caf0f93a0b7e4b9a271e0568fbf0ccf1
This commit is contained in:
		
							parent
							
								
									b16d24a342
								
							
						
					
					
						commit
						bc2cdb667d
					
				| @ -153,7 +153,7 @@ absl::flat_hash_map<int64, TfOp> CollectTfOpsFromHostThreadsXPlane( | ||||
|     // user-inserted TraceMe's have "unknown" type. We don't count them in
 | ||||
|     // Tf-stats.
 | ||||
|     TfOp tf_op = ParseTfOpFullname(metadata.name()); | ||||
|     if (!IsUnknownOp(tf_op.type)) { | ||||
|     if (tf_op.category != Category::kUnknown) { | ||||
|       tf_ops.try_emplace(metadata.id(), tf_op); | ||||
|     } | ||||
|   } | ||||
| @ -214,7 +214,7 @@ OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb( | ||||
|       if (tf_op_fullname.empty()) return; | ||||
|       TfOp tf_op = ParseTfOpFullname(tf_op_fullname); | ||||
|       TfOpRoofLineCostEstimator::OpRoofLineStats costs; | ||||
|       if (tf_op.type != kUnknownOp) { | ||||
|       if (tf_op.category != Category::kUnknown) { | ||||
|         costs = op_level_cost_estimator.Predict(event); | ||||
|       } | ||||
|       device_op_metrics_db_builder.EnterOp( | ||||
|  | ||||
| @ -147,7 +147,8 @@ void ProcessTfOpEvent(const XEventVisitor& event, | ||||
|       plane_builder->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId)) | ||||
|           ->id(); | ||||
|   TfOp tf_op = ParseTfOpFullname(tf_op_full_name); | ||||
|   if (tf_op.is_tf_op) { | ||||
|   Category category = tf_op.category; | ||||
|   if (category == Category::kTensorFlow || category == Category::kJax) { | ||||
|     std::vector<XEvent> name_scope_event_per_level; | ||||
|     for (const auto& tf_name_scope : ParseTfNameScopes(tf_op)) { | ||||
|       name_scope_event_per_level.push_back(CreateXEvent( | ||||
|  | ||||
| @ -47,14 +47,16 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) { | ||||
|   // JAX op types have only lowercase letters and underscores.
 | ||||
|   static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_]*"}; | ||||
| 
 | ||||
|   TfOp tf_op = {tf_op_fullname, kUnknownOp, /*is_tf_op=*/false}; | ||||
|   TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp}; | ||||
|   std::vector<absl::string_view> parts = | ||||
|       absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1)); | ||||
|   if (parts.size() != 2) { | ||||
|     // GPU-related Ops that need to be tracked.
 | ||||
|     if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) { | ||||
|       tf_op.category = Category::kMemcpyHToD; | ||||
|       tf_op.type = kMemcpyHToDOp; | ||||
|     } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) { | ||||
|       tf_op.category = Category::kMemcpyDToH; | ||||
|       tf_op.type = kMemcpyDToHOp; | ||||
|     } | ||||
|     // TODO(ckluk): Include the corresponding Ops on TPU.
 | ||||
| @ -62,12 +64,13 @@ TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) { | ||||
|     // Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
 | ||||
|     // format of TF Op names. But we still want to capture them for
 | ||||
|     // input-pipeline analysis.
 | ||||
|     tf_op.category = Category::kTfData; | ||||
|     tf_op.type = kDatasetOp; | ||||
|   } else if (RE2::FullMatch(parts[1], *kTfOpTypeRegEx) && | ||||
|              RE2::FullMatch(parts[0], *kTfOpNameRegEx)) {  // TensorFlow
 | ||||
|     tf_op = {parts[0], parts[1], /*is_tf_op=*/true}; | ||||
|     tf_op = {Category::kTensorFlow, parts[0], parts[1]}; | ||||
|   } else if (RE2::FullMatch(parts[1], *kJaxOpTypeRegEx)) {  // JAX
 | ||||
|     tf_op = {parts[0], parts[1], /*is_tf_op=*/false}; | ||||
|     tf_op = {Category::kJax, parts[0], parts[1]}; | ||||
|   } | ||||
|   return tf_op; | ||||
| } | ||||
| @ -81,10 +84,10 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op) { | ||||
| 
 | ||||
| std::string TfOpEventName(const TfOp& tf_op) { | ||||
|   std::string event_name; | ||||
|   if (tf_op.type == kUnknownOp) { | ||||
|   if (tf_op.category == Category::kUnknown) { | ||||
|     // Some TraceMe names contain trailing whitespace, remove it.
 | ||||
|     event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name)); | ||||
|   } else if (tf_op.type == kDatasetOp) { | ||||
|   } else if (tf_op.category == Category::kTfData) { | ||||
|     std::vector<absl::string_view> op_parts = | ||||
|         absl::StrSplit(tf_op.name, kSeparator); | ||||
|     event_name = absl::StrCat(kIterator, kSeparator, op_parts.back()); | ||||
|  | ||||
| @ -31,11 +31,20 @@ ABSL_CONST_INIT extern const absl::string_view kDatasetOp; | ||||
| ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp; | ||||
| ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp; | ||||
| 
 | ||||
| enum class Category { | ||||
|   kTensorFlow, | ||||
|   kJax, | ||||
|   kTfData, | ||||
|   kMemcpyHToD, | ||||
|   kMemcpyDToH, | ||||
|   kUnknown, | ||||
| }; | ||||
| 
 | ||||
| // Breaks a TensorFlow op fullname into name and type.
 | ||||
| struct TfOp { | ||||
|   Category category; | ||||
|   absl::string_view name; | ||||
|   absl::string_view type; | ||||
|   bool is_tf_op; | ||||
| }; | ||||
| 
 | ||||
| TfOp ParseTfOpFullname(absl::string_view tf_op_fullname); | ||||
| @ -48,11 +57,6 @@ std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op); | ||||
| std::string TfOpEventName(const TfOp& tf_op); | ||||
| std::string TfOpEventName(absl::string_view tf_op_fullname); | ||||
| 
 | ||||
| // Returns true if the given name is not a TensorFlow op.
 | ||||
| inline bool IsUnknownOp(absl::string_view tf_op_type) { | ||||
|   return tf_op_type == kUnknownOp; | ||||
| } | ||||
| 
 | ||||
| // Returns true if the given name is a TensorFlow Dataset Op.
 | ||||
| inline bool IsDatasetOp(absl::string_view tf_op_type) { | ||||
|   return tf_op_type == kDatasetOp; | ||||
|  | ||||
| @ -24,6 +24,7 @@ namespace { | ||||
| TEST(TfOpUtilsTest, TfOpTest) { | ||||
|   const absl::string_view kName = "OpName:OpType"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kTensorFlow); | ||||
|   EXPECT_EQ(tf_op.name, "OpName"); | ||||
|   EXPECT_EQ(tf_op.type, "OpType"); | ||||
|   EXPECT_EQ(TfOpEventName(kName), "OpType");  // type only
 | ||||
| @ -32,6 +33,7 @@ TEST(TfOpUtilsTest, TfOpTest) { | ||||
| TEST(TfOpUtilsTest, InternalTfOpTest) { | ||||
|   const absl::string_view kName = "OpName:_InternalOpType"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kTensorFlow); | ||||
|   EXPECT_EQ(tf_op.name, "OpName"); | ||||
|   EXPECT_EQ(tf_op.type, "_InternalOpType"); | ||||
|   EXPECT_EQ(TfOpEventName(kName), "_InternalOpType");  // type only
 | ||||
| @ -40,6 +42,7 @@ TEST(TfOpUtilsTest, InternalTfOpTest) { | ||||
| TEST(TfOpUtilsTest, TfOpWithPathTest) { | ||||
|   const absl::string_view kName = "path/to/name:OpType"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kTensorFlow); | ||||
|   EXPECT_EQ(tf_op.name, "path/to/name"); | ||||
|   EXPECT_EQ(tf_op.type, "OpType"); | ||||
|   EXPECT_EQ(TfOpEventName(kName), "OpType");  // type only
 | ||||
| @ -48,24 +51,27 @@ TEST(TfOpUtilsTest, TfOpWithPathTest) { | ||||
| TEST(TfOpUtilsTest, ShortDatasetOpTest) { | ||||
|   const absl::string_view kName = "Iterator::Batch"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kTfData); | ||||
|   EXPECT_EQ(tf_op.name, kName); | ||||
|   EXPECT_TRUE(IsDatasetOp(tf_op.type)); | ||||
|   EXPECT_EQ(tf_op.type, kDatasetOp); | ||||
|   EXPECT_EQ(TfOpEventName(kName), kName); | ||||
| } | ||||
| 
 | ||||
| TEST(TfOpUtilsTest, LongDatasetOpTest) { | ||||
|   const absl::string_view kName = "Iterator::Batch::Map::TfRecord"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kTfData); | ||||
|   EXPECT_EQ(tf_op.name, kName); | ||||
|   EXPECT_TRUE(IsDatasetOp(tf_op.type)); | ||||
|   EXPECT_EQ(tf_op.type, kDatasetOp); | ||||
|   EXPECT_EQ(TfOpEventName(kName), "Iterator::TfRecord");  // shorter name
 | ||||
| } | ||||
| 
 | ||||
| TEST(TfOpUtilsTest, TraceMeTest) { | ||||
|   const absl::string_view kName = "MyTraceMe"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kUnknown); | ||||
|   EXPECT_EQ(tf_op.name, kName); | ||||
|   EXPECT_TRUE(IsUnknownOp(tf_op.type)); | ||||
|   EXPECT_EQ(tf_op.type, kUnknownOp); | ||||
|   EXPECT_EQ(TfOpEventName(kName), kName); | ||||
| } | ||||
| 
 | ||||
| @ -73,16 +79,18 @@ TEST(TfOpUtilsTest, TraceMeWithColonTest) { | ||||
|   // "12345" is not a valid op type.
 | ||||
|   const absl::string_view kName = "RunStep/Server:54635"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kUnknown); | ||||
|   EXPECT_EQ(tf_op.name, kName); | ||||
|   EXPECT_TRUE(IsUnknownOp(tf_op.type)); | ||||
|   EXPECT_EQ(tf_op.type, kUnknownOp); | ||||
|   EXPECT_EQ(TfOpEventName(kName), kName); | ||||
| } | ||||
| 
 | ||||
| TEST(TfOpUtilsTest, TraceMeWithDoubleColonTest) { | ||||
|   const absl::string_view kName = "XLA::StartProgram"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kUnknown); | ||||
|   EXPECT_EQ(tf_op.name, kName); | ||||
|   EXPECT_TRUE(IsUnknownOp(tf_op.type)); | ||||
|   EXPECT_EQ(tf_op.type, kUnknownOp); | ||||
|   EXPECT_EQ(TfOpEventName(kName), kName); | ||||
| } | ||||
| 
 | ||||
| @ -90,11 +98,39 @@ TEST(TfOpUtilsTest, TraceMeWithTrailingWhitespaceTest) { | ||||
|   const absl::string_view kName = "SessionRun "; | ||||
|   const absl::string_view kNameTrimmed = "SessionRun"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kUnknown); | ||||
|   EXPECT_EQ(tf_op.name, kName); | ||||
|   EXPECT_TRUE(IsUnknownOp(tf_op.type)); | ||||
|   EXPECT_EQ(tf_op.type, kUnknownOp); | ||||
|   EXPECT_EQ(TfOpEventName(kName), kNameTrimmed); | ||||
| } | ||||
| 
 | ||||
| TEST(TfOpUtilsTest, MemcpyHToDTest) { | ||||
|   const absl::string_view kName = "MemcpyHToD"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kMemcpyHToD); | ||||
|   EXPECT_EQ(tf_op.name, kName); | ||||
|   EXPECT_EQ(tf_op.type, kMemcpyHToDOp); | ||||
|   EXPECT_EQ(TfOpEventName(kName), kName); | ||||
| } | ||||
| 
 | ||||
| TEST(TfOpUtilsTest, MemcpyDToHTest) { | ||||
|   const absl::string_view kName = "MemcpyDToH"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kMemcpyDToH); | ||||
|   EXPECT_EQ(tf_op.name, kName); | ||||
|   EXPECT_EQ(tf_op.type, kMemcpyDToHOp); | ||||
|   EXPECT_EQ(TfOpEventName(kName), kName); | ||||
| } | ||||
| 
 | ||||
| TEST(TfOpUtilsTest, JaxOpTest) { | ||||
|   const absl::string_view kName = "op_name:op_type"; | ||||
|   TfOp tf_op = ParseTfOpFullname(kName); | ||||
|   EXPECT_EQ(tf_op.category, Category::kJax); | ||||
|   EXPECT_EQ(tf_op.name, "op_name"); | ||||
|   EXPECT_EQ(tf_op.type, "op_type"); | ||||
|   EXPECT_EQ(TfOpEventName(kName), "op_type"); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace profiler
 | ||||
| }  // namespace tensorflow
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user