[TF:XLA] Replace most of HloProfilePrinter by a protocol buffer

This change replaces the meat of HloProfilePrinter with a protobuf
HloProfilePrinterData.  The original plan was to serialize HloProfilePrinter
into C++ source code and put that in a .cc file along with the string for the
xla::ProgramShape.  However, since we now directly serialize xla::ProgramShape
into a .o file, for consistency I think we should do the same thing for
HloProfilePrinter (instead of adding yet another output file to tfcompile).

The change itself is fairly simple, it is large mostly due to the mass renaming
I had to do.

PiperOrigin-RevId: 183158192
This commit is contained in:
Sanjoy Das 2018-01-24 16:10:12 -08:00 committed by TensorFlower Gardener
parent 7bf8ccdb4e
commit ffa63e57bd
23 changed files with 205 additions and 212 deletions

View File

@ -29,7 +29,7 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
arg_names_(static_data.arg_names), arg_names_(static_data.arg_names),
result_names_(static_data.result_names), result_names_(static_data.result_names),
program_shape_(static_data.program_shape), program_shape_(static_data.program_shape),
hlo_profile_printer_(static_data.hlo_profile_printer) { hlo_profile_printer_data_(static_data.hlo_profile_printer_data) {
// Allocate arg and temp buffers. // Allocate arg and temp buffers.
if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(

View File

@ -26,7 +26,7 @@ limitations under the License.
// never use this functionality. // never use this functionality.
namespace xla { namespace xla {
class ProgramShape; class ProgramShape;
class HloProfilePrinter; class HloProfilePrinterData;
} }
namespace tensorflow { namespace tensorflow {
@ -77,12 +77,14 @@ class XlaCompiledCpuFunction {
// [Optional] Arg and result shapes. // [Optional] Arg and result shapes.
const xla::ProgramShape* program_shape = nullptr; const xla::ProgramShape* program_shape = nullptr;
// [Optional] Profile printer. Null if profiling is disabled. // [Optional] Profile printer data. Null if profiling is disabled.
const xla::HloProfilePrinter* hlo_profile_printer = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr;
// [Optional] The number of profile counters expected in the profile counter // [Optional] The number of profile counters expected in the profile counter
// buffer by the generated code and hlo_profile_printer. 0 if profiling is // buffer by the generated code and hlo_profile_printer. 0 if profiling is
// disabled. // disabled. This information is already present in
// hlo_profile_printer_data but xla::HloProfilePrinterData is forward
// declared so we don't have access to that information here.
int64 profile_counters_size = 0; int64 profile_counters_size = 0;
}; };
@ -205,10 +207,12 @@ class XlaCompiledCpuFunction {
// program shape isn't available. // program shape isn't available.
const xla::ProgramShape* ProgramShape() const { return program_shape_; } const xla::ProgramShape* ProgramShape() const { return program_shape_; }
bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } bool hlo_profiling_enabled() const {
const xla::HloProfilePrinter& hlo_profile_printer() const { return hlo_profile_printer_data_ != nullptr;
}
const xla::HloProfilePrinterData& hlo_profile_printer_data() const {
assert(hlo_profiling_enabled()); assert(hlo_profiling_enabled());
return *hlo_profile_printer_; return *hlo_profile_printer_data_;
} }
private: private:
@ -234,7 +238,7 @@ class XlaCompiledCpuFunction {
const char** arg_names_ = nullptr; const char** arg_names_ = nullptr;
const char** result_names_ = nullptr; const char** result_names_ = nullptr;
const xla::ProgramShape* program_shape_ = nullptr; const xla::ProgramShape* program_shape_ = nullptr;
const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -182,10 +182,10 @@ XlaJitCompiledCpuFunction::Compile(
jit->static_data_.program_shape = jit->program_shape_.get(); jit->static_data_.program_shape = jit->program_shape_.get();
if (cpu_executable->hlo_profiling_enabled()) { if (cpu_executable->hlo_profiling_enabled()) {
jit->static_data_.hlo_profile_printer = jit->static_data_.hlo_profile_printer_data =
&cpu_executable->hlo_profile_printer(); &cpu_executable->hlo_profile_printer_data();
jit->static_data_.profile_counters_size = jit->static_data_.profile_counters_size =
cpu_executable->hlo_profile_printer().profile_counters_size(); cpu_executable->hlo_profile_printer_data().profile_counters_size();
} }
return std::move(jit_unique_ptr); return std::move(jit_unique_ptr);

View File

@ -29,6 +29,11 @@ xla_proto_library(
deps = ["//tensorflow/compiler/xla:xla_data_proto"], deps = ["//tensorflow/compiler/xla:xla_data_proto"],
) )
xla_proto_library(
name = "hlo_profile_printer_data",
srcs = ["hlo_profile_printer_data.proto"],
)
# Filegroup used to collect source files for dependency checking. # Filegroup used to collect source files for dependency checking.
filegroup( filegroup(
name = "c_srcs", name = "c_srcs",
@ -2267,6 +2272,7 @@ cc_library(
srcs = ["hlo_profile_printer.cc"], srcs = ["hlo_profile_printer.cc"],
hdrs = ["hlo_profile_printer.h"], hdrs = ["hlo_profile_printer.h"],
deps = [ deps = [
":hlo_profile_printer_data",
":human_readable_profile_builder", ":human_readable_profile_builder",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
], ],

View File

@ -485,7 +485,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx; std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx; std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map; std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer; std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
if (module->config().hlo_profiling_enabled()) { if (module->config().hlo_profiling_enabled()) {
hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module); hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
@ -505,8 +505,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
HloCostAnalysis cost_analysis(shape_size_bytes); HloCostAnalysis cost_analysis(shape_size_bytes);
TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis));
hlo_profile_printer = hlo_profile_printer_data =
CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis); CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis);
computation_to_profile_idx = computation_to_profile_idx =
hlo_profile_index_map->computation_to_profile_idx(); hlo_profile_index_map->computation_to_profile_idx();
} }
@ -619,7 +619,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
cpu_executable.reset(new ParallelCpuExecutable( cpu_executable.reset(new ParallelCpuExecutable(
std::move(jit), std::move(assignment), std::move(module), std::move(jit), std::move(assignment), std::move(module),
std::move(function_names), std::move(aligned_constants), std::move(function_names), std::move(aligned_constants),
std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) { if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable) static_cast<CpuExecutable&>(*cpu_executable)
@ -698,7 +698,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module)); jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new CpuExecutable( cpu_executable.reset(new CpuExecutable(
std::move(jit), std::move(assignment), std::move(module), function_name, std::move(jit), std::move(assignment), std::move(module), function_name,
std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) { if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable) static_cast<CpuExecutable&>(*cpu_executable)

View File

@ -55,9 +55,9 @@ CpuExecutable::CpuExecutable(
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name, const string& entry_function_name,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer, std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer), : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)), std::move(hlo_profile_index_map)),
jit_(std::move(jit)), jit_(std::move(jit)),
assignment_(std::move(assignment)) { assignment_(std::move(assignment)) {

View File

@ -51,7 +51,7 @@ class CpuExecutable : public Executable {
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name, const string& entry_function_name,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer, std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map); std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~CpuExecutable() override {} ~CpuExecutable() override {}

View File

@ -61,9 +61,9 @@ ParallelCpuExecutable::ParallelCpuExecutable(
std::unique_ptr<const HloInstructionMap<string>> function_names, std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>> std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
aligned_constants, aligned_constants,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer, std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer), : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)), std::move(hlo_profile_index_map)),
jit_(std::move(jit)), jit_(std::move(jit)),
assignment_(std::move(assignment)), assignment_(std::move(assignment)),

View File

@ -55,7 +55,7 @@ class ParallelCpuExecutable : public Executable {
std::unordered_map<const HloInstruction*, std::unordered_map<const HloInstruction*,
std::unique_ptr<unsigned char[]>> std::unique_ptr<unsigned char[]>>
aligned_constants, aligned_constants,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer, std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map); std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~ParallelCpuExecutable() override {} ~ParallelCpuExecutable() override {}

View File

@ -73,7 +73,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> Executable::ExecuteOnStreamWrapper(
std::unique_ptr<HloExecutionProfile> profile_ptr = std::unique_ptr<HloExecutionProfile> profile_ptr =
module_config().debug_options().xla_hlo_profile() && module_config().debug_options().xla_hlo_profile() &&
hlo_profiling_enabled() hlo_profiling_enabled()
? MakeUnique<HloExecutionProfile>(&hlo_profile_printer(), ? MakeUnique<HloExecutionProfile>(&hlo_profile_printer_data(),
&hlo_profile_index_map()) &hlo_profile_index_map())
: nullptr; : nullptr;

View File

@ -44,13 +44,14 @@ namespace xla {
// interface that is used for launching compiled programs across platforms. // interface that is used for launching compiled programs across platforms.
class Executable { class Executable {
public: public:
explicit Executable(std::unique_ptr<const HloModule> hlo_module, explicit Executable(
std::unique_ptr<HloProfilePrinter> hlo_profile_printer, std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: hlo_module_(std::move(hlo_module)), : hlo_module_(std::move(hlo_module)),
hlo_profile_printer_(std::move(hlo_profile_printer)), hlo_profile_printer_data_(std::move(hlo_profile_printer_data)),
hlo_profile_index_map_(std::move(hlo_profile_index_map)) { hlo_profile_index_map_(std::move(hlo_profile_index_map)) {
CHECK_EQ(hlo_profile_printer_.get() == nullptr, CHECK_EQ(hlo_profile_printer_data_.get() == nullptr,
hlo_profile_index_map_.get() == nullptr); hlo_profile_index_map_.get() == nullptr);
} }
virtual ~Executable() {} virtual ~Executable() {}
@ -116,9 +117,9 @@ class Executable {
"Equality test on this executable is not implemented."); "Equality test on this executable is not implemented.");
} }
const HloProfilePrinter& hlo_profile_printer() const { const HloProfilePrinterData& hlo_profile_printer_data() const {
CHECK(hlo_profiling_enabled()); CHECK(hlo_profiling_enabled());
return *hlo_profile_printer_; return *hlo_profile_printer_data_;
} }
const HloProfileIndexMap& hlo_profile_index_map() const { const HloProfileIndexMap& hlo_profile_index_map() const {
@ -129,7 +130,9 @@ class Executable {
// Returns whether this executable was compiled with HLO profilings support // Returns whether this executable was compiled with HLO profilings support
// enabled. If not, the caller should not expect an hlo_execution_profile // enabled. If not, the caller should not expect an hlo_execution_profile
// passed to ExecuteOnStream above to be populated during execution. // passed to ExecuteOnStream above to be populated during execution.
bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } bool hlo_profiling_enabled() const {
return hlo_profile_printer_data_ != nullptr;
}
const HloModule& module() const { return *hlo_module_; } const HloModule& module() const { return *hlo_module_; }
@ -179,7 +182,7 @@ class Executable {
// execution. // execution.
int64 execution_count_ = 0; int64 execution_count_ = 0;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer_; std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data_;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map_; std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map_;
}; };

View File

@ -593,14 +593,14 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
XLA_VLOG_LINES(2, thunk_schedule->ToString()); XLA_VLOG_LINES(2, thunk_schedule->ToString());
std::unique_ptr<HloProfileIndexMap> profile_index_map; std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinter> profile_printer; std::unique_ptr<HloProfilePrinterData> profile_printer;
if (module->config().hlo_profiling_enabled()) { if (module->config().hlo_profiling_enabled()) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
profile_index_map = MakeUnique<HloProfileIndexMap>(*module); profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
profile_printer = profile_printer =
CreateHloProfilePrinter(*profile_index_map, cost_analysis); CreateHloProfilePrinterData(*profile_index_map, cost_analysis);
} }
auto* gpu_executable = new GpuExecutable( auto* gpu_executable = new GpuExecutable(

View File

@ -116,9 +116,9 @@ GpuExecutable::GpuExecutable(
std::unique_ptr<const ThunkSchedule> thunk_schedule, std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer, std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer), : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)), std::move(hlo_profile_index_map)),
ptx_(ptx), ptx_(ptx),
cubin_(cubin), cubin_(cubin),

View File

@ -54,7 +54,7 @@ class GpuExecutable : public Executable {
std::unique_ptr<const ThunkSchedule> thunk_schedule, std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module, std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinter> hlo_profile_printer, std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map); std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
// This should be called after set_ir_module_string. // This should be called after set_ir_module_string.

View File

@ -40,83 +40,75 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) {
} }
} }
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter( std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
const HloProfileIndexMap& hlo_profile_index_map, const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis) { const HloCostAnalysis& cost_analysis) {
using HloComputationInfo = HloProfilePrinter::HloComputationInfo; using HloComputationInfo = HloProfilePrinterData::HloComputationInfo;
using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo; using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo;
HloComputationInfo* computation_infos = size_t profile_counters_size = hlo_profile_index_map.total_count();
new HloComputationInfo[hlo_profile_index_map.computation_count()];
// There are two "indices" in play here. The first one is the index of the std::unique_ptr<HloProfilePrinterData> profile_printer_data =
// HloComputationInfo or HloInstructionInfo in the array that contains said MakeUnique<HloProfilePrinterData>();
// HloComputationInfo or HloInstructionInfo. The second index is the index of profile_printer_data->set_profile_counters_size(profile_counters_size);
// the HloComputationInfo or HloInstructionInfo in the profile counters array, profile_printer_data->mutable_computation_infos()->Reserve(
// as decided by hlo_profile_index_map. The latter index is always referred hlo_profile_index_map.computation_count());
// to as "profile_index".
size_t computation_index_in_static_data = 0; const auto& computation_to_profile_idx_map =
size_t max_profile_index = hlo_profile_index_map.total_count(); hlo_profile_index_map.computation_to_profile_idx();
for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) {
CHECK_LT(pair.second, max_profile_index); // computation_to_profile_idx_map's order is not deterministic so create a
// deterministic computation_and_profile_idx_list so that we end up with a
// deterministic HloProfilePrinterData protobuf.
std::vector<std::pair<const HloComputation*, int64>>
computation_and_profile_idx_list(computation_to_profile_idx_map.begin(),
computation_to_profile_idx_map.end());
// The profile indices were computed deterministically in
// HloProfileIndexMap::HloProfileIndexMap.
c_sort(computation_and_profile_idx_list,
[](const std::pair<const HloComputation*, int64>& left,
const std::pair<const HloComputation*, int64>& right) {
return left.second < right.second;
});
for (const auto& pair : computation_and_profile_idx_list) {
CHECK_LT(pair.second, profile_counters_size);
const HloComputation* computation = pair.first; const HloComputation* computation = pair.first;
size_t current_computation_index = computation_index_in_static_data++;
HloComputationInfo* computation_info = HloComputationInfo* computation_info =
&computation_infos[current_computation_index]; profile_printer_data->add_computation_infos();
computation_info->name = strdup(computation->name().c_str()); computation_info->set_name(computation->name());
computation_info->profile_index = pair.second; computation_info->set_profile_index(pair.second);
computation_info->instructions = computation_info->mutable_instruction_infos()->Reserve(
new HloInstructionInfo[computation->instruction_count()]; computation->instruction_count());
computation_info->instructions_size = computation->instruction_count();
size_t instruction_index_in_static_data = 0;
for (const HloInstruction* hlo : computation->instructions()) { for (const HloInstruction* hlo : computation->instructions()) {
HloProfilePrinter::HloInstructionInfo* instruction_info = HloInstructionInfo* instruction_info =
&computation_info->instructions[instruction_index_in_static_data++]; computation_info->add_instruction_infos();
instruction_info->long_name = strdup(hlo->ToString().c_str()); instruction_info->set_long_name(hlo->ToString());
instruction_info->short_name = strdup( instruction_info->set_short_name(
hlo->ToString(HloPrintOptions().set_compact_operands(true)).c_str()); hlo->ToString(HloPrintOptions().set_compact_operands(true)));
instruction_info->category = strdup(hlo->ToCategory().c_str()); instruction_info->set_category(hlo->ToCategory());
instruction_info->flop_count = cost_analysis.flop_count(*hlo); instruction_info->set_flop_count(cost_analysis.flop_count(*hlo));
instruction_info->transcendental_count = instruction_info->set_transcendental_count(
cost_analysis.transcendental_count(*hlo); cost_analysis.transcendental_count(*hlo));
instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo); instruction_info->set_bytes_accessed(cost_analysis.bytes_accessed(*hlo));
instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo); instruction_info->set_optimal_seconds(
instruction_info->profile_index = cost_analysis.optimal_seconds(*hlo));
hlo_profile_index_map.GetProfileIndexFor(*hlo); instruction_info->set_profile_index(
CHECK_LT(instruction_info->profile_index, max_profile_index); hlo_profile_index_map.GetProfileIndexFor(*hlo));
} }
} }
auto deleter = [](HloProfilePrinter::HloComputationInfo* computation_infos, return profile_printer_data;
int64 computation_infos_size) {
for (int64 i = 0; i < computation_infos_size; i++) {
HloInstructionInfo* instruction_infos = computation_infos[i].instructions;
for (int64 j = 0; j < computation_infos[i].instructions_size; j++) {
// We can't make instruction_infos[j].long_name etc. non-const pointers
// since they may point into static storage, so we have a const_cast
// here.
free(const_cast<char*>(instruction_infos[j].long_name));
free(const_cast<char*>(instruction_infos[j].short_name));
free(const_cast<char*>(instruction_infos[j].category));
}
delete[] instruction_infos;
free(const_cast<char*>(computation_infos[i].name));
}
delete[] computation_infos;
};
return MakeUnique<HloProfilePrinter>(
computation_infos, hlo_profile_index_map.computation_count(),
/*profile_counters_size=*/max_profile_index, deleter);
} }
HloExecutionProfile::HloExecutionProfile( HloExecutionProfile::HloExecutionProfile(
const HloProfilePrinter* hlo_profile_printer, const HloProfilePrinterData* hlo_profile_printer_data,
const HloProfileIndexMap* hlo_profile_index_map) const HloProfileIndexMap* hlo_profile_index_map)
: hlo_profile_printer_(*hlo_profile_printer), : hlo_profile_printer_data_(*hlo_profile_printer_data),
hlo_profile_index_map_(*hlo_profile_index_map), hlo_profile_index_map_(*hlo_profile_index_map),
profile_counters_( profile_counters_(
/*count*/ hlo_profile_index_map_.total_count(), /*count*/ hlo_profile_index_map_.total_count(),

View File

@ -77,8 +77,8 @@ class HloProfileIndexMap {
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx_; std::unordered_map<const HloComputation*, int64> computation_to_profile_idx_;
}; };
// Create an instance of `HloProfilePrinter` that owns its memory. // Create an instance of `HloProfilePrinterData`.
std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter( std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
const HloProfileIndexMap& hlo_profile_index_map, const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis); const HloCostAnalysis& cost_analysis);
@ -90,7 +90,7 @@ class HloExecutionProfile {
public: public:
using DeviceDescription = perftools::gputools::DeviceDescription; using DeviceDescription = perftools::gputools::DeviceDescription;
HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer, HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data,
const HloProfileIndexMap* hlo_profile_index_map); const HloProfileIndexMap* hlo_profile_index_map);
// Record how many cycles this HLO took to execute. // Record how many cycles this HLO took to execute.
@ -117,11 +117,10 @@ class HloExecutionProfile {
// debugging; e.g. emits cycle counts, execution time at the nominal device // debugging; e.g. emits cycle counts, execution time at the nominal device
// frequency, and the effective throughput given the provided cost_analysis // frequency, and the effective throughput given the provided cost_analysis
// for the operations in a given computation. Returns an empty string if it // for the operations in a given computation. Returns an empty string if it
// wasn't possible to generate a printable version. cost_analysis should be a // wasn't possible to generate a printable version.
// clean analysis that can be used to visit the computation.
string ToString(const DeviceDescription& device_description) const { string ToString(const DeviceDescription& device_description) const {
return hlo_profile_printer_.ToString(profile_counters_.data(), return PrintHloProfile(hlo_profile_printer_data_, profile_counters_.data(),
device_description.clock_rate_ghz()); device_description.clock_rate_ghz());
} }
std::vector<int64>* mutable_profile_counters() { return &profile_counters_; } std::vector<int64>* mutable_profile_counters() { return &profile_counters_; }
@ -130,7 +129,7 @@ class HloExecutionProfile {
} }
private: private:
const HloProfilePrinter& hlo_profile_printer_; const HloProfilePrinterData& hlo_profile_printer_data_;
const HloProfileIndexMap& hlo_profile_index_map_; const HloProfileIndexMap& hlo_profile_index_map_;
// Stores per-Hlo profile counters. This is the only thing that changes when // Stores per-Hlo profile counters. This is the only thing that changes when

View File

@ -73,8 +73,8 @@ TEST_F(HloExecutionProfileTest, Basic) {
HloCostAnalysis cost_analysis(shape_size_function); HloCostAnalysis cost_analysis(shape_size_function);
HloProfileIndexMap profile_index_map(*hlo_module); HloProfileIndexMap profile_index_map(*hlo_module);
std::unique_ptr<HloProfilePrinter> profile_printer = std::unique_ptr<HloProfilePrinterData> profile_printer =
CreateHloProfilePrinter(profile_index_map, cost_analysis); CreateHloProfilePrinterData(profile_index_map, cost_analysis);
HloExecutionProfile execution_profile(profile_printer.get(), HloExecutionProfile execution_profile(profile_printer.get(),
&profile_index_map); &profile_index_map);

View File

@ -18,20 +18,20 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
namespace xla { namespace xla {
string HloProfilePrinter::ToString(const int64* counters, string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
double clock_rate_ghz) const { const int64* counters, double clock_rate_ghz) {
using HloComputationInfo = HloProfilePrinterData::HloComputationInfo;
using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo;
string result; string result;
for (int computation_idx = 0; computation_idx < computation_infos_size_; for (const HloComputationInfo& computation_info :
computation_idx++) { hlo_profile_printer_data.computation_infos()) {
const HloComputationInfo& computation = computation_infos_[computation_idx]; const auto& instruction_infos = computation_info.instruction_infos();
const HloInstructionInfo* instructions_begin = computation.instructions;
const HloInstructionInfo* instructions_end =
computation.instructions + computation.instructions_size;
bool any_instruction_profiled = bool any_instruction_profiled =
std::any_of(instructions_begin, instructions_end, std::any_of(instruction_infos.begin(), instruction_infos.end(),
[&](const HloInstructionInfo& instruction_info) { [&](const HloInstructionInfo& instruction_info) {
return counters[instruction_info.profile_index] != 0; return counters[instruction_info.profile_index()] != 0;
}); });
if (!any_instruction_profiled) { if (!any_instruction_profiled) {
@ -41,16 +41,19 @@ string HloProfilePrinter::ToString(const int64* counters,
// Once we start using this in AOT for real, we will probably need a more // Once we start using this in AOT for real, we will probably need a more
// minimal version of HumanReadableProfileBuilder. // minimal version of HumanReadableProfileBuilder.
HumanReadableProfileBuilder builder( HumanReadableProfileBuilder builder(
computation.name, counters[computation.profile_index], clock_rate_ghz); computation_info.name(), counters[computation_info.profile_index()],
clock_rate_ghz);
for (const auto* instruction = instructions_begin; for (const auto& instruction_info : instruction_infos) {
instruction != instructions_end; instruction++) {
builder.AddOp( builder.AddOp(
/*op_name=*/instruction->long_name, /*op_name=*/instruction_info.long_name(),
/*short_name=*/instruction->short_name, instruction->category, /*short_name=*/instruction_info.short_name(),
counters[instruction->profile_index], instruction->flop_count, instruction_info.category(),
instruction->transcendental_count, instruction->bytes_accessed, counters[instruction_info.profile_index()],
instruction->optimal_seconds); instruction_info.flop_count(),
instruction_info.transcendental_count(),
instruction_info.bytes_accessed(),
instruction_info.optimal_seconds());
} }
result += builder.ToString(); result += builder.ToString();
@ -58,10 +61,4 @@ string HloProfilePrinter::ToString(const int64* counters,
return result; return result;
} }
HloProfilePrinter::~HloProfilePrinter() {
if (deleter_) {
deleter_(computation_infos_, computation_infos_size_);
}
}
} // namespace xla } // namespace xla

View File

@ -20,84 +20,13 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
namespace xla { namespace xla {
// Instances of this class can pretty-print profile counters gathered from // Pretty-print an array of profile counters using hlo_profile_printer_data.
// running an XLA computation without having access to the backing module. string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
class HloProfilePrinter { const int64* counters, double clock_rate_ghz);
public:
// Holds meta information about an HloInstruction.
//
// The pointer-typed fields can be owning or non-owning -- this decision is
// manifested as the deleter_ function in the containing HloProfilePrinter.
struct HloInstructionInfo {
// Textual information for pretty printing.
const char* long_name;
const char* short_name;
const char* category;
// Metrics computed by HloCostAnalysis.
float flop_count;
float transcendental_count;
float bytes_accessed;
float optimal_seconds;
// The index into the profile counters array for the HloInstruction
// corresponding to this HloInstructionInfo.
int64 profile_index;
};
// Holds meta information about an HloComputation.
//
// The pointer-typed fields can be owning or non-owning -- this decision is
// manifested as the deleter_ function in the containing HloProfilePrinter.
struct HloComputationInfo {
const char* name;
// The index into the profile counters array for the HloInstruction
// corresponding to this HloComputationInfo.
int64 profile_index;
HloInstructionInfo* instructions;
int64 instructions_size;
};
HloProfilePrinter(
HloComputationInfo* computation_infos, int64 computation_infos_size,
int64 profile_counters_size,
std::function<void(HloComputationInfo*, int64)> deleter = nullptr)
: computation_infos_(computation_infos),
computation_infos_size_(computation_infos_size),
profile_counters_size_(profile_counters_size),
deleter_(std::move(deleter)) {}
HloProfilePrinter(HloProfilePrinter&& other) {
std::swap(other.computation_infos_, computation_infos_);
std::swap(other.computation_infos_size_, computation_infos_size_);
std::swap(other.deleter_, deleter_);
}
HloProfilePrinter(const HloProfilePrinter&) = delete;
HloProfilePrinter& operator=(const HloProfilePrinter&) = delete;
// Converts the profile counter sequence `counters` to a human readable string
// representation.
string ToString(const int64* counters, double clock_rate_ghz) const;
// Returns the size of the profile buffer expected by this printer.
int64 profile_counters_size() const { return profile_counters_size_; }
~HloProfilePrinter();
private:
// The `computation_infos_` field can be owning or non-owning -- this decision
// is manifested as the deleter_ function.
HloComputationInfo* computation_infos_ = nullptr;
int64 computation_infos_size_ = 0;
int64 profile_counters_size_ = 0;
std::function<void(HloComputationInfo*, int64)> deleter_;
};
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_

View File

@ -0,0 +1,60 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package xla;
option cc_enable_arenas = true;
// Describes how to pretty-print a profile counter array gathered for a specific
// HloModule.
message HloProfilePrinterData {
// Pretty-printer information about an HloInstruction.
message HloInstructionInfo {
string long_name = 1;
string short_name = 2;
string category = 3;
// Metrics computed by HloCostAnalysis.
float flop_count = 4;
float transcendental_count = 5;
float bytes_accessed = 6;
float optimal_seconds = 7;
// The index into the profile counters array for the HloInstruction
// corresponding to this HloInstructionInfo.
int64 profile_index = 8;
}
// Pretty-printer information about an HloComputation.
message HloComputationInfo {
string name = 1;
// The index into the profile counters array for the HloComputation
// corresponding to this HloComputationInfo.
int64 profile_index = 2;
// HloInstructionInfos for every HloInstruction in the HloComputation for
// corresponding to this HloComputattionInfo.
repeated HloInstructionInfo instruction_infos = 3;
}
// HloComputationInfos for every HloComputation in the HloModule.
repeated HloComputationInfo computation_infos = 1;
// The size of the profile counters array we will pretty-print.
int64 profile_counters_size = 2;
}

View File

@ -569,7 +569,7 @@ Service::ExecuteParallelAndRegisterResult(
se::Stream* stream = index_to_profiled_stream.second; se::Stream* stream = index_to_profiled_stream.second;
Executable* executable = executables[device]; Executable* executable = executables[device];
const HloModule& module = executable->module(); const HloModule& module = executable->module();
HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(), HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map()); &executable->hlo_profile_index_map());
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));

View File

@ -110,7 +110,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
Executable* executable = local_executable->executable(); Executable* executable = local_executable->executable();
HloExecutionProfile hlo_execution_profile( HloExecutionProfile hlo_execution_profile(
&executable->hlo_profile_printer(), &executable->hlo_profile_index_map()); &executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
Backend::StreamPtr stream_ptr, Backend::StreamPtr stream_ptr,

View File

@ -398,13 +398,11 @@ std::vector<std::pair<int64, int64>> CommonFactors(
// Removes illegal characters from filenames. // Removes illegal characters from filenames.
string SanitizeFileName(string file_name); string SanitizeFileName(string file_name);
// Simple wrapper around std::all_of.
template <typename Container, typename Predicate> template <typename Container, typename Predicate>
bool c_all_of(Container container, Predicate predicate) { bool c_all_of(Container container, Predicate predicate) {
return std::all_of(std::begin(container), std::end(container), predicate); return std::all_of(std::begin(container), std::end(container), predicate);
} }
// Simple wrapper around std::transform.
template <typename InputContainer, typename OutputIterator, template <typename InputContainer, typename OutputIterator,
typename UnaryOperation> typename UnaryOperation>
OutputIterator c_transform(InputContainer input_container, OutputIterator c_transform(InputContainer input_container,
@ -414,7 +412,6 @@ OutputIterator c_transform(InputContainer input_container,
output_iterator, unary_op); output_iterator, unary_op);
} }
// Simple wrapper around std::copy_if.
template <class InputContainer, class OutputIterator, class UnaryPredicate> template <class InputContainer, class OutputIterator, class UnaryPredicate>
OutputIterator c_copy_if(InputContainer input_container, OutputIterator c_copy_if(InputContainer input_container,
OutputIterator output_iterator, OutputIterator output_iterator,
@ -423,6 +420,11 @@ OutputIterator c_copy_if(InputContainer input_container,
output_iterator, predicate); output_iterator, predicate);
} }
template <class InputContainer, class Comparator>
void c_sort(InputContainer& input_container, Comparator comparator) {
std::sort(input_container.begin(), input_container.end(), comparator);
}
} // namespace xla } // namespace xla
#define XLA_LOG_LINES(SEV, STRING) \ #define XLA_LOG_LINES(SEV, STRING) \