[XLA] Redesign: implement GetComputationStats.

PiperOrigin-RevId: 190871262
This commit is contained in:
A. Unique TensorFlower 2018-03-28 18:54:09 -07:00 committed by TensorFlower Gardener
parent 59a1255354
commit 2b41d75654
3 changed files with 65 additions and 4 deletions

View File

@ -276,7 +276,12 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
if (execution_profile != nullptr) {
*execution_profile = response.profile();
// TODO(b/74197823): Get execution stats for the graph and VLOG(1) them.
if (VLOG_IS_ON(1)) {
TF_ASSIGN_OR_RETURN(
auto execution_stats,
ExecutionStatsAsString(computation, response.profile()));
VLOG(1) << execution_stats;
}
}
return MakeUnique<GlobalData>(stub_, response.output());
@ -402,8 +407,22 @@ StatusOr<ComputationStats> Client::GetComputationStats(
StatusOr<ComputationStats> Client::GetComputationStats(
const XlaComputation& computation,
const DebugOptions& debug_options) const {
return Unimplemented(
"GetComputationStats is not yet implemented for XlaComputation");
ComputationGraphStatsRequest request;
// TODO(b/74197823): Find a way to avoid the copy of the hlo proto.
*request.mutable_computation() = computation.proto();
*request.mutable_debug_options() = debug_options;
ComputationStatsResponse response;
VLOG(1) << "making computation graph stats request";
Status s = stub_->GetComputationGraphStats(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
CHECK(response.has_stats());
return response.stats();
}
StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
@ -467,6 +486,28 @@ StatusOr<string> Client::ExecutionStatsAsString(
return string("[Execution Statistics] not available.");
}
StatusOr<string> Client::ExecutionStatsAsString(
const XlaComputation& computation, const ExecutionProfile& profile) {
TF_ASSIGN_OR_RETURN(
auto computation_stats,
GetComputationStats(computation,
legacy_flags::GetDebugOptionsFromFlags()));
int64 total_flops =
computation_stats.flop_count() + computation_stats.transcendental_count();
if (profile.compute_time_ns() > 0) {
int64 nanoseconds = profile.compute_time_ns();
int64 cycle_count = profile.compute_cycle_count();
double gflops = total_flops / nanoseconds;
return tensorflow::strings::StrCat(
"[Execution Statistics] flop count: ", computation_stats.flop_count(),
", transcendental count: ", computation_stats.transcendental_count(),
", compute execution time: ", nanoseconds, " nsec",
", compute cycles: ", cycle_count, ", performance: ", gflops,
"gflop/s");
}
return string("[Execution Statistics] not available.");
}
StatusOr<ChannelHandle> Client::CreateChannelHandle() {
CreateChannelHandleRequest request;
CreateChannelHandleResponse response;

View File

@ -241,6 +241,8 @@ class Client {
// ExecutionProfile returned from an execution of the computation.
StatusOr<string> ExecutionStatsAsString(const Computation& computation,
const ExecutionProfile& profile);
StatusOr<string> ExecutionStatsAsString(const XlaComputation& computation,
const ExecutionProfile& profile);
ServiceInterface* stub_; // Stub that this client is connected on.

View File

@ -1452,7 +1452,25 @@ tensorflow::Status Service::GetComputationStats(
tensorflow::Status Service::GetComputationGraphStats(
const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
return Unimplemented("get-computation-graph-stats is not yet implemented");
HloModuleConfig config;
config.set_debug_options(arg->debug_options());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(arg->computation(), config));
hlo_graph_dumper::MaybeDumpHloModule(*module,
"computation statistics subject");
// Run HLO analysis to get the computation statistics.
HloCostAnalysis analysis(
execute_backend_->compiler()->ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
ComputationStats stats;
stats.set_flop_count(analysis.flop_count());
stats.set_transcendental_count(analysis.transcendental_count());
*result->mutable_stats() = stats;
return tensorflow::Status::OK();
}
template <typename RequestT, typename ResponseT>