Use switch statements in lambdas passed to XEventVisitor::ForEachStat.
Also add utility to parse tensor shapes. PiperOrigin-RevId: 315393466 Change-Id: I23a33867f132a3a30617315e79911780143e815e
This commit is contained in:
parent
2edb0fe27f
commit
2cceeea264
tensorflow/core/profiler
@ -49,18 +49,23 @@ KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb(
|
||||
|
||||
absl::string_view equation;
|
||||
event.ForEachStat([&](const tensorflow::profiler::XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kLevel0) {
|
||||
tf_op_fullname = stat.StrOrRefValue();
|
||||
} else if (stat.Type() == StatType::kKernelDetails) {
|
||||
kernel.set_name(event.Name().data(), event.Name().size());
|
||||
bool using_tensor_cores = IsKernelUsingTensorCore(event.Name());
|
||||
kernel.set_is_kernel_using_tensor_core(using_tensor_cores);
|
||||
kernel.set_total_duration_ns(event.DurationNs());
|
||||
kernel.set_min_duration_ns(event.DurationNs());
|
||||
kernel.set_max_duration_ns(event.DurationNs());
|
||||
ParseKernelLaunchParams(stat.StrOrRefValue(), &kernel);
|
||||
} else if (stat.Type() == StatType::kEquation) {
|
||||
equation = stat.StrOrRefValue();
|
||||
if (!stat.Type().has_value()) return;
|
||||
switch (stat.Type().value()) {
|
||||
case StatType::kLevel0:
|
||||
tf_op_fullname = stat.StrOrRefValue();
|
||||
break;
|
||||
case StatType::kKernelDetails:
|
||||
kernel.set_name(event.Name().data(), event.Name().size());
|
||||
kernel.set_is_kernel_using_tensor_core(
|
||||
IsKernelUsingTensorCore(event.Name()));
|
||||
kernel.set_total_duration_ns(event.DurationNs());
|
||||
kernel.set_min_duration_ns(event.DurationNs());
|
||||
kernel.set_max_duration_ns(event.DurationNs());
|
||||
ParseKernelLaunchParams(stat.StrOrRefValue(), &kernel);
|
||||
break;
|
||||
case StatType::kEquation:
|
||||
equation = stat.StrOrRefValue();
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -146,38 +146,55 @@ MemoryProfile GenerateMemoryProfile(const XPlane* host_trace) {
|
||||
ActivityMetadata metadata;
|
||||
std::string memory_id;
|
||||
event.ForEachStat([&](const XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kIndexOnHost ||
|
||||
stat.Type() == StatType::kDeviceOrdinal) {
|
||||
memory_id = absl::StrFormat("%d", stat.IntValue());
|
||||
} else if (stat.Type() == StatType::kAllocatorName) {
|
||||
memory_id = stat.ToString();
|
||||
} else if (stat.Type() == StatType::kBytesReserved) {
|
||||
stats.bytes_reserved = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kBytesAllocated) {
|
||||
stats.bytes_allocated = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kBytesAvailable) {
|
||||
stats.bytes_available = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kFragmentation) {
|
||||
stats.fragmentation = stat.DoubleValue();
|
||||
} else if (stat.Type() == StatType::kPeakBytesInUse) {
|
||||
stats.peak_bytes_in_use = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kRequestedBytes) {
|
||||
metadata.requested_bytes = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kAllocationBytes) {
|
||||
metadata.allocation_bytes = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kAddress) {
|
||||
metadata.address = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kTfOp) {
|
||||
metadata.tf_op_name = stat.StrOrRefValue();
|
||||
} else if (stat.Type() == StatType::kStepId) {
|
||||
metadata.step_id = stat.IntValue();
|
||||
if (metadata.step_id != 0) (*step_count)[metadata.step_id]++;
|
||||
} else if (stat.Type() == StatType::kRegionType) {
|
||||
metadata.region_type = stat.StrOrRefValue();
|
||||
} else if (stat.Type() == StatType::kDataType) {
|
||||
metadata.data_type = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kTensorShapes) {
|
||||
metadata.tensor_shape = stat.StrOrRefValue();
|
||||
if (!stat.Type().has_value()) return;
|
||||
switch (stat.Type().value()) {
|
||||
case StatType::kIndexOnHost:
|
||||
case StatType::kDeviceOrdinal:
|
||||
memory_id = absl::StrFormat("%d", stat.IntValue());
|
||||
break;
|
||||
case StatType::kAllocatorName:
|
||||
memory_id = std::string(stat.StrOrRefValue());
|
||||
break;
|
||||
case StatType::kBytesReserved:
|
||||
stats.bytes_reserved = stat.IntValue();
|
||||
break;
|
||||
case StatType::kBytesAllocated:
|
||||
stats.bytes_allocated = stat.IntValue();
|
||||
break;
|
||||
case StatType::kBytesAvailable:
|
||||
stats.bytes_available = stat.IntValue();
|
||||
break;
|
||||
case StatType::kFragmentation:
|
||||
stats.fragmentation = stat.DoubleValue();
|
||||
break;
|
||||
case StatType::kPeakBytesInUse:
|
||||
stats.peak_bytes_in_use = stat.IntValue();
|
||||
break;
|
||||
case StatType::kRequestedBytes:
|
||||
metadata.requested_bytes = stat.IntValue();
|
||||
break;
|
||||
case StatType::kAllocationBytes:
|
||||
metadata.allocation_bytes = stat.IntValue();
|
||||
break;
|
||||
case StatType::kAddress:
|
||||
metadata.address = stat.IntValue();
|
||||
break;
|
||||
case StatType::kTfOp:
|
||||
metadata.tf_op_name = stat.StrOrRefValue();
|
||||
break;
|
||||
case StatType::kStepId:
|
||||
metadata.step_id = stat.IntValue();
|
||||
if (metadata.step_id != 0) (*step_count)[metadata.step_id]++;
|
||||
break;
|
||||
case StatType::kRegionType:
|
||||
metadata.region_type = stat.StrOrRefValue();
|
||||
break;
|
||||
case StatType::kDataType:
|
||||
metadata.data_type = stat.IntValue();
|
||||
break;
|
||||
case StatType::kTensorShapes:
|
||||
metadata.tensor_shape = stat.StrOrRefValue();
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -61,12 +61,17 @@ StepEvents ConvertHostThreadsXLineToStepEvents(
|
||||
int64 group_id = -1;
|
||||
absl::string_view step_name;
|
||||
event.ForEachStat([&](const XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kCorrelationId) {
|
||||
correlation_id = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kGroupId) {
|
||||
group_id = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kStepName) {
|
||||
step_name = stat.StrOrRefValue();
|
||||
if (!stat.Type().has_value()) return;
|
||||
switch (stat.Type().value()) {
|
||||
case StatType::kCorrelationId:
|
||||
correlation_id = stat.IntValue();
|
||||
break;
|
||||
case StatType::kGroupId:
|
||||
group_id = stat.IntValue();
|
||||
break;
|
||||
case StatType::kStepName:
|
||||
step_name = stat.StrOrRefValue();
|
||||
break;
|
||||
}
|
||||
});
|
||||
if (group_id < 0) return;
|
||||
@ -126,14 +131,19 @@ StepEvents ConvertDeviceTraceXLineToStepEvents(const XLineVisitor& line) {
|
||||
line.ForEachEvent([&](const XEventVisitor& event) {
|
||||
int64 correlation_id = -1;
|
||||
int64 group_id = -1;
|
||||
absl::string_view tensor_shapes = "";
|
||||
absl::string_view tensor_shapes;
|
||||
event.ForEachStat([&](const XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kCorrelationId) {
|
||||
correlation_id = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kGroupId) {
|
||||
group_id = stat.IntValue();
|
||||
} else if (stat.Type() == StatType::kTensorShapes) {
|
||||
tensor_shapes = stat.StrOrRefValue();
|
||||
if (!stat.Type().has_value()) return;
|
||||
switch (stat.Type().value()) {
|
||||
case StatType::kCorrelationId:
|
||||
correlation_id = stat.IntValue();
|
||||
break;
|
||||
case StatType::kGroupId:
|
||||
group_id = stat.IntValue();
|
||||
break;
|
||||
case StatType::kTensorShapes:
|
||||
tensor_shapes = stat.StrOrRefValue();
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -162,13 +162,18 @@ class TfFunctionExecutions {
|
||||
explicit TfFunctionExecutions(const XLineVisitor& line) {
|
||||
// Creates points_ and activations_ from line.
|
||||
line.ForEachEvent([&](const XEventVisitor& event) {
|
||||
std::string mode = "";
|
||||
absl::string_view mode;
|
||||
int64 tracing_count = 0;
|
||||
event.ForEachStat([&mode, &tracing_count](const XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kTfFunctionCall)
|
||||
mode = std::string(stat.StrOrRefValue());
|
||||
if (stat.Type() == StatType::kTfFunctionTracingCount)
|
||||
tracing_count = stat.IntValue();
|
||||
if (!stat.Type().has_value()) return;
|
||||
switch (stat.Type().value()) {
|
||||
case StatType::kTfFunctionCall:
|
||||
mode = stat.StrOrRefValue();
|
||||
break;
|
||||
case StatType::kTfFunctionTracingCount:
|
||||
tracing_count = stat.IntValue();
|
||||
break;
|
||||
}
|
||||
});
|
||||
if (mode.empty()) return;
|
||||
|
||||
|
@ -87,35 +87,34 @@ grappler::DeviceInfo TfOpRoofLineCostEstimator::GetDeviceInfo(
|
||||
TfOpRoofLineCostEstimator::OpRoofLineStats TfOpRoofLineCostEstimator::Predict(
|
||||
const XEventVisitor& event) {
|
||||
TfOp tf_op;
|
||||
bool has_shape_stats = false;
|
||||
std::vector<std::string> input_tensors;
|
||||
absl::string_view tensor_shapes;
|
||||
event.ForEachStat([&](const XStatVisitor& stat) {
|
||||
if (stat.Type() == StatType::kLevel0) {
|
||||
tf_op = ParseTfOpFullname(stat.StrOrRefValue());
|
||||
} else if (stat.Type() == StatType::kTensorShapes) {
|
||||
has_shape_stats = true;
|
||||
auto shapes_stats = stat.StrOrRefValue();
|
||||
absl::ConsumePrefix(&shapes_stats, "(");
|
||||
absl::ConsumeSuffix(&shapes_stats, ")");
|
||||
input_tensors = absl::StrSplit(shapes_stats, ';');
|
||||
if (!stat.Type().has_value()) return;
|
||||
switch (stat.Type().value()) {
|
||||
case StatType::kLevel0:
|
||||
tf_op = ParseTfOpFullname(stat.StrOrRefValue());
|
||||
break;
|
||||
case StatType::kTensorShapes:
|
||||
tensor_shapes = stat.StrOrRefValue();
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
// Return empty OpRoofLineStats if shape is not traced or this is not a tf op.
|
||||
if (tf_op.type.empty() || !has_shape_stats) {
|
||||
if (tf_op.type.empty() || tensor_shapes.empty()) {
|
||||
return {0ULL, 0ULL, /*inaccurate=*/true};
|
||||
}
|
||||
|
||||
grappler::OpContext op_context;
|
||||
op_context.name = std::string(tf_op.type);
|
||||
op_context.op_info.set_op(op_context.name);
|
||||
for (const auto& tensor : input_tensors) {
|
||||
for (absl::string_view tensor : ParseTensorShapes(tensor_shapes)) {
|
||||
*op_context.op_info.add_inputs() = GetTensorProperties(tensor);
|
||||
}
|
||||
grappler::Costs costs = PredictCosts(op_context);
|
||||
if (costs.inaccurate) unsupported_ops_.insert(std::string(tf_op.type));
|
||||
|
||||
VLOG(1) << tf_op.type << "[" << absl::StrJoin(input_tensors, ",") << "]"
|
||||
VLOG(1) << tf_op.type << tensor_shapes
|
||||
<< " flops:" << costs.compute_time.count()
|
||||
<< " bytes:" << costs.memory_time.count();
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -104,5 +105,12 @@ std::string TfOpEventName(absl::string_view tf_op_fullname) {
|
||||
return TfOpEventName(ParseTfOpFullname(tf_op_fullname));
|
||||
}
|
||||
|
||||
std::vector<absl::string_view> ParseTensorShapes(
|
||||
absl::string_view tensor_shapes) {
|
||||
absl::ConsumePrefix(&tensor_shapes, "(");
|
||||
absl::ConsumeSuffix(&tensor_shapes, ")");
|
||||
return absl::StrSplit(tensor_shapes, ';');
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
@ -81,6 +81,12 @@ inline bool IsMemcpyHToDOp(absl::string_view tf_op_type) {
|
||||
inline bool IsMemcpyDToHOp(absl::string_view tf_op_type) {
|
||||
return tf_op_type == kMemcpyDToHOp;
|
||||
}
|
||||
|
||||
// Splits a string of tensor shapes in "(shape1;shape2;...)" format, i.e.,
|
||||
// delimited by '(' and ')' and separated by ';', into the individual shapes.
|
||||
std::vector<absl::string_view> ParseTensorShapes(
|
||||
absl::string_view tensor_shapes);
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user