From 2b41d75654012f917cda1b54aee090d73086ab84 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 28 Mar 2018 18:54:09 -0700 Subject: [PATCH] [XLA] Redesign: implement GetComputationStats. PiperOrigin-RevId: 190871262 --- tensorflow/compiler/xla/client/client.cc | 47 ++++++++++++++++++++-- tensorflow/compiler/xla/client/client.h | 2 + tensorflow/compiler/xla/service/service.cc | 20 ++++++++- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index a857c4ff0b4..c4c88943746 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -276,7 +276,12 @@ StatusOr> Client::Execute( if (execution_profile != nullptr) { *execution_profile = response.profile(); - // TODO(b/74197823): Get execution stats for the graph and VLOG(1) them. + if (VLOG_IS_ON(1)) { + TF_ASSIGN_OR_RETURN( + auto execution_stats, + ExecutionStatsAsString(computation, response.profile())); + VLOG(1) << execution_stats; + } } return MakeUnique(stub_, response.output()); @@ -402,8 +407,22 @@ StatusOr Client::GetComputationStats( StatusOr Client::GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const { - return Unimplemented( - "GetComputationStats is not yet implemented for XlaComputation"); + ComputationGraphStatsRequest request; + + // TODO(b/74197823): Find a way to avoid the copy of the hlo proto. + *request.mutable_computation() = computation.proto(); + *request.mutable_debug_options() = debug_options; + ComputationStatsResponse response; + + VLOG(1) << "making computation graph stats request"; + Status s = stub_->GetComputationGraphStats(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + CHECK(response.has_stats()); + return response.stats(); } StatusOr> Client::GetComputationShape( @@ -467,6 +486,28 @@ StatusOr Client::ExecutionStatsAsString( return string("[Execution Statistics] not available."); } +StatusOr Client::ExecutionStatsAsString( + const XlaComputation& computation, const ExecutionProfile& profile) { + TF_ASSIGN_OR_RETURN( + auto computation_stats, + GetComputationStats(computation, + legacy_flags::GetDebugOptionsFromFlags())); + int64 total_flops = + computation_stats.flop_count() + computation_stats.transcendental_count(); + if (profile.compute_time_ns() > 0) { + int64 nanoseconds = profile.compute_time_ns(); + int64 cycle_count = profile.compute_cycle_count(); + double gflops = total_flops / nanoseconds; + return tensorflow::strings::StrCat( + "[Execution Statistics] flop count: ", computation_stats.flop_count(), + ", transcendental count: ", computation_stats.transcendental_count(), + ", compute execution time: ", nanoseconds, " nsec", + ", compute cycles: ", cycle_count, ", performance: ", gflops, + "gflop/s"); + } + return string("[Execution Statistics] not available."); +} + StatusOr Client::CreateChannelHandle() { CreateChannelHandleRequest request; CreateChannelHandleResponse response; diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 226b788d541..05d707dab15 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -241,6 +241,8 @@ class Client { // ExecutionProfile returned from an execution of the computation. StatusOr ExecutionStatsAsString(const Computation& computation, const ExecutionProfile& profile); + StatusOr ExecutionStatsAsString(const XlaComputation& computation, + const ExecutionProfile& profile); ServiceInterface* stub_; // Stub that this client is connected on. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index af05e3f5169..ca8071b7bbb 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1452,7 +1452,25 @@ tensorflow::Status Service::GetComputationStats( tensorflow::Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { - return Unimplemented("get-computation-graph-stats is not yet implemented"); + HloModuleConfig config; + config.set_debug_options(arg->debug_options()); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(arg->computation(), config)); + + hlo_graph_dumper::MaybeDumpHloModule(*module, + "computation statistics subject"); + + // Run HLO analysis to get the computation statistics. + HloCostAnalysis analysis( + execute_backend_->compiler()->ShapeSizeBytesFunction()); + + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); + + ComputationStats stats; + stats.set_flop_count(analysis.flop_count()); + stats.set_transcendental_count(analysis.transcendental_count()); + *result->mutable_stats() = stats; + return tensorflow::Status::OK(); } template