[XLA] Redesign: implement GetComputationStats.
PiperOrigin-RevId: 190871262
This commit is contained in:
parent
59a1255354
commit
2b41d75654
@ -276,7 +276,12 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
|
||||
if (execution_profile != nullptr) {
|
||||
*execution_profile = response.profile();
|
||||
// TODO(b/74197823): Get execution stats for the graph and VLOG(1) them.
|
||||
if (VLOG_IS_ON(1)) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto execution_stats,
|
||||
ExecutionStatsAsString(computation, response.profile()));
|
||||
VLOG(1) << execution_stats;
|
||||
}
|
||||
}
|
||||
|
||||
return MakeUnique<GlobalData>(stub_, response.output());
|
||||
@ -402,8 +407,22 @@ StatusOr<ComputationStats> Client::GetComputationStats(
|
||||
StatusOr<ComputationStats> Client::GetComputationStats(
|
||||
const XlaComputation& computation,
|
||||
const DebugOptions& debug_options) const {
|
||||
return Unimplemented(
|
||||
"GetComputationStats is not yet implemented for XlaComputation");
|
||||
ComputationGraphStatsRequest request;
|
||||
|
||||
// TODO(b/74197823): Find a way to avoid the copy of the hlo proto.
|
||||
*request.mutable_computation() = computation.proto();
|
||||
*request.mutable_debug_options() = debug_options;
|
||||
ComputationStatsResponse response;
|
||||
|
||||
VLOG(1) << "making computation graph stats request";
|
||||
Status s = stub_->GetComputationGraphStats(&request, &response);
|
||||
VLOG(1) << "done with request";
|
||||
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
CHECK(response.has_stats());
|
||||
return response.stats();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
|
||||
@ -467,6 +486,28 @@ StatusOr<string> Client::ExecutionStatsAsString(
|
||||
return string("[Execution Statistics] not available.");
|
||||
}
|
||||
|
||||
StatusOr<string> Client::ExecutionStatsAsString(
|
||||
const XlaComputation& computation, const ExecutionProfile& profile) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto computation_stats,
|
||||
GetComputationStats(computation,
|
||||
legacy_flags::GetDebugOptionsFromFlags()));
|
||||
int64 total_flops =
|
||||
computation_stats.flop_count() + computation_stats.transcendental_count();
|
||||
if (profile.compute_time_ns() > 0) {
|
||||
int64 nanoseconds = profile.compute_time_ns();
|
||||
int64 cycle_count = profile.compute_cycle_count();
|
||||
double gflops = total_flops / nanoseconds;
|
||||
return tensorflow::strings::StrCat(
|
||||
"[Execution Statistics] flop count: ", computation_stats.flop_count(),
|
||||
", transcendental count: ", computation_stats.transcendental_count(),
|
||||
", compute execution time: ", nanoseconds, " nsec",
|
||||
", compute cycles: ", cycle_count, ", performance: ", gflops,
|
||||
"gflop/s");
|
||||
}
|
||||
return string("[Execution Statistics] not available.");
|
||||
}
|
||||
|
||||
StatusOr<ChannelHandle> Client::CreateChannelHandle() {
|
||||
CreateChannelHandleRequest request;
|
||||
CreateChannelHandleResponse response;
|
||||
|
@ -241,6 +241,8 @@ class Client {
|
||||
// ExecutionProfile returned from an execution of the computation.
|
||||
StatusOr<string> ExecutionStatsAsString(const Computation& computation,
|
||||
const ExecutionProfile& profile);
|
||||
StatusOr<string> ExecutionStatsAsString(const XlaComputation& computation,
|
||||
const ExecutionProfile& profile);
|
||||
|
||||
ServiceInterface* stub_; // Stub that this client is connected on.
|
||||
|
||||
|
@ -1452,7 +1452,25 @@ tensorflow::Status Service::GetComputationStats(
|
||||
|
||||
tensorflow::Status Service::GetComputationGraphStats(
|
||||
const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
|
||||
return Unimplemented("get-computation-graph-stats is not yet implemented");
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(arg->debug_options());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
HloModule::CreateFromProto(arg->computation(), config));
|
||||
|
||||
hlo_graph_dumper::MaybeDumpHloModule(*module,
|
||||
"computation statistics subject");
|
||||
|
||||
// Run HLO analysis to get the computation statistics.
|
||||
HloCostAnalysis analysis(
|
||||
execute_backend_->compiler()->ShapeSizeBytesFunction());
|
||||
|
||||
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
|
||||
|
||||
ComputationStats stats;
|
||||
stats.set_flop_count(analysis.flop_count());
|
||||
stats.set_transcendental_count(analysis.transcendental_count());
|
||||
*result->mutable_stats() = stats;
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
template <typename RequestT, typename ResponseT>
|
||||
|
Loading…
x
Reference in New Issue
Block a user