[TF:XLA] Replace most of HloProfilePrinter by a protocol buffer

This change replaces the meat of HloProfilePrinter with a protobuf
HloProfilePrinterData.  The original plan was to serialize HloProfilePrinter
into C++ source code and put that in a .cc file along with the string for the
xla::ProgramShape.  However, since we now directly serialize xla::ProgramShape
into a .o file, for consistency I think we should do the same thing for
HloProfilePrinter (instead of adding yet another output file to tfcompile).

The change itself is fairly simple, it is large mostly due to the mass renaming
I had to do.

PiperOrigin-RevId: 183158192
This commit is contained in:
Sanjoy Das 2018-01-24 16:10:12 -08:00 committed by TensorFlower Gardener
parent 7bf8ccdb4e
commit ffa63e57bd
23 changed files with 205 additions and 212 deletions

View File

@ -29,7 +29,7 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
arg_names_(static_data.arg_names),
result_names_(static_data.result_names),
program_shape_(static_data.program_shape),
hlo_profile_printer_(static_data.hlo_profile_printer) {
hlo_profile_printer_data_(static_data.hlo_profile_printer_data) {
// Allocate arg and temp buffers.
if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(

View File

@ -26,7 +26,7 @@ limitations under the License.
// never use this functionality.
namespace xla {
class ProgramShape;
class HloProfilePrinter;
class HloProfilePrinterData;
}
namespace tensorflow {
@ -77,12 +77,14 @@ class XlaCompiledCpuFunction {
// [Optional] Arg and result shapes.
const xla::ProgramShape* program_shape = nullptr;
// [Optional] Profile printer. Null if profiling is disabled.
const xla::HloProfilePrinter* hlo_profile_printer = nullptr;
// [Optional] Profile printer data. Null if profiling is disabled.
const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr;
// [Optional] The number of profile counters expected in the profile counter
// buffer by the generated code and hlo_profile_printer. 0 if profiling is
// disabled.
// disabled. This information is already present in
// hlo_profile_printer_data but xla::HloProfilePrinterData is forward
// declared so we don't have access to that information here.
int64 profile_counters_size = 0;
};
@ -205,10 +207,12 @@ class XlaCompiledCpuFunction {
// program shape isn't available.
const xla::ProgramShape* ProgramShape() const { return program_shape_; }
bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; }
const xla::HloProfilePrinter& hlo_profile_printer() const {
bool hlo_profiling_enabled() const {
return hlo_profile_printer_data_ != nullptr;
}
const xla::HloProfilePrinterData& hlo_profile_printer_data() const {
assert(hlo_profiling_enabled());
return *hlo_profile_printer_;
return *hlo_profile_printer_data_;
}
private:
@ -234,7 +238,7 @@ class XlaCompiledCpuFunction {
const char** arg_names_ = nullptr;
const char** result_names_ = nullptr;
const xla::ProgramShape* program_shape_ = nullptr;
const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr;
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
};
} // namespace tensorflow

View File

@ -182,10 +182,10 @@ XlaJitCompiledCpuFunction::Compile(
jit->static_data_.program_shape = jit->program_shape_.get();
if (cpu_executable->hlo_profiling_enabled()) {
jit->static_data_.hlo_profile_printer =
&cpu_executable->hlo_profile_printer();
jit->static_data_.hlo_profile_printer_data =
&cpu_executable->hlo_profile_printer_data();
jit->static_data_.profile_counters_size =
cpu_executable->hlo_profile_printer().profile_counters_size();
cpu_executable->hlo_profile_printer_data().profile_counters_size();
}
return std::move(jit_unique_ptr);

View File

@ -29,6 +29,11 @@ xla_proto_library(
deps = ["//tensorflow/compiler/xla:xla_data_proto"],
)
xla_proto_library(
name = "hlo_profile_printer_data",
srcs = ["hlo_profile_printer_data.proto"],
)
# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
@ -2267,6 +2272,7 @@ cc_library(
srcs = ["hlo_profile_printer.cc"],
hdrs = ["hlo_profile_printer.h"],
deps = [
":hlo_profile_printer_data",
":human_readable_profile_builder",
"//tensorflow/compiler/xla:types",
],

View File

@ -485,7 +485,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer;
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
if (module->config().hlo_profiling_enabled()) {
hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
@ -505,8 +505,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
HloCostAnalysis cost_analysis(shape_size_bytes);
TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis));
hlo_profile_printer =
CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis);
hlo_profile_printer_data =
CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis);
computation_to_profile_idx =
hlo_profile_index_map->computation_to_profile_idx();
}
@ -619,7 +619,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
cpu_executable.reset(new ParallelCpuExecutable(
std::move(jit), std::move(assignment), std::move(module),
std::move(function_names), std::move(aligned_constants),
std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable)
@ -698,7 +698,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new CpuExecutable(
std::move(jit), std::move(assignment), std::move(module), function_name,
std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable)

View File

@ -55,9 +55,9 @@ CpuExecutable::CpuExecutable(
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)),
assignment_(std::move(assignment)) {

View File

@ -51,7 +51,7 @@ class CpuExecutable : public Executable {
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~CpuExecutable() override {}

View File

@ -61,9 +61,9 @@ ParallelCpuExecutable::ParallelCpuExecutable(
std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
aligned_constants,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)),
assignment_(std::move(assignment)),

View File

@ -55,7 +55,7 @@ class ParallelCpuExecutable : public Executable {
std::unordered_map<const HloInstruction*,
std::unique_ptr<unsigned char[]>>
aligned_constants,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~ParallelCpuExecutable() override {}

View File

@ -73,7 +73,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> Executable::ExecuteOnStreamWrapper(
std::unique_ptr<HloExecutionProfile> profile_ptr =
module_config().debug_options().xla_hlo_profile() &&
hlo_profiling_enabled()
? MakeUnique<HloExecutionProfile>(&hlo_profile_printer(),
? MakeUnique<HloExecutionProfile>(&hlo_profile_printer_data(),
&hlo_profile_index_map())
: nullptr;

View File

@ -44,13 +44,14 @@ namespace xla {
// interface that is used for launching compiled programs across platforms.
class Executable {
public:
explicit Executable(std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
explicit Executable(
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: hlo_module_(std::move(hlo_module)),
hlo_profile_printer_(std::move(hlo_profile_printer)),
hlo_profile_printer_data_(std::move(hlo_profile_printer_data)),
hlo_profile_index_map_(std::move(hlo_profile_index_map)) {
CHECK_EQ(hlo_profile_printer_.get() == nullptr,
CHECK_EQ(hlo_profile_printer_data_.get() == nullptr,
hlo_profile_index_map_.get() == nullptr);
}
virtual ~Executable() {}
@ -116,9 +117,9 @@ class Executable {
"Equality test on this executable is not implemented.");
}
const HloProfilePrinter& hlo_profile_printer() const {
const HloProfilePrinterData& hlo_profile_printer_data() const {
CHECK(hlo_profiling_enabled());
return *hlo_profile_printer_;
return *hlo_profile_printer_data_;
}
const HloProfileIndexMap& hlo_profile_index_map() const {
@ -129,7 +130,9 @@ class Executable {
// Returns whether this executable was compiled with HLO profilings support
// enabled. If not, the caller should not expect an hlo_execution_profile
// passed to ExecuteOnStream above to be populated during execution.
bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; }
bool hlo_profiling_enabled() const {
return hlo_profile_printer_data_ != nullptr;
}
const HloModule& module() const { return *hlo_module_; }
@ -179,7 +182,7 @@ class Executable {
// execution.
int64 execution_count_ = 0;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer_;
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data_;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map_;
};

View File

@ -593,14 +593,14 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
XLA_VLOG_LINES(2, thunk_schedule->ToString());
std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinter> profile_printer;
std::unique_ptr<HloProfilePrinterData> profile_printer;
if (module->config().hlo_profiling_enabled()) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinter(*profile_index_map, cost_analysis);
CreateHloProfilePrinterData(*profile_index_map, cost_analysis);
}
auto* gpu_executable = new GpuExecutable(

View File

@ -116,9 +116,9 @@ GpuExecutable::GpuExecutable(
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer),
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
ptx_(ptx),
cubin_(cubin),

View File

@ -54,7 +54,7 @@ class GpuExecutable : public Executable {
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
// This should be called after set_ir_module_string.

View File

@ -40,83 +40,75 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) {
}
}
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis) {
using HloComputationInfo = HloProfilePrinter::HloComputationInfo;
using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo;
using HloComputationInfo = HloProfilePrinterData::HloComputationInfo;
using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo;
HloComputationInfo* computation_infos =
new HloComputationInfo[hlo_profile_index_map.computation_count()];
size_t profile_counters_size = hlo_profile_index_map.total_count();
// There are two "indices" in play here. The first one is the index of the
// HloComputationInfo or HloInstructionInfo in the array that contains said
// HloComputationInfo or HloInstructionInfo. The second index is the index of
// the HloComputationInfo or HloInstructionInfo in the profile counters array,
// as decided by hlo_profile_index_map. The latter index is always referred
// to as "profile_index".
std::unique_ptr<HloProfilePrinterData> profile_printer_data =
MakeUnique<HloProfilePrinterData>();
profile_printer_data->set_profile_counters_size(profile_counters_size);
profile_printer_data->mutable_computation_infos()->Reserve(
hlo_profile_index_map.computation_count());
size_t computation_index_in_static_data = 0;
size_t max_profile_index = hlo_profile_index_map.total_count();
for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) {
CHECK_LT(pair.second, max_profile_index);
const auto& computation_to_profile_idx_map =
hlo_profile_index_map.computation_to_profile_idx();
// computation_to_profile_idx_map's order is not deterministic so create a
// deterministic computation_and_profile_idx_list so that we end up with a
// deterministic HloProfilePrinterData protobuf.
std::vector<std::pair<const HloComputation*, int64>>
computation_and_profile_idx_list(computation_to_profile_idx_map.begin(),
computation_to_profile_idx_map.end());
// The profile indices were computed deterministically in
// HloProfileIndexMap::HloProfileIndexMap.
c_sort(computation_and_profile_idx_list,
[](const std::pair<const HloComputation*, int64>& left,
const std::pair<const HloComputation*, int64>& right) {
return left.second < right.second;
});
for (const auto& pair : computation_and_profile_idx_list) {
CHECK_LT(pair.second, profile_counters_size);
const HloComputation* computation = pair.first;
size_t current_computation_index = computation_index_in_static_data++;
HloComputationInfo* computation_info =
&computation_infos[current_computation_index];
profile_printer_data->add_computation_infos();
computation_info->name = strdup(computation->name().c_str());
computation_info->profile_index = pair.second;
computation_info->instructions =
new HloInstructionInfo[computation->instruction_count()];
computation_info->instructions_size = computation->instruction_count();
computation_info->set_name(computation->name());
computation_info->set_profile_index(pair.second);
computation_info->mutable_instruction_infos()->Reserve(
computation->instruction_count());
size_t instruction_index_in_static_data = 0;
for (const HloInstruction* hlo : computation->instructions()) {
HloProfilePrinter::HloInstructionInfo* instruction_info =
&computation_info->instructions[instruction_index_in_static_data++];
instruction_info->long_name = strdup(hlo->ToString().c_str());
instruction_info->short_name = strdup(
hlo->ToString(HloPrintOptions().set_compact_operands(true)).c_str());
instruction_info->category = strdup(hlo->ToCategory().c_str());
instruction_info->flop_count = cost_analysis.flop_count(*hlo);
instruction_info->transcendental_count =
cost_analysis.transcendental_count(*hlo);
instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo);
instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo);
instruction_info->profile_index =
hlo_profile_index_map.GetProfileIndexFor(*hlo);
CHECK_LT(instruction_info->profile_index, max_profile_index);
HloInstructionInfo* instruction_info =
computation_info->add_instruction_infos();
instruction_info->set_long_name(hlo->ToString());
instruction_info->set_short_name(
hlo->ToString(HloPrintOptions().set_compact_operands(true)));
instruction_info->set_category(hlo->ToCategory());
instruction_info->set_flop_count(cost_analysis.flop_count(*hlo));
instruction_info->set_transcendental_count(
cost_analysis.transcendental_count(*hlo));
instruction_info->set_bytes_accessed(cost_analysis.bytes_accessed(*hlo));
instruction_info->set_optimal_seconds(
cost_analysis.optimal_seconds(*hlo));
instruction_info->set_profile_index(
hlo_profile_index_map.GetProfileIndexFor(*hlo));
}
}
auto deleter = [](HloProfilePrinter::HloComputationInfo* computation_infos,
int64 computation_infos_size) {
for (int64 i = 0; i < computation_infos_size; i++) {
HloInstructionInfo* instruction_infos = computation_infos[i].instructions;
for (int64 j = 0; j < computation_infos[i].instructions_size; j++) {
// We can't make instruction_infos[j].long_name etc. non-const pointers
// since they may point into static storage, so we have a const_cast
// here.
free(const_cast<char*>(instruction_infos[j].long_name));
free(const_cast<char*>(instruction_infos[j].short_name));
free(const_cast<char*>(instruction_infos[j].category));
}
delete[] instruction_infos;
free(const_cast<char*>(computation_infos[i].name));
}
delete[] computation_infos;
};
return MakeUnique<HloProfilePrinter>(
computation_infos, hlo_profile_index_map.computation_count(),
/*profile_counters_size=*/max_profile_index, deleter);
return profile_printer_data;
}
HloExecutionProfile::HloExecutionProfile(
const HloProfilePrinter* hlo_profile_printer,
const HloProfilePrinterData* hlo_profile_printer_data,
const HloProfileIndexMap* hlo_profile_index_map)
: hlo_profile_printer_(*hlo_profile_printer),
: hlo_profile_printer_data_(*hlo_profile_printer_data),
hlo_profile_index_map_(*hlo_profile_index_map),
profile_counters_(
/*count*/ hlo_profile_index_map_.total_count(),

View File

@ -77,8 +77,8 @@ class HloProfileIndexMap {
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx_;
};
// Create an instance of `HloProfilePrinter` that owns its memory.
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
// Create an instance of `HloProfilePrinterData`.
std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis);
@ -90,7 +90,7 @@ class HloExecutionProfile {
public:
using DeviceDescription = perftools::gputools::DeviceDescription;
HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer,
HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data,
const HloProfileIndexMap* hlo_profile_index_map);
// Record how many cycles this HLO took to execute.
@ -117,10 +117,9 @@ class HloExecutionProfile {
// debugging; e.g. emits cycle counts, execution time at the nominal device
// frequency, and the effective throughput given the provided cost_analysis
// for the operations in a given computation. Returns an empty string if it
// wasn't possible to generate a printable version. cost_analysis should be a
// clean analysis that can be used to visit the computation.
// wasn't possible to generate a printable version.
string ToString(const DeviceDescription& device_description) const {
return hlo_profile_printer_.ToString(profile_counters_.data(),
return PrintHloProfile(hlo_profile_printer_data_, profile_counters_.data(),
device_description.clock_rate_ghz());
}
@ -130,7 +129,7 @@ class HloExecutionProfile {
}
private:
const HloProfilePrinter& hlo_profile_printer_;
const HloProfilePrinterData& hlo_profile_printer_data_;
const HloProfileIndexMap& hlo_profile_index_map_;
// Stores per-Hlo profile counters. This is the only thing that changes when

View File

@ -73,8 +73,8 @@ TEST_F(HloExecutionProfileTest, Basic) {
HloCostAnalysis cost_analysis(shape_size_function);
HloProfileIndexMap profile_index_map(*hlo_module);
std::unique_ptr<HloProfilePrinter> profile_printer =
CreateHloProfilePrinter(profile_index_map, cost_analysis);
std::unique_ptr<HloProfilePrinterData> profile_printer =
CreateHloProfilePrinterData(profile_index_map, cost_analysis);
HloExecutionProfile execution_profile(profile_printer.get(),
&profile_index_map);

View File

@ -18,20 +18,20 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
namespace xla {
string HloProfilePrinter::ToString(const int64* counters,
double clock_rate_ghz) const {
string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
const int64* counters, double clock_rate_ghz) {
using HloComputationInfo = HloProfilePrinterData::HloComputationInfo;
using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo;
string result;
for (int computation_idx = 0; computation_idx < computation_infos_size_;
computation_idx++) {
const HloComputationInfo& computation = computation_infos_[computation_idx];
const HloInstructionInfo* instructions_begin = computation.instructions;
const HloInstructionInfo* instructions_end =
computation.instructions + computation.instructions_size;
for (const HloComputationInfo& computation_info :
hlo_profile_printer_data.computation_infos()) {
const auto& instruction_infos = computation_info.instruction_infos();
bool any_instruction_profiled =
std::any_of(instructions_begin, instructions_end,
std::any_of(instruction_infos.begin(), instruction_infos.end(),
[&](const HloInstructionInfo& instruction_info) {
return counters[instruction_info.profile_index] != 0;
return counters[instruction_info.profile_index()] != 0;
});
if (!any_instruction_profiled) {
@ -41,16 +41,19 @@ string HloProfilePrinter::ToString(const int64* counters,
// Once we start using this in AOT for real, we will probably need a more
// minimal version of HumanReadableProfileBuilder.
HumanReadableProfileBuilder builder(
computation.name, counters[computation.profile_index], clock_rate_ghz);
computation_info.name(), counters[computation_info.profile_index()],
clock_rate_ghz);
for (const auto* instruction = instructions_begin;
instruction != instructions_end; instruction++) {
for (const auto& instruction_info : instruction_infos) {
builder.AddOp(
/*op_name=*/instruction->long_name,
/*short_name=*/instruction->short_name, instruction->category,
counters[instruction->profile_index], instruction->flop_count,
instruction->transcendental_count, instruction->bytes_accessed,
instruction->optimal_seconds);
/*op_name=*/instruction_info.long_name(),
/*short_name=*/instruction_info.short_name(),
instruction_info.category(),
counters[instruction_info.profile_index()],
instruction_info.flop_count(),
instruction_info.transcendental_count(),
instruction_info.bytes_accessed(),
instruction_info.optimal_seconds());
}
result += builder.ToString();
@ -58,10 +61,4 @@ string HloProfilePrinter::ToString(const int64* counters,
return result;
}
HloProfilePrinter::~HloProfilePrinter() {
if (deleter_) {
deleter_(computation_infos_, computation_infos_size_);
}
}
} // namespace xla

View File

@ -20,84 +20,13 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
// Instances of this class can pretty-print profile counters gathered from
// running an XLA computation without having access to the backing module.
class HloProfilePrinter {
public:
// Holds meta information about an HloInstruction.
//
// The pointer-typed fields can be owning or non-owning -- this decision is
// manifested as the deleter_ function in the containing HloProfilePrinter.
struct HloInstructionInfo {
// Textual information for pretty printing.
const char* long_name;
const char* short_name;
const char* category;
// Metrics computed by HloCostAnalysis.
float flop_count;
float transcendental_count;
float bytes_accessed;
float optimal_seconds;
// The index into the profile counters array for the HloInstruction
// corresponding to this HloInstructionInfo.
int64 profile_index;
};
// Holds meta information about an HloComputation.
//
// The pointer-typed fields can be owning or non-owning -- this decision is
// manifested as the deleter_ function in the containing HloProfilePrinter.
struct HloComputationInfo {
const char* name;
// The index into the profile counters array for the HloInstruction
// corresponding to this HloComputationInfo.
int64 profile_index;
HloInstructionInfo* instructions;
int64 instructions_size;
};
HloProfilePrinter(
HloComputationInfo* computation_infos, int64 computation_infos_size,
int64 profile_counters_size,
std::function<void(HloComputationInfo*, int64)> deleter = nullptr)
: computation_infos_(computation_infos),
computation_infos_size_(computation_infos_size),
profile_counters_size_(profile_counters_size),
deleter_(std::move(deleter)) {}
HloProfilePrinter(HloProfilePrinter&& other) {
std::swap(other.computation_infos_, computation_infos_);
std::swap(other.computation_infos_size_, computation_infos_size_);
std::swap(other.deleter_, deleter_);
}
HloProfilePrinter(const HloProfilePrinter&) = delete;
HloProfilePrinter& operator=(const HloProfilePrinter&) = delete;
// Converts the profile counter sequence `counters` to a human readable string
// representation.
string ToString(const int64* counters, double clock_rate_ghz) const;
// Returns the size of the profile buffer expected by this printer.
int64 profile_counters_size() const { return profile_counters_size_; }
~HloProfilePrinter();
private:
// The `computation_infos_` field can be owning or non-owning -- this decision
// is manifested as the deleter_ function.
HloComputationInfo* computation_infos_ = nullptr;
int64 computation_infos_size_ = 0;
int64 profile_counters_size_ = 0;
std::function<void(HloComputationInfo*, int64)> deleter_;
};
// Pretty-print an array of profile counters using hlo_profile_printer_data.
string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
const int64* counters, double clock_rate_ghz);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_

View File

@ -0,0 +1,60 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package xla;
option cc_enable_arenas = true;
// Describes how to pretty-print a profile counter array gathered for a specific
// HloModule.
message HloProfilePrinterData {
// Pretty-printer information about an HloInstruction.
message HloInstructionInfo {
string long_name = 1;
string short_name = 2;
string category = 3;
// Metrics computed by HloCostAnalysis.
float flop_count = 4;
float transcendental_count = 5;
float bytes_accessed = 6;
float optimal_seconds = 7;
// The index into the profile counters array for the HloInstruction
// corresponding to this HloInstructionInfo.
int64 profile_index = 8;
}
// Pretty-printer information about an HloComputation.
message HloComputationInfo {
string name = 1;
// The index into the profile counters array for the HloComputation
// corresponding to this HloComputationInfo.
int64 profile_index = 2;
// HloInstructionInfos for every HloInstruction in the HloComputation for
// corresponding to this HloComputattionInfo.
repeated HloInstructionInfo instruction_infos = 3;
}
// HloComputationInfos for every HloComputation in the HloModule.
repeated HloComputationInfo computation_infos = 1;
// The size of the profile counters array we will pretty-print.
int64 profile_counters_size = 2;
}

View File

@ -569,7 +569,7 @@ Service::ExecuteParallelAndRegisterResult(
se::Stream* stream = index_to_profiled_stream.second;
Executable* executable = executables[device];
const HloModule& module = executable->module();
HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(),
HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
TF_RETURN_IF_ERROR(
executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));

View File

@ -110,7 +110,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
Executable* executable = local_executable->executable();
HloExecutionProfile hlo_execution_profile(
&executable->hlo_profile_printer(), &executable->hlo_profile_index_map());
&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
TF_ASSERT_OK_AND_ASSIGN(
Backend::StreamPtr stream_ptr,

View File

@ -398,13 +398,11 @@ std::vector<std::pair<int64, int64>> CommonFactors(
// Removes illegal characters from filenames.
string SanitizeFileName(string file_name);
// Simple wrapper around std::all_of.
template <typename Container, typename Predicate>
bool c_all_of(Container container, Predicate predicate) {
return std::all_of(std::begin(container), std::end(container), predicate);
}
// Simple wrapper around std::transform.
template <typename InputContainer, typename OutputIterator,
typename UnaryOperation>
OutputIterator c_transform(InputContainer input_container,
@ -414,7 +412,6 @@ OutputIterator c_transform(InputContainer input_container,
output_iterator, unary_op);
}
// Simple wrapper around std::copy_if.
template <class InputContainer, class OutputIterator, class UnaryPredicate>
OutputIterator c_copy_if(InputContainer input_container,
OutputIterator output_iterator,
@ -423,6 +420,11 @@ OutputIterator c_copy_if(InputContainer input_container,
output_iterator, predicate);
}
template <class InputContainer, class Comparator>
void c_sort(InputContainer& input_container, Comparator comparator) {
std::sort(input_container.begin(), input_container.end(), comparator);
}
} // namespace xla
#define XLA_LOG_LINES(SEV, STRING) \