Add OpStats to InputPipelineAnalysis converter.

PiperOrigin-RevId: 286924417
Change-Id: I36c8eb3ca4a4e067941e0fc5b9995ec0b528a97a
This commit is contained in:
A. Unique TensorFlower 2019-12-23 11:52:19 -08:00 committed by TensorFlower Gardener
parent 0fd31271b2
commit bc493c2488
7 changed files with 640 additions and 1 deletions

View File

@ -78,6 +78,31 @@ cc_library(
],
)
cc_library(
name = "op_stats_to_input_pipeline_analysis",
srcs = ["op_stats_to_input_pipeline_analysis.cc"],
hdrs = ["op_stats_to_input_pipeline_analysis.h"],
deps = [
":op_metrics_to_record",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:logging",
"//tensorflow/core/profiler/protobuf:hardware_types_proto_cc",
"//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc",
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
"//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
"//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
"//tensorflow/core/profiler/utils:event_span",
"//tensorflow/core/profiler/utils:math_utils",
"//tensorflow/core/profiler/utils:tf_op_utils",
"//tensorflow/core/profiler/utils:time_utils",
"//tensorflow/core/util:stats_calculator_portable",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "op_stats_to_tf_stats",
srcs = ["op_stats_to_tf_stats.cc"],

View File

@ -0,0 +1,402 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
#include <algorithm>
#include <utility>
#include "google/protobuf/any.pb.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
#include "tensorflow/core/profiler/utils/event_span.h"
#include "tensorflow/core/profiler/utils/math_utils.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
#include "tensorflow/core/util/stats_calculator.h"
namespace tensorflow {
namespace profiler {
namespace {
const double kNumPsPerMs = 1000000000.0;
template <class Collection>
double GetTimeInMs(const Collection& type_ps,
EventType event_type) {
return PicosToMillis(gtl::FindWithDefault(type_ps, event_type, /*value=*/0));
}
StepSummary GetStepSummaryForSampleStats(const Stat<double>& sample_stats) {
StepSummary step_time_summary;
step_time_summary.set_average(sample_stats.avg());
step_time_summary.set_standard_deviation(
std::sqrt(sample_stats.sample_variance()));
step_time_summary.set_minimum(sample_stats.min());
step_time_summary.set_maximum(sample_stats.max());
return step_time_summary;
}
GenericStepTimeBreakdown ComputeGenericStepTimeBreakdownInMs(
const InputPipelineAnalysisResult& analysis) {
Stat<double> unknown_time_ms;
Stat<double> infeed_ms;
Stat<double> outfeed_ms;
Stat<double> device_compute_ms;
Stat<double> device_to_device_ms;
Stat<double> host_compute_ms;
Stat<double> host_prepare_ms;
Stat<double> host_compile_ms;
GenericStepTimeBreakdown result;
for (const google::protobuf::Any& step_details : analysis.step_details()) {
PerGenericStepDetails details;
bool success = step_details.UnpackTo(&details);
if (!success) {
LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic"
<< std::endl;
return {};
}
unknown_time_ms.UpdateStat(details.unknown_time_ms());
infeed_ms.UpdateStat(details.infeed_ms());
outfeed_ms.UpdateStat(details.outfeed_ms());
device_compute_ms.UpdateStat(details.device_compute_ms());
device_to_device_ms.UpdateStat(details.device_to_device_ms());
host_compute_ms.UpdateStat(details.host_compute_ms());
host_prepare_ms.UpdateStat(details.host_prepare_ms());
host_compile_ms.UpdateStat(details.host_compile_ms());
}
*result.mutable_unknown_time_ms_summary() =
GetStepSummaryForSampleStats(unknown_time_ms);
*result.mutable_infeed_ms_summary() = GetStepSummaryForSampleStats(infeed_ms);
*result.mutable_outfeed_ms_summary() =
GetStepSummaryForSampleStats(outfeed_ms);
*result.mutable_device_compute_ms_summary() =
GetStepSummaryForSampleStats(device_compute_ms);
*result.mutable_device_to_device_ms_summary() =
GetStepSummaryForSampleStats(device_to_device_ms);
*result.mutable_host_compute_ms_summary() =
GetStepSummaryForSampleStats(host_compute_ms);
*result.mutable_host_prepare_ms_summary() =
GetStepSummaryForSampleStats(host_prepare_ms);
*result.mutable_host_compile_ms_summary() =
GetStepSummaryForSampleStats(host_compile_ms);
return result;
}
InputPipelineAnalysisResult ComputeGenericInputPipelineAnalysisResult(
const protobuf::RepeatedPtrField<PerCoreStepInfo>& grouped_by_step) {
InputPipelineAnalysisResult result;
// Computes the summary of step time in ms.
*result.mutable_step_time_summary() =
ComputeStepTimeSummaryInMs(grouped_by_step);
Stat<double> infeed_summary_stats_in_percent;
for (const auto& coreid_stepinfo_map : grouped_by_step) {
// Iterates over each step.
const auto* ptr =
gtl::FindOrNull(coreid_stepinfo_map.step_info_per_core(), 0);
if (ptr == nullptr) {
// For generic hardware, all step-info is put under core-0. If ptr
// is nullptr, it means there is no step at all.
continue;
}
const StepInfoResult& step_info = *ptr;
// Adds the details for a new step.
PerGenericStepDetails details;
details.set_step_number(step_info.step_num());
details.set_step_time_ms(PicosToMillis(step_info.duration_ps()));
GenericStepBreakdown generic;
bool success = step_info.step_breakdown().UnpackTo(&generic);
if (!success) {
LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic"
<< std::endl;
return {};
}
const auto& type_ps = generic.type_ps();
details.set_unknown_time_ms(GetTimeInMs(type_ps, UNKNOWN_TIME));
// To be consistent with TPU case, the infeed time includes the time that
// the host is reading files, preprocessing, and the time to transfer the
// data to the device.
details.set_infeed_ms(GetTimeInMs(type_ps, HOST_WAIT_INPUT) +
GetTimeInMs(type_ps, HOST_TO_DEVICE) +
GetTimeInMs(type_ps, DEVICE_WAIT_HOST));
details.set_outfeed_ms(GetTimeInMs(type_ps, DEVICE_TO_HOST));
details.set_device_compute_ms(GetTimeInMs(type_ps, DEVICE_COMPUTE));
details.set_device_to_device_ms(GetTimeInMs(type_ps, DEVICE_TO_DEVICE) +
GetTimeInMs(type_ps, DEVICE_WAIT_DEVICE));
details.set_host_compute_ms(GetTimeInMs(type_ps, HOST_COMPUTE));
details.set_host_prepare_ms(GetTimeInMs(type_ps, HOST_PREPARE));
details.set_host_compile_ms(GetTimeInMs(type_ps, HOST_COMPILE));
result.add_step_details()->PackFrom(details);
const double infeed_pct_of_step_time =
100.0 * SafeDivide(details.infeed_ms(), details.step_time_ms());
infeed_summary_stats_in_percent.UpdateStat(infeed_pct_of_step_time);
}
// Computes the summary of infeed time as percentage of step time.
*result.mutable_infeed_percent_summary() =
GetStepSummaryForSampleStats(infeed_summary_stats_in_percent);
// Computes the breakdown of step time.
GenericStepTimeBreakdown generic_step_time_breakdown =
ComputeGenericStepTimeBreakdownInMs(result);
result.mutable_step_time_breakdown()->PackFrom(generic_step_time_breakdown);
return result;
}
// Classification of input processing on the host.
enum class InputOpCategory {
kEnqueue, // enqueue data to be transferred to device.
kDemandedFileRead, // demanded read from file.
kAdvancedFileRead, // advanced read from file (including cached,
// prefetch, parallel-map, interleave).
kPreprocessing // data preprocessing.
};
string InputOpCategoryString(InputOpCategory category) {
switch (category) {
case InputOpCategory::kEnqueue:
return "Enqueue";
case InputOpCategory::kDemandedFileRead:
return "Demanded file read";
case InputOpCategory::kAdvancedFileRead:
return "Advanced file read";
case InputOpCategory::kPreprocessing:
return "Preprocessing";
}
}
inline bool IsInputOp(absl::string_view category) {
return IsInfeedEnqueueOp(category) || IsDatasetOp(category);
}
InputOpCategory CategorizeInputOp(absl::string_view name,
absl::string_view category) {
if (IsInfeedEnqueueOp(category)) {
return InputOpCategory::kEnqueue;
}
DCHECK(IsDatasetOp(category));
if (absl::EndsWith(name, "::TFRecord") ||
absl::EndsWith(name, "::TextLine") ||
absl::EndsWith(name, "::FixedLengthRecord") ||
absl::EndsWith(name, "::SSTable") || absl::EndsWith(name, "::RecordIO")) {
if (absl::StrContains(name, "::MemoryReader") ||
absl::StrContains(name, "::MemoryWriter") ||
absl::StrContains(name, "::Interleave") ||
absl::StrContains(name, "::Prefetch") ||
absl::StrContains(name, "::ParallelMap")) {
return InputOpCategory::kAdvancedFileRead;
} else {
return InputOpCategory::kDemandedFileRead;
}
} else {
return InputOpCategory::kPreprocessing;
}
}
struct InputOpMetrics {
std::vector<const OpMetrics*> input_op_metrics;
uint64 input_op_time_ps = 0;
};
InputOpMetrics SelectInputOpMetrics(const OpMetricsDb& all_op_metrics) {
InputOpMetrics input_op_metrics;
for (const OpMetrics* op_metrics : SortedOpMetricsDb(all_op_metrics)) {
if (IsInputOp(op_metrics->category())) {
input_op_metrics.input_op_metrics.push_back(op_metrics);
input_op_metrics.input_op_time_ps += op_metrics->self_time_ps();
}
}
return input_op_metrics;
}
InputOpDetails ConvertOpMetricsToInputOpDetails(const OpMetrics& op_metrics,
uint64 input_op_time_ps,
InputOpCategory category) {
InputOpDetails details;
details.set_op_name(op_metrics.name());
details.set_count(op_metrics.occurrences());
details.set_time_in_ms(PicosToMillis(op_metrics.time_ps()));
details.set_self_time_in_ms(PicosToMillis(op_metrics.self_time_ps()));
details.set_time_in_percent(
100.0 * SafeDivide(op_metrics.time_ps(), input_op_time_ps));
details.set_self_time_in_percent(
100.0 * SafeDivide(op_metrics.self_time_ps(), input_op_time_ps));
details.set_category(InputOpCategoryString(category));
return details;
}
void GenerateHostResult(const OpMetricsDb& host_tf_metrics_db,
InputPipelineAnalysisResult* result) {
InputOpMetrics input_op_metrics = SelectInputOpMetrics(host_tf_metrics_db);
// Return if the program is not using an input pipeline with xprof
// instrumentation and no input ops are found.
if (input_op_metrics.input_op_metrics.empty()) return;
absl::flat_hash_map<InputOpCategory, double> aggregated_input_op_times_us;
for (const OpMetrics* op_metrics : input_op_metrics.input_op_metrics) {
InputOpCategory category =
CategorizeInputOp(op_metrics->name(), op_metrics->category());
*result->add_input_op_details() = ConvertOpMetricsToInputOpDetails(
*op_metrics, input_op_metrics.input_op_time_ps, category);
aggregated_input_op_times_us[category] +=
PicosToMicros(op_metrics->self_time_ps());
}
double enqueue_time_us =
aggregated_input_op_times_us[InputOpCategory::kEnqueue];
double total_input_op_time_us =
aggregated_input_op_times_us[InputOpCategory::kDemandedFileRead] +
aggregated_input_op_times_us[InputOpCategory::kAdvancedFileRead] +
aggregated_input_op_times_us[InputOpCategory::kPreprocessing];
// We use total_host_infeed_enq_start_timestamp_ps_diff_ to approximate the
// total host step time.
double ratio = SafeDivide(
host_tf_metrics_db.total_host_infeed_enq_duration_ps(),
host_tf_metrics_db.total_host_infeed_enq_start_timestamp_ps_diff());
DCHECK_LE(ratio, 1.0);
DCHECK_GE(ratio, 0.0);
double non_enqueue_time_us = (ratio != 0.0)
? (enqueue_time_us * (1.0 - ratio) / ratio)
: total_input_op_time_us;
// Scales the various input-time components wrt to non_enqueue_time_us.
double scaled_demanded_fileread_time_us = SafeDivide(
non_enqueue_time_us *
aggregated_input_op_times_us[InputOpCategory::kDemandedFileRead],
total_input_op_time_us);
double scaled_advanced_fileread_time_us = SafeDivide(
non_enqueue_time_us *
aggregated_input_op_times_us[InputOpCategory::kAdvancedFileRead],
total_input_op_time_us);
double scaled_preprocessing_time_us = SafeDivide(
non_enqueue_time_us *
aggregated_input_op_times_us[InputOpCategory::kPreprocessing],
total_input_op_time_us);
double unclassified_non_enqueue_time_us = std::max(
0.0, non_enqueue_time_us - scaled_demanded_fileread_time_us -
scaled_advanced_fileread_time_us - scaled_preprocessing_time_us);
InputTimeBreakdown* input_time_breakdown =
result->mutable_input_time_breakdown();
input_time_breakdown->set_enqueue_us(enqueue_time_us);
input_time_breakdown->set_demanded_file_read_us(
scaled_demanded_fileread_time_us);
input_time_breakdown->set_advanced_file_read_us(
scaled_advanced_fileread_time_us);
input_time_breakdown->set_preprocessing_us(scaled_preprocessing_time_us);
input_time_breakdown->set_unclassified_non_enqueue_us(
unclassified_non_enqueue_time_us);
}
string AnchorElement(absl::string_view url, absl::string_view text) {
return absl::StrCat("<a href=\"", url, "\" target=\"_blank\">", text, "</a>");
}
InputPipelineAnalysisRecommendation GenerateRecommendation() {
const absl::string_view kDatasetIntro =
"https://www.tensorflow.org/programmers_guide/datasets";
const absl::string_view kDatasetTopic =
"https://www.tensorflow.org/api_docs/python/tf/data/Dataset#";
const absl::string_view kTfRecordDataset =
"https://www.tensorflow.org/api_docs/python/tf/data/"
"TFRecordDataset#class_tfrecorddataset";
InputPipelineAnalysisRecommendation recommendation;
*recommendation.add_details() =
"Enqueuing data: you may want to combine small input data chunks "
"into fewer "
"but larger chunks.";
*recommendation.add_details() = absl::StrCat(
"Data preprocessing: you may increase num_parallel_calls in ",
AnchorElement(absl::StrCat(kDatasetTopic, "map"), "Dataset map()"),
" or preprocess the data OFFLINE.");
*recommendation.add_details() = absl::StrCat(
"Reading data from files in advance: you may tune parameters in the "
"following Dataset API (",
AnchorElement(absl::StrCat(kDatasetTopic, "prefetch"), "prefetch size"),
", ",
AnchorElement(absl::StrCat(kDatasetTopic, "interleave"),
"interleave cycle_length"),
", ", AnchorElement(kTfRecordDataset, "reader buffer_size"), ")");
*recommendation.add_details() = absl::StrCat(
"Reading data from files on demand: you should read data IN ADVANCE "
"using the following Dataset API (",
AnchorElement(absl::StrCat(kDatasetTopic, "prefetch"), "prefetch"), ", ",
AnchorElement(absl::StrCat(kDatasetTopic, "interleave"), "interleave"),
", ", AnchorElement(kTfRecordDataset, "reader buffer"), ")");
*recommendation.add_details() = absl::StrCat(
"Other data reading or processing: you may consider using the ",
AnchorElement(kDatasetIntro, "Dataset API"),
" (if you are not using it now)");
return recommendation;
}
} // namespace
StepSummary ComputeStepTimeSummaryInMs(
const protobuf::RepeatedPtrField<PerCoreStepInfo>& grouped_by_step) {
Stat<double> total_step_stats_in_ms;
// iterates over each step.
for (const auto& coreid_stepinfo_map : grouped_by_step) {
double max_per_step_stats_in_ms = 0.0;
// iterates over each core.
for (const auto& coreid_and_stepinfo :
coreid_stepinfo_map.step_info_per_core()) {
const auto& step_info = coreid_and_stepinfo.second;
max_per_step_stats_in_ms = std::max(step_info.duration_ps() / kNumPsPerMs,
max_per_step_stats_in_ms);
}
// Step time of each step is determined by the slowest core.
total_step_stats_in_ms.UpdateStat(max_per_step_stats_in_ms);
}
return GetStepSummaryForSampleStats(total_step_stats_in_ms);
}
InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis(
const OpStats& op_stats, const HardwareType& hardware_type) {
InputPipelineAnalysisResult result =
ComputeGenericInputPipelineAnalysisResult(
op_stats.step_db().step_sequence());
result.set_hardware_type(hardware_type);
GenerateHostResult(op_stats.host_op_metrics_db(), &result);
*result.mutable_recommendation() = GenerateRecommendation();
return result;
}
} // namespace profiler
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h"
#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
namespace tensorflow {
namespace profiler {
InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis(
const OpStats& op_stats, const HardwareType& hardware_type);
// Computes the summary of step time in milliseconds.
StepSummary ComputeStepTimeSummaryInMs(
const ::tensorflow::protobuf::RepeatedPtrField<PerCoreStepInfo>&
grouped_by_step);
} // namespace profiler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_

View File

@ -26,6 +26,16 @@ exports_files(
visibility = ["//tensorflow/core:__pkg__"],
)
tf_proto_library(
name = "input_pipeline_proto",
srcs = ["input_pipeline.proto"],
cc_api_version = 2,
protodeps = [":hardware_types_proto"],
visibility = [
":friends",
],
)
tf_proto_library(
name = "op_metrics_proto",
srcs = ["op_metrics.proto"],

View File

@ -0,0 +1,118 @@
syntax = "proto3";
package tensorflow.profiler;
import "google/protobuf/any.proto";
import "tensorflow/core/profiler/protobuf/hardware_types.proto";
// Used for both step duration and Op duration.
message StepSummary {
double average = 1;
double standard_deviation = 2;
double minimum = 3;
double maximum = 4;
}
// Per-step details on generic hardware.
message PerGenericStepDetails {
// The step number of a step.
int32 step_number = 1;
// The step time (in ms).
double step_time_ms = 2;
// Breakdown of the step time in different event categories.
// The unknown time (in ms).
double unknown_time_ms = 3;
// The infeed time (in ms).
double infeed_ms = 4;
// The outfeed time (in ms).
double outfeed_ms = 5;
// The device-compute time (in ms).
double device_compute_ms = 6;
// The device-to-device communication time (in ms).
double device_to_device_ms = 7;
// The host-compute time (in ms).
double host_compute_ms = 8;
// The host-prepare time (in ms).
double host_prepare_ms = 9;
// The time spent on compiling (in ms).
double host_compile_ms = 10;
}
message InputTimeBreakdown {
// Time spent on demanded file read in microseconds.
double demanded_file_read_us = 1;
// Time spent on advanced file read in microseconds.
double advanced_file_read_us = 2;
// Time spent on data preprocessing in microseconds.
double preprocessing_us = 3;
// The infeed enqueue time in microseconds.
double enqueue_us = 4;
// This entry is for the situtation where we can't further
// break down the non-enqueue input time (because the input pipeline
// is not instrumented).
double unclassified_non_enqueue_us = 5;
}
message InputOpDetails {
// The Op's name.
string op_name = 1;
// The number of occurrences.
uint64 count = 2;
// Time (accumulated over all occurrences) in milliseconds.
double time_in_ms = 3;
// Time (accumulated over all occurrences) in
// percentage of the total input processing time.
double time_in_percent = 4;
// Self time (accumulated over all occurrences) in milliseconds.
double self_time_in_ms = 5;
// Self time (accumulated over all occurrences) in
// percentage of the total input processing time.
double self_time_in_percent = 6;
// Possible categories: "Enqueue", "Advanced file read",
// "Demanded file read", "Preprocessing", "Unknown".
string category = 7;
}
message InputPipelineAnalysisRecommendation {
// A list of detailed recommendations.
repeated string details = 1;
}
message GenericStepTimeBreakdown {
// Summary of all unknown time as a part of step in ms.
StepSummary unknown_time_ms_summary = 1;
// Summary of all infeed time as a part of step in ms.
StepSummary infeed_ms_summary = 2;
// Summary of all outfeed time as a part of step in ms.
StepSummary outfeed_ms_summary = 3;
// Summary of all device-compute time as a part of step in ms.
StepSummary device_compute_ms_summary = 4;
// Summary of all device-to-device time as a part of step in ms.
StepSummary device_to_device_ms_summary = 5;
// Summary of all host-compute time as a part of step in ms.
StepSummary host_compute_ms_summary = 6;
// Summary of all host-prepare time as a part of step in ms.
StepSummary host_prepare_ms_summary = 7;
// Summary of all compilation time as a part of step in ms.
StepSummary host_compile_ms_summary = 8;
}
message InputPipelineAnalysisResult {
// Hardware type.
HardwareType hardware_type = 1;
// Summary of all step duration across all cores.
StepSummary step_time_summary = 2;
// Summary of all infeed dequeue op duration as percentage of step duration.
StepSummary infeed_percent_summary = 3;
// Details of each step. Can be unpacked into a PerGenericStepDetails.
repeated google.protobuf.Any step_details = 4;
// The breakdown of the input processing time.
InputTimeBreakdown input_time_breakdown = 5;
// Details of each input Op executed.
repeated InputOpDetails input_op_details = 6;
// Recommendation for next steps to users.
InputPipelineAnalysisRecommendation recommendation = 7;
// Breakdown of the step time. Can be unpacked into a
// GenericStepTimeBreakdown.
google.protobuf.Any step_time_breakdown = 8;
}

View File

@ -71,8 +71,21 @@ class Stat {
: static_cast<HighPrecisionValueType>(sum_) / count_;
}
// Returns sample variance.
ValueType sample_variance() const {
return all_same()
? 0
: (squared_sum_ - std::pow(sum_, 2.0) / count_) / (count_ - 1);
}
// Returns population variance.
ValueType variance() const {
return all_same() ? 0 : (squared_sum_ / count_) - (avg() * avg());
}
// Returns population stddev.
ValueType std_deviation() const {
return all_same() ? 0 : sqrt(squared_sum_ / count_ - avg() * avg());
return all_same() ? 0 : std::sqrt(variance());
}
void OutputToStream(std::ostream* stream) const {

View File

@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/stats_calculator.h"
#include <cfloat>
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -72,5 +75,34 @@ TEST(StatsCalculatorTest, AddNodeStatsUpdate) {
EXPECT_EQ(run1_mem_used + run2_mem_used, detail.mem_used.sum());
}
TEST(StatsCalculatorTest, UpdateStat) {
Stat<double> stat;
EXPECT_TRUE(stat.empty());
EXPECT_TRUE(stat.all_same());
stat.UpdateStat(1);
EXPECT_TRUE(stat.all_same());
stat.UpdateStat(-1.0);
EXPECT_FALSE(stat.all_same());
stat.UpdateStat(100);
stat.UpdateStat(0);
EXPECT_EQ(4, stat.count());
EXPECT_EQ(-1, stat.min());
EXPECT_EQ(100, stat.max());
EXPECT_EQ(25, stat.avg());
EXPECT_EQ(1, stat.first());
EXPECT_EQ(0, stat.newest());
EXPECT_EQ(10002, stat.squared_sum());
EXPECT_EQ(625, stat.avg() * stat.avg());
// Sample variance
EXPECT_EQ(7502.0 / 3, stat.sample_variance());
// Sample standard deviation, from WolframAlpha
EXPECT_NEAR(50.00666622228147160678152, std::sqrt(stat.sample_variance()),
FLT_EPSILON);
// Population variance
EXPECT_NEAR(7502.0 / 4, stat.variance(), FLT_EPSILON);
// Population standard deviation, from WolframAlpha
EXPECT_NEAR(43.30704330706496060826769, stat.std_deviation(), FLT_EPSILON);
}
} // namespace
} // namespace tensorflow