Use switch statements in lambdas passed to XEventVisitor::ForEachStat.

Also add utility to parse tensor shapes.

PiperOrigin-RevId: 315393466
Change-Id: I23a33867f132a3a30617315e79911780143e815e
This commit is contained in:
Jose Baiocchi 2020-06-08 18:09:41 -07:00 committed by TensorFlower Gardener
parent 2edb0fe27f
commit 2cceeea264
7 changed files with 125 additions and 75 deletions

View File

@ -49,18 +49,23 @@ KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb(
absl::string_view equation;
event.ForEachStat([&](const tensorflow::profiler::XStatVisitor& stat) {
if (stat.Type() == StatType::kLevel0) {
tf_op_fullname = stat.StrOrRefValue();
} else if (stat.Type() == StatType::kKernelDetails) {
kernel.set_name(event.Name().data(), event.Name().size());
bool using_tensor_cores = IsKernelUsingTensorCore(event.Name());
kernel.set_is_kernel_using_tensor_core(using_tensor_cores);
kernel.set_total_duration_ns(event.DurationNs());
kernel.set_min_duration_ns(event.DurationNs());
kernel.set_max_duration_ns(event.DurationNs());
ParseKernelLaunchParams(stat.StrOrRefValue(), &kernel);
} else if (stat.Type() == StatType::kEquation) {
equation = stat.StrOrRefValue();
if (!stat.Type().has_value()) return;
switch (stat.Type().value()) {
case StatType::kLevel0:
tf_op_fullname = stat.StrOrRefValue();
break;
case StatType::kKernelDetails:
kernel.set_name(event.Name().data(), event.Name().size());
kernel.set_is_kernel_using_tensor_core(
IsKernelUsingTensorCore(event.Name()));
kernel.set_total_duration_ns(event.DurationNs());
kernel.set_min_duration_ns(event.DurationNs());
kernel.set_max_duration_ns(event.DurationNs());
ParseKernelLaunchParams(stat.StrOrRefValue(), &kernel);
break;
case StatType::kEquation:
equation = stat.StrOrRefValue();
break;
}
});

View File

@ -146,38 +146,55 @@ MemoryProfile GenerateMemoryProfile(const XPlane* host_trace) {
ActivityMetadata metadata;
std::string memory_id;
event.ForEachStat([&](const XStatVisitor& stat) {
if (stat.Type() == StatType::kIndexOnHost ||
stat.Type() == StatType::kDeviceOrdinal) {
memory_id = absl::StrFormat("%d", stat.IntValue());
} else if (stat.Type() == StatType::kAllocatorName) {
memory_id = stat.ToString();
} else if (stat.Type() == StatType::kBytesReserved) {
stats.bytes_reserved = stat.IntValue();
} else if (stat.Type() == StatType::kBytesAllocated) {
stats.bytes_allocated = stat.IntValue();
} else if (stat.Type() == StatType::kBytesAvailable) {
stats.bytes_available = stat.IntValue();
} else if (stat.Type() == StatType::kFragmentation) {
stats.fragmentation = stat.DoubleValue();
} else if (stat.Type() == StatType::kPeakBytesInUse) {
stats.peak_bytes_in_use = stat.IntValue();
} else if (stat.Type() == StatType::kRequestedBytes) {
metadata.requested_bytes = stat.IntValue();
} else if (stat.Type() == StatType::kAllocationBytes) {
metadata.allocation_bytes = stat.IntValue();
} else if (stat.Type() == StatType::kAddress) {
metadata.address = stat.IntValue();
} else if (stat.Type() == StatType::kTfOp) {
metadata.tf_op_name = stat.StrOrRefValue();
} else if (stat.Type() == StatType::kStepId) {
metadata.step_id = stat.IntValue();
if (metadata.step_id != 0) (*step_count)[metadata.step_id]++;
} else if (stat.Type() == StatType::kRegionType) {
metadata.region_type = stat.StrOrRefValue();
} else if (stat.Type() == StatType::kDataType) {
metadata.data_type = stat.IntValue();
} else if (stat.Type() == StatType::kTensorShapes) {
metadata.tensor_shape = stat.StrOrRefValue();
if (!stat.Type().has_value()) return;
switch (stat.Type().value()) {
case StatType::kIndexOnHost:
case StatType::kDeviceOrdinal:
memory_id = absl::StrFormat("%d", stat.IntValue());
break;
case StatType::kAllocatorName:
memory_id = std::string(stat.StrOrRefValue());
break;
case StatType::kBytesReserved:
stats.bytes_reserved = stat.IntValue();
break;
case StatType::kBytesAllocated:
stats.bytes_allocated = stat.IntValue();
break;
case StatType::kBytesAvailable:
stats.bytes_available = stat.IntValue();
break;
case StatType::kFragmentation:
stats.fragmentation = stat.DoubleValue();
break;
case StatType::kPeakBytesInUse:
stats.peak_bytes_in_use = stat.IntValue();
break;
case StatType::kRequestedBytes:
metadata.requested_bytes = stat.IntValue();
break;
case StatType::kAllocationBytes:
metadata.allocation_bytes = stat.IntValue();
break;
case StatType::kAddress:
metadata.address = stat.IntValue();
break;
case StatType::kTfOp:
metadata.tf_op_name = stat.StrOrRefValue();
break;
case StatType::kStepId:
metadata.step_id = stat.IntValue();
if (metadata.step_id != 0) (*step_count)[metadata.step_id]++;
break;
case StatType::kRegionType:
metadata.region_type = stat.StrOrRefValue();
break;
case StatType::kDataType:
metadata.data_type = stat.IntValue();
break;
case StatType::kTensorShapes:
metadata.tensor_shape = stat.StrOrRefValue();
break;
}
});

View File

@ -61,12 +61,17 @@ StepEvents ConvertHostThreadsXLineToStepEvents(
int64 group_id = -1;
absl::string_view step_name;
event.ForEachStat([&](const XStatVisitor& stat) {
if (stat.Type() == StatType::kCorrelationId) {
correlation_id = stat.IntValue();
} else if (stat.Type() == StatType::kGroupId) {
group_id = stat.IntValue();
} else if (stat.Type() == StatType::kStepName) {
step_name = stat.StrOrRefValue();
if (!stat.Type().has_value()) return;
switch (stat.Type().value()) {
case StatType::kCorrelationId:
correlation_id = stat.IntValue();
break;
case StatType::kGroupId:
group_id = stat.IntValue();
break;
case StatType::kStepName:
step_name = stat.StrOrRefValue();
break;
}
});
if (group_id < 0) return;
@ -126,14 +131,19 @@ StepEvents ConvertDeviceTraceXLineToStepEvents(const XLineVisitor& line) {
line.ForEachEvent([&](const XEventVisitor& event) {
int64 correlation_id = -1;
int64 group_id = -1;
absl::string_view tensor_shapes = "";
absl::string_view tensor_shapes;
event.ForEachStat([&](const XStatVisitor& stat) {
if (stat.Type() == StatType::kCorrelationId) {
correlation_id = stat.IntValue();
} else if (stat.Type() == StatType::kGroupId) {
group_id = stat.IntValue();
} else if (stat.Type() == StatType::kTensorShapes) {
tensor_shapes = stat.StrOrRefValue();
if (!stat.Type().has_value()) return;
switch (stat.Type().value()) {
case StatType::kCorrelationId:
correlation_id = stat.IntValue();
break;
case StatType::kGroupId:
group_id = stat.IntValue();
break;
case StatType::kTensorShapes:
tensor_shapes = stat.StrOrRefValue();
break;
}
});

View File

@ -162,13 +162,18 @@ class TfFunctionExecutions {
explicit TfFunctionExecutions(const XLineVisitor& line) {
// Creates points_ and activations_ from line.
line.ForEachEvent([&](const XEventVisitor& event) {
std::string mode = "";
absl::string_view mode;
int64 tracing_count = 0;
event.ForEachStat([&mode, &tracing_count](const XStatVisitor& stat) {
if (stat.Type() == StatType::kTfFunctionCall)
mode = std::string(stat.StrOrRefValue());
if (stat.Type() == StatType::kTfFunctionTracingCount)
tracing_count = stat.IntValue();
if (!stat.Type().has_value()) return;
switch (stat.Type().value()) {
case StatType::kTfFunctionCall:
mode = stat.StrOrRefValue();
break;
case StatType::kTfFunctionTracingCount:
tracing_count = stat.IntValue();
break;
}
});
if (mode.empty()) return;

View File

@ -87,35 +87,34 @@ grappler::DeviceInfo TfOpRoofLineCostEstimator::GetDeviceInfo(
TfOpRoofLineCostEstimator::OpRoofLineStats TfOpRoofLineCostEstimator::Predict(
const XEventVisitor& event) {
TfOp tf_op;
bool has_shape_stats = false;
std::vector<std::string> input_tensors;
absl::string_view tensor_shapes;
event.ForEachStat([&](const XStatVisitor& stat) {
if (stat.Type() == StatType::kLevel0) {
tf_op = ParseTfOpFullname(stat.StrOrRefValue());
} else if (stat.Type() == StatType::kTensorShapes) {
has_shape_stats = true;
auto shapes_stats = stat.StrOrRefValue();
absl::ConsumePrefix(&shapes_stats, "(");
absl::ConsumeSuffix(&shapes_stats, ")");
input_tensors = absl::StrSplit(shapes_stats, ';');
if (!stat.Type().has_value()) return;
switch (stat.Type().value()) {
case StatType::kLevel0:
tf_op = ParseTfOpFullname(stat.StrOrRefValue());
break;
case StatType::kTensorShapes:
tensor_shapes = stat.StrOrRefValue();
break;
}
});
// Return empty OpRoofLineStats if shape is not traced or this is not a tf op.
if (tf_op.type.empty() || !has_shape_stats) {
if (tf_op.type.empty() || tensor_shapes.empty()) {
return {0ULL, 0ULL, /*inaccurate=*/true};
}
grappler::OpContext op_context;
op_context.name = std::string(tf_op.type);
op_context.op_info.set_op(op_context.name);
for (const auto& tensor : input_tensors) {
for (absl::string_view tensor : ParseTensorShapes(tensor_shapes)) {
*op_context.op_info.add_inputs() = GetTensorProperties(tensor);
}
grappler::Costs costs = PredictCosts(op_context);
if (costs.inaccurate) unsupported_ops_.insert(std::string(tf_op.type));
VLOG(1) << tf_op.type << "[" << absl::StrJoin(input_tensors, ",") << "]"
VLOG(1) << tf_op.type << tensor_shapes
<< " flops:" << costs.compute_time.count()
<< " bytes:" << costs.memory_time.count();

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
@ -104,5 +105,12 @@ std::string TfOpEventName(absl::string_view tf_op_fullname) {
return TfOpEventName(ParseTfOpFullname(tf_op_fullname));
}
std::vector<absl::string_view> ParseTensorShapes(
absl::string_view tensor_shapes) {
absl::ConsumePrefix(&tensor_shapes, "(");
absl::ConsumeSuffix(&tensor_shapes, ")");
return absl::StrSplit(tensor_shapes, ';');
}
} // namespace profiler
} // namespace tensorflow

View File

@ -81,6 +81,12 @@ inline bool IsMemcpyHToDOp(absl::string_view tf_op_type) {
inline bool IsMemcpyDToHOp(absl::string_view tf_op_type) {
return tf_op_type == kMemcpyDToHOp;
}
// Splits a string of tensor shapes in "(shape1;shape2;...)" format, i.e.,
// delimited by '(' and ')' and separated by ';', into the individual shapes.
std::vector<absl::string_view> ParseTensorShapes(
absl::string_view tensor_shapes);
} // namespace profiler
} // namespace tensorflow