[XLA] Fixes some div-by-zero bugs.

Change: 153795265
This commit is contained in:
A. Unique TensorFlower 2017-04-20 21:32:47 -08:00 committed by TensorFlower Gardener
parent c1bd0fe248
commit b0594e1b82

View File

@ -70,6 +70,7 @@ string HloExecutionProfile::ToString(
string result; string result;
const int64 total_cycles = total_cycles_executed(computation); const int64 total_cycles = total_cycles_executed(computation);
double clock_rate_ghz = device_description.clock_rate_ghz(); double clock_rate_ghz = device_description.clock_rate_ghz();
CHECK_GE(clock_rate_ghz, 1e-9);
const auto cycles_to_microseconds = [&](double cycles) { const auto cycles_to_microseconds = [&](double cycles) {
return cycles / clock_rate_ghz / 1000.0; return cycles / clock_rate_ghz / 1000.0;
@ -80,14 +81,19 @@ string HloExecutionProfile::ToString(
double nsecs = cycles / clock_rate_ghz; double nsecs = cycles / clock_rate_ghz;
string bytes_per_sec; string bytes_per_sec;
string bytes_per_cycle; string bytes_per_cycle;
if (bytes_accessed >= 0) { if (cycles <= 0 || bytes_accessed < 0) {
bytes_per_sec = "<unknown>";
bytes_per_cycle = "<unknown>";
} else {
bytes_per_sec = tensorflow::strings::HumanReadableNumBytes( bytes_per_sec = tensorflow::strings::HumanReadableNumBytes(
bytes_accessed / (nsecs / 1e9)); bytes_accessed / (nsecs / 1e9));
bytes_per_cycle = bytes_per_cycle =
tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles); tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles);
} else { }
bytes_per_sec = "<unknown>";
bytes_per_cycle = "<unknown>"; double cycles_percent = 0;
if (total_cycles > 0) {
cycles_percent = cycles / static_cast<double>(total_cycles) * 100;
} }
tensorflow::strings::StrAppend( tensorflow::strings::StrAppend(
@ -97,8 +103,7 @@ string HloExecutionProfile::ToString(
":: " ":: "
"%12s/cycle :: " "%12s/cycle :: "
"%s", "%s",
cycles, cycles / static_cast<double>(total_cycles) * 100, cycles, cycles_percent, cycles_to_microseconds(cycles),
cycles_to_microseconds(cycles),
flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(), flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(),
bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str())); bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str()));
}; };
@ -114,26 +119,30 @@ string HloExecutionProfile::ToString(
for (const auto& item : items) { for (const auto& item : items) {
const HloInstruction* hlo = item.first; const HloInstruction* hlo = item.first;
tensorflow::strings::StrAppend(&result, "\n\t"); tensorflow::strings::StrAppend(&result, "\n\t");
int64 flops = hlo == nullptr ? -1 : cost_analysis.flop_count(*hlo); const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo);
int64 bytes_accessed = const int64 bytes_accessed =
hlo == nullptr ? -1 : cost_analysis.bytes_accessed(*hlo); (hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo);
string display = hlo == nullptr ? "<none>" : hlo->ToString(); const string display = (hlo == nullptr) ? "<none>" : hlo->ToString();
append_item(item.second, flops, bytes_accessed, display); append_item(item.second, flops, bytes_accessed, display);
} }
MetricTableReport table; if (total_cycles <= 0) {
table.SetMetricName("microseconds"); result += "****** 0 total cycles ******\n";
table.SetEntryName("ops"); } else {
table.SetShowCategoryTable(); MetricTableReport table;
for (const auto& item : items) { table.SetMetricName("microseconds");
MetricTableReport::Entry entry; table.SetEntryName("ops");
entry.text = item.first->ToString(); table.SetShowCategoryTable();
entry.short_text = item.first->ToString(/*compact_operands=*/true); for (const auto& item : items) {
entry.category_text = item.first->ToCategory(); MetricTableReport::Entry entry;
entry.metric = cycles_to_microseconds(item.second); entry.text = item.first->ToString();
table.AddEntry(std::move(entry)); entry.short_text = item.first->ToString(/*compact_operands=*/true);
entry.category_text = item.first->ToCategory();
entry.metric = cycles_to_microseconds(item.second);
table.AddEntry(std::move(entry));
}
result += table.MakeReport(cycles_to_microseconds(total_cycles));
} }
result += table.MakeReport(cycles_to_microseconds(total_cycles));
return result; return result;
} }