[XLA] Fixes some div-by-zero bugs.
Change: 153795265
This commit is contained in:
parent
c1bd0fe248
commit
b0594e1b82
@ -70,6 +70,7 @@ string HloExecutionProfile::ToString(
|
|||||||
string result;
|
string result;
|
||||||
const int64 total_cycles = total_cycles_executed(computation);
|
const int64 total_cycles = total_cycles_executed(computation);
|
||||||
double clock_rate_ghz = device_description.clock_rate_ghz();
|
double clock_rate_ghz = device_description.clock_rate_ghz();
|
||||||
|
CHECK_GE(clock_rate_ghz, 1e-9);
|
||||||
|
|
||||||
const auto cycles_to_microseconds = [&](double cycles) {
|
const auto cycles_to_microseconds = [&](double cycles) {
|
||||||
return cycles / clock_rate_ghz / 1000.0;
|
return cycles / clock_rate_ghz / 1000.0;
|
||||||
@ -80,14 +81,19 @@ string HloExecutionProfile::ToString(
|
|||||||
double nsecs = cycles / clock_rate_ghz;
|
double nsecs = cycles / clock_rate_ghz;
|
||||||
string bytes_per_sec;
|
string bytes_per_sec;
|
||||||
string bytes_per_cycle;
|
string bytes_per_cycle;
|
||||||
if (bytes_accessed >= 0) {
|
if (cycles <= 0 || bytes_accessed < 0) {
|
||||||
|
bytes_per_sec = "<unknown>";
|
||||||
|
bytes_per_cycle = "<unknown>";
|
||||||
|
} else {
|
||||||
bytes_per_sec = tensorflow::strings::HumanReadableNumBytes(
|
bytes_per_sec = tensorflow::strings::HumanReadableNumBytes(
|
||||||
bytes_accessed / (nsecs / 1e9));
|
bytes_accessed / (nsecs / 1e9));
|
||||||
bytes_per_cycle =
|
bytes_per_cycle =
|
||||||
tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles);
|
tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles);
|
||||||
} else {
|
}
|
||||||
bytes_per_sec = "<unknown>";
|
|
||||||
bytes_per_cycle = "<unknown>";
|
double cycles_percent = 0;
|
||||||
|
if (total_cycles > 0) {
|
||||||
|
cycles_percent = cycles / static_cast<double>(total_cycles) * 100;
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::strings::StrAppend(
|
tensorflow::strings::StrAppend(
|
||||||
@ -97,8 +103,7 @@ string HloExecutionProfile::ToString(
|
|||||||
":: "
|
":: "
|
||||||
"%12s/cycle :: "
|
"%12s/cycle :: "
|
||||||
"%s",
|
"%s",
|
||||||
cycles, cycles / static_cast<double>(total_cycles) * 100,
|
cycles, cycles_percent, cycles_to_microseconds(cycles),
|
||||||
cycles_to_microseconds(cycles),
|
|
||||||
flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(),
|
flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(),
|
||||||
bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str()));
|
bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str()));
|
||||||
};
|
};
|
||||||
@ -114,26 +119,30 @@ string HloExecutionProfile::ToString(
|
|||||||
for (const auto& item : items) {
|
for (const auto& item : items) {
|
||||||
const HloInstruction* hlo = item.first;
|
const HloInstruction* hlo = item.first;
|
||||||
tensorflow::strings::StrAppend(&result, "\n\t");
|
tensorflow::strings::StrAppend(&result, "\n\t");
|
||||||
int64 flops = hlo == nullptr ? -1 : cost_analysis.flop_count(*hlo);
|
const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo);
|
||||||
int64 bytes_accessed =
|
const int64 bytes_accessed =
|
||||||
hlo == nullptr ? -1 : cost_analysis.bytes_accessed(*hlo);
|
(hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo);
|
||||||
string display = hlo == nullptr ? "<none>" : hlo->ToString();
|
const string display = (hlo == nullptr) ? "<none>" : hlo->ToString();
|
||||||
append_item(item.second, flops, bytes_accessed, display);
|
append_item(item.second, flops, bytes_accessed, display);
|
||||||
}
|
}
|
||||||
|
|
||||||
MetricTableReport table;
|
if (total_cycles <= 0) {
|
||||||
table.SetMetricName("microseconds");
|
result += "****** 0 total cycles ******\n";
|
||||||
table.SetEntryName("ops");
|
} else {
|
||||||
table.SetShowCategoryTable();
|
MetricTableReport table;
|
||||||
for (const auto& item : items) {
|
table.SetMetricName("microseconds");
|
||||||
MetricTableReport::Entry entry;
|
table.SetEntryName("ops");
|
||||||
entry.text = item.first->ToString();
|
table.SetShowCategoryTable();
|
||||||
entry.short_text = item.first->ToString(/*compact_operands=*/true);
|
for (const auto& item : items) {
|
||||||
entry.category_text = item.first->ToCategory();
|
MetricTableReport::Entry entry;
|
||||||
entry.metric = cycles_to_microseconds(item.second);
|
entry.text = item.first->ToString();
|
||||||
table.AddEntry(std::move(entry));
|
entry.short_text = item.first->ToString(/*compact_operands=*/true);
|
||||||
|
entry.category_text = item.first->ToCategory();
|
||||||
|
entry.metric = cycles_to_microseconds(item.second);
|
||||||
|
table.AddEntry(std::move(entry));
|
||||||
|
}
|
||||||
|
result += table.MakeReport(cycles_to_microseconds(total_cycles));
|
||||||
}
|
}
|
||||||
result += table.MakeReport(cycles_to_microseconds(total_cycles));
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user