Move the writer functions of profileSummarizer to ProfileSummaryFormatter.
PiperOrigin-RevId: 295875892 Change-Id: Ie27c735012f1337b848e94548ac26aea5b8770b6
This commit is contained in:
parent
ea13922cf1
commit
ee7642b267
@ -112,6 +112,27 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "profile_summary_formatter",
|
||||
srcs = ["profile_summary_formatter.cc"],
|
||||
hdrs = ["profile_summary_formatter.h"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
"//tensorflow/core/util:stats_calculator_portable",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "profile_summary_formatter_test",
|
||||
srcs = ["profile_summary_formatter_test.cc"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
":profile_summary_formatter",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "profile_summarizer",
|
||||
srcs = ["profile_summarizer.cc"],
|
||||
@ -120,6 +141,7 @@ cc_library(
|
||||
deps = [
|
||||
":memory_info",
|
||||
":profile_buffer",
|
||||
":profile_summary_formatter",
|
||||
"//tensorflow/core/util:stats_calculator_portable",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/profiling/profile_summarizer.h"
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorflow/lite/profiling/memory_info.h"
|
||||
@ -85,29 +86,21 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
|
||||
return details;
|
||||
}
|
||||
|
||||
tensorflow::StatSummarizerOptions GetProfileSummarizerOptions(
|
||||
bool format_as_csv) {
|
||||
auto options = tensorflow::StatSummarizerOptions();
|
||||
// Summary will be manually handled per subgraphs in order to keep the
|
||||
// compatibility.
|
||||
options.show_summary = false;
|
||||
options.show_memory = false;
|
||||
options.format_as_csv = format_as_csv;
|
||||
return options;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ProfileSummarizer::ProfileSummarizer(bool format_as_csv)
|
||||
: delegate_stats_calculator_(new tensorflow::StatsCalculator(
|
||||
GetProfileSummarizerOptions(format_as_csv))),
|
||||
format_as_csv_(format_as_csv) {
|
||||
ProfileSummarizer::ProfileSummarizer(
|
||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter)
|
||||
: summary_formatter_(std::move(summary_formatter)) {
|
||||
// Create stats calculator for the primary graph.
|
||||
stats_calculator_map_[0] = std::unique_ptr<tensorflow::StatsCalculator>(
|
||||
new tensorflow::StatsCalculator(
|
||||
GetProfileSummarizerOptions(format_as_csv)));
|
||||
}
|
||||
summary_formatter_->GetStatSummarizerOptions()));
|
||||
|
||||
// Create stats calculator for the delegation op.
|
||||
delegate_stats_calculator_ = std::unique_ptr<tensorflow::StatsCalculator>(
|
||||
new tensorflow::StatsCalculator(
|
||||
summary_formatter_->GetStatSummarizerOptions()));
|
||||
}
|
||||
void ProfileSummarizer::ProcessProfiles(
|
||||
const std::vector<const ProfileEvent*>& profile_stats,
|
||||
const tflite::Interpreter& interpreter) {
|
||||
@ -214,45 +207,10 @@ tensorflow::StatsCalculator* ProfileSummarizer::GetStatsCalculator(
|
||||
stats_calculator_map_[subgraph_index] =
|
||||
std::unique_ptr<tensorflow::StatsCalculator>(
|
||||
new tensorflow::StatsCalculator(
|
||||
GetProfileSummarizerOptions(format_as_csv_)));
|
||||
summary_formatter_->GetStatSummarizerOptions()));
|
||||
}
|
||||
return stats_calculator_map_[subgraph_index].get();
|
||||
}
|
||||
|
||||
std::string ProfileSummarizer::GenerateReport(std::string tag,
|
||||
bool include_output_string) {
|
||||
std::stringstream stream;
|
||||
bool has_non_primary_graph =
|
||||
(stats_calculator_map_.size() - stats_calculator_map_.count(0)) > 0;
|
||||
for (auto& stats_calc : stats_calculator_map_) {
|
||||
auto subgraph_index = stats_calc.first;
|
||||
auto subgraph_stats = stats_calc.second.get();
|
||||
if (has_non_primary_graph) {
|
||||
if (subgraph_index == 0)
|
||||
stream << "Primary graph " << tag << ":" << std::endl;
|
||||
else
|
||||
stream << "Subgraph (index: " << subgraph_index << ") " << tag << ":"
|
||||
<< std::endl;
|
||||
}
|
||||
if (include_output_string) {
|
||||
stream << subgraph_stats->GetOutputString();
|
||||
}
|
||||
if (subgraph_index != 0) {
|
||||
stream << "Subgraph (index: " << subgraph_index << ") ";
|
||||
}
|
||||
stream << subgraph_stats->GetShortSummary() << std::endl;
|
||||
}
|
||||
|
||||
if (delegate_stats_calculator_->num_runs() > 0) {
|
||||
stream << "Delegate internal: " << std::endl;
|
||||
if (include_output_string) {
|
||||
stream << delegate_stats_calculator_->GetOutputString();
|
||||
}
|
||||
stream << delegate_stats_calculator_->GetShortSummary() << std::endl;
|
||||
}
|
||||
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
} // namespace profiling
|
||||
} // namespace tflite
|
||||
|
@ -17,11 +17,13 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARIZER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/util/stats_calculator.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/profiling/profile_buffer.h"
|
||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace profiling {
|
||||
@ -29,21 +31,25 @@ namespace profiling {
|
||||
// Creates a summary of operator invocations in the interpreter.
|
||||
class ProfileSummarizer {
|
||||
public:
|
||||
explicit ProfileSummarizer(bool format_as_csv = false);
|
||||
explicit ProfileSummarizer(
|
||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter =
|
||||
std::make_unique<ProfileSummaryDefaultFormatter>());
|
||||
virtual ~ProfileSummarizer() {}
|
||||
|
||||
// Process profile events to update statistics for operator invocations.
|
||||
void ProcessProfiles(const std::vector<const ProfileEvent*>& profile_stats,
|
||||
const tflite::Interpreter& interpreter);
|
||||
|
||||
// Returns a string detailing the accumulated runtime stats in a tab-separated
|
||||
// format which can be pasted into a spreadsheet for further analysis.
|
||||
// Returns a string detailing the accumulated runtime stats in the format of
|
||||
// summary_formatter_.
|
||||
std::string GetOutputString() {
|
||||
return GenerateReport("profile", /*include_output_string*/ true);
|
||||
return summary_formatter_->GetOutputString(stats_calculator_map_,
|
||||
*delegate_stats_calculator_);
|
||||
}
|
||||
|
||||
std::string GetShortSummary() {
|
||||
return GenerateReport("summary", /*include_output_string*/ false);
|
||||
return summary_formatter_->GetShortSummary(stats_calculator_map_,
|
||||
*delegate_stats_calculator_);
|
||||
}
|
||||
|
||||
tensorflow::StatsCalculator* GetStatsCalculator(uint32_t subgraph_index);
|
||||
@ -63,11 +69,8 @@ class ProfileSummarizer {
|
||||
|
||||
std::unique_ptr<tensorflow::StatsCalculator> delegate_stats_calculator_;
|
||||
|
||||
// GenerateReport returns the report of subgraphs in a string format.
|
||||
std::string GenerateReport(std::string tag, bool include_output_string);
|
||||
|
||||
// Whether output is formatted as CSV.
|
||||
bool format_as_csv_ = false;
|
||||
// Summary formatter for customized output formats.
|
||||
std::unique_ptr<ProfileSummaryFormatter> summary_formatter_;
|
||||
};
|
||||
|
||||
} // namespace profiling
|
||||
|
97
tensorflow/lite/profiling/profile_summary_formatter.cc
Normal file
97
tensorflow/lite/profiling/profile_summary_formatter.cc
Normal file
@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
namespace tflite {
|
||||
namespace profiling {
|
||||
|
||||
std::string ProfileSummaryDefaultFormatter::GetOutputString(
|
||||
const std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>&
|
||||
stats_calculator_map,
|
||||
const tensorflow::StatsCalculator& delegate_stats_calculator) const {
|
||||
return GenerateReport("profile", /*include_output_string*/ true,
|
||||
stats_calculator_map, delegate_stats_calculator);
|
||||
}
|
||||
|
||||
std::string ProfileSummaryDefaultFormatter::GetShortSummary(
|
||||
const std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>&
|
||||
stats_calculator_map,
|
||||
const tensorflow::StatsCalculator& delegate_stats_calculator) const {
|
||||
return GenerateReport("summary", /*include_output_string*/ false,
|
||||
stats_calculator_map, delegate_stats_calculator);
|
||||
}
|
||||
|
||||
std::string ProfileSummaryDefaultFormatter::GenerateReport(
|
||||
const std::string& tag, bool include_output_string,
|
||||
const std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>&
|
||||
stats_calculator_map,
|
||||
const tensorflow::StatsCalculator& delegate_stats_calculator) const {
|
||||
std::stringstream stream;
|
||||
bool has_non_primary_graph =
|
||||
(stats_calculator_map.size() - stats_calculator_map.count(0)) > 0;
|
||||
for (const auto& stats_calc : stats_calculator_map) {
|
||||
auto subgraph_index = stats_calc.first;
|
||||
auto subgraph_stats = stats_calc.second.get();
|
||||
if (has_non_primary_graph) {
|
||||
if (subgraph_index == 0) {
|
||||
stream << "Primary graph " << tag << ":" << std::endl;
|
||||
} else {
|
||||
stream << "Subgraph (index: " << subgraph_index << ") " << tag << ":"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
if (include_output_string) {
|
||||
stream << subgraph_stats->GetOutputString();
|
||||
}
|
||||
if (subgraph_index != 0) {
|
||||
stream << "Subgraph (index: " << subgraph_index << ") ";
|
||||
}
|
||||
stream << subgraph_stats->GetShortSummary() << std::endl;
|
||||
}
|
||||
|
||||
if (delegate_stats_calculator.num_runs() > 0) {
|
||||
stream << "Delegate internal: " << std::endl;
|
||||
if (include_output_string) {
|
||||
stream << delegate_stats_calculator.GetOutputString();
|
||||
}
|
||||
stream << delegate_stats_calculator.GetShortSummary() << std::endl;
|
||||
}
|
||||
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
tensorflow::StatSummarizerOptions
|
||||
ProfileSummaryDefaultFormatter::GetStatSummarizerOptions() const {
|
||||
auto options = tensorflow::StatSummarizerOptions();
|
||||
// Summary will be manually handled per subgraphs in order to keep the
|
||||
// compatibility.
|
||||
options.show_summary = false;
|
||||
options.show_memory = false;
|
||||
return options;
|
||||
}
|
||||
|
||||
tensorflow::StatSummarizerOptions
|
||||
ProfileSummaryCSVFormatter::GetStatSummarizerOptions() const {
|
||||
auto options = ProfileSummaryDefaultFormatter::GetStatSummarizerOptions();
|
||||
options.format_as_csv = true;
|
||||
return options;
|
||||
}
|
||||
|
||||
} // namespace profiling
|
||||
} // namespace tflite
|
84
tensorflow/lite/profiling/profile_summary_formatter.h
Normal file
84
tensorflow/lite/profiling/profile_summary_formatter.h
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARY_FORMATTER_H_
|
||||
#define TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARY_FORMATTER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/util/stats_calculator.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace profiling {
|
||||
|
||||
// Formats the profile summary in a certain way.
|
||||
class ProfileSummaryFormatter {
|
||||
public:
|
||||
ProfileSummaryFormatter() {}
|
||||
virtual ~ProfileSummaryFormatter() {}
|
||||
// Returns a string detailing the accumulated runtime stats in StatsCalculator
|
||||
// of ProfileSummarizer.
|
||||
virtual std::string GetOutputString(
|
||||
const std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>&
|
||||
stats_calculator_map,
|
||||
const tensorflow::StatsCalculator& delegate_stats_calculator) const = 0;
|
||||
// Returns a string detailing the short summary of the the accumulated runtime
|
||||
// stats in StatsCalculator of ProfileSummarizer.
|
||||
virtual std::string GetShortSummary(
|
||||
const std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>&
|
||||
stats_calculator_map,
|
||||
const tensorflow::StatsCalculator& delegate_stats_calculator) const = 0;
|
||||
virtual tensorflow::StatSummarizerOptions GetStatSummarizerOptions()
|
||||
const = 0;
|
||||
};
|
||||
|
||||
class ProfileSummaryDefaultFormatter : public ProfileSummaryFormatter {
|
||||
public:
|
||||
ProfileSummaryDefaultFormatter() {}
|
||||
~ProfileSummaryDefaultFormatter() override {}
|
||||
std::string GetOutputString(
|
||||
const std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>&
|
||||
stats_calculator_map,
|
||||
const tensorflow::StatsCalculator& delegate_stats_calculator)
|
||||
const override;
|
||||
std::string GetShortSummary(
|
||||
const std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>&
|
||||
stats_calculator_map,
|
||||
const tensorflow::StatsCalculator& delegate_stats_calculator)
|
||||
const override;
|
||||
tensorflow::StatSummarizerOptions GetStatSummarizerOptions() const override;
|
||||
|
||||
private:
|
||||
std::string GenerateReport(
|
||||
const std::string& tag, bool include_output_string,
|
||||
const std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>&
|
||||
stats_calculator_map,
|
||||
const tensorflow::StatsCalculator& delegate_stats_calculator) const;
|
||||
};
|
||||
|
||||
class ProfileSummaryCSVFormatter : public ProfileSummaryDefaultFormatter {
|
||||
public:
|
||||
ProfileSummaryCSVFormatter() {}
|
||||
tensorflow::StatSummarizerOptions GetStatSummarizerOptions() const override;
|
||||
};
|
||||
|
||||
} // namespace profiling
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARY_FORMATTER_H_
|
164
tensorflow/lite/profiling/profile_summary_formatter_test.cc
Normal file
164
tensorflow/lite/profiling/profile_summary_formatter_test.cc
Normal file
@ -0,0 +1,164 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace profiling {
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(SummaryWriterTest, SummaryOptionStdOut) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
tensorflow::StatSummarizerOptions options = writer.GetStatSummarizerOptions();
|
||||
EXPECT_EQ(options.show_summary, false);
|
||||
EXPECT_EQ(options.show_memory, false);
|
||||
EXPECT_EQ(options.format_as_csv, false);
|
||||
}
|
||||
|
||||
TEST(SummaryWriterTest, SummaryOptionCSV) {
|
||||
ProfileSummaryCSVFormatter writer;
|
||||
tensorflow::StatSummarizerOptions options = writer.GetStatSummarizerOptions();
|
||||
EXPECT_EQ(options.show_summary, false);
|
||||
EXPECT_EQ(options.show_memory, false);
|
||||
EXPECT_EQ(options.format_as_csv, true);
|
||||
}
|
||||
TEST(SummaryWriterTest, EmptyOutputString) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
std::string output = writer.GetOutputString(
|
||||
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>(),
|
||||
tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()));
|
||||
EXPECT_EQ(output.size(), 0);
|
||||
}
|
||||
|
||||
TEST(SummaryWriterTest, EmptyShortSummary) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
std::string output = writer.GetShortSummary(
|
||||
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>(),
|
||||
tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()));
|
||||
EXPECT_EQ(output.size(), 0);
|
||||
}
|
||||
|
||||
TEST(SummaryWriterTest, SingleSubgraphOutputString) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>
|
||||
stats_calculator_map;
|
||||
stats_calculator_map[0] = std::make_unique<tensorflow::StatsCalculator>(
|
||||
writer.GetStatSummarizerOptions());
|
||||
std::string output = writer.GetOutputString(
|
||||
stats_calculator_map,
|
||||
tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()));
|
||||
ASSERT_TRUE(output.find("Run Order") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("Top by Computation Time") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("Top by Memory Use") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Summary by node type") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("nodes observed") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("Primary graph") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Subgraph") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Delegate internal") == std::string::npos);
|
||||
}
|
||||
|
||||
TEST(SummaryWriterTest, SingleSubgraphShortSummary) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>
|
||||
stats_calculator_map;
|
||||
stats_calculator_map[0] = std::make_unique<tensorflow::StatsCalculator>(
|
||||
writer.GetStatSummarizerOptions());
|
||||
std::string output = writer.GetShortSummary(
|
||||
stats_calculator_map,
|
||||
tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()));
|
||||
ASSERT_TRUE(output.find("Run Order") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Top by Computation Time") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Top by Memory Use") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Summary by node type") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("nodes observed") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("Primary graph") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Subgraph") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Delegate internal") == std::string::npos);
|
||||
}
|
||||
|
||||
TEST(SummaryWriterTest, MultiSubgraphOutputString) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>
|
||||
stats_calculator_map;
|
||||
stats_calculator_map[0] = std::make_unique<tensorflow::StatsCalculator>(
|
||||
writer.GetStatSummarizerOptions());
|
||||
stats_calculator_map[1] = std::make_unique<tensorflow::StatsCalculator>(
|
||||
writer.GetStatSummarizerOptions());
|
||||
std::string output = writer.GetOutputString(
|
||||
stats_calculator_map,
|
||||
tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()));
|
||||
ASSERT_TRUE(output.find("Primary graph") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("Subgraph") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("Delegate internal") == std::string::npos);
|
||||
}
|
||||
|
||||
TEST(SummaryWriterTest, MultiSubgraphShortSummary) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>
|
||||
stats_calculator_map;
|
||||
stats_calculator_map[0] = std::make_unique<tensorflow::StatsCalculator>(
|
||||
writer.GetStatSummarizerOptions());
|
||||
stats_calculator_map[1] = std::make_unique<tensorflow::StatsCalculator>(
|
||||
writer.GetStatSummarizerOptions());
|
||||
std::string output = writer.GetShortSummary(
|
||||
stats_calculator_map,
|
||||
tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()));
|
||||
ASSERT_TRUE(output.find("Primary graph") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("Subgraph") != std::string::npos);
|
||||
ASSERT_TRUE(output.find("Delegate internal") == std::string::npos);
|
||||
}
|
||||
|
||||
TEST(SummaryWriterTest, DelegationOutputString) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
auto delegate_stats_calculator =
|
||||
tensorflow::StatsCalculator(writer.GetStatSummarizerOptions());
|
||||
delegate_stats_calculator.UpdateRunTotalUs(1);
|
||||
std::string output = writer.GetOutputString(
|
||||
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>(),
|
||||
delegate_stats_calculator);
|
||||
ASSERT_TRUE(output.find("Primary graph") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Subgraph") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Delegate internal") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST(SummaryWriterTest, DelegationShortSummary) {
|
||||
ProfileSummaryDefaultFormatter writer;
|
||||
auto delegate_stats_calculator =
|
||||
tensorflow::StatsCalculator(writer.GetStatSummarizerOptions());
|
||||
delegate_stats_calculator.UpdateRunTotalUs(1);
|
||||
std::string output = writer.GetShortSummary(
|
||||
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>(),
|
||||
delegate_stats_calculator);
|
||||
ASSERT_TRUE(output.find("Primary graph") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Subgraph") == std::string::npos);
|
||||
ASSERT_TRUE(output.find("Delegate internal") != std::string::npos);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace profiling
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -148,6 +148,7 @@ cc_library(
|
||||
"//tensorflow/lite/experimental/ruy/profiler",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/profiling:profiler",
|
||||
"//tensorflow/lite/profiling:profile_summary_formatter",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
] + select({
|
||||
"//tensorflow:fuchsia": [],
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/op_resolver.h"
|
||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
|
||||
#include "tensorflow/lite/tools/benchmark/delegate_provider.h"
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/profiling/profile_summary_formatter.h"
|
||||
#include "tensorflow/lite/profiling/profiler.h"
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
|
||||
|
||||
|
@ -22,11 +22,11 @@ namespace benchmark {
|
||||
|
||||
ProfilingListener::ProfilingListener(Interpreter* interpreter,
|
||||
uint32_t max_num_entries,
|
||||
std::string csv_file_path)
|
||||
const std::string& csv_file_path)
|
||||
: interpreter_(interpreter),
|
||||
profiler_(max_num_entries),
|
||||
run_summarizer_(!csv_file_path.empty()),
|
||||
init_summarizer_(!csv_file_path.empty()),
|
||||
run_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())),
|
||||
init_summarizer_(CreateProfileSummaryFormatter(!csv_file_path.empty())),
|
||||
csv_file_path_(csv_file_path) {
|
||||
TFLITE_BENCHMARK_CHECK(interpreter);
|
||||
interpreter_->SetProfiler(&profiler_);
|
||||
@ -85,5 +85,12 @@ void ProfilingListener::WriteOutput(const std::string& header,
|
||||
(*stream) << data << std::endl;
|
||||
}
|
||||
|
||||
std::unique_ptr<profiling::ProfileSummaryFormatter>
|
||||
ProfilingListener::CreateProfileSummaryFormatter(bool format_as_csv) const {
|
||||
return format_as_csv
|
||||
? std::make_unique<profiling::ProfileSummaryDefaultFormatter>()
|
||||
: std::make_unique<profiling::ProfileSummaryCSVFormatter>();
|
||||
}
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace tflite
|
||||
|
@ -27,7 +27,7 @@ namespace benchmark {
|
||||
class ProfilingListener : public BenchmarkListener {
|
||||
public:
|
||||
explicit ProfilingListener(Interpreter* interpreter, uint32_t max_num_entries,
|
||||
std::string csv_file_path = "");
|
||||
const std::string& csv_file_path = "");
|
||||
|
||||
void OnBenchmarkStart(const BenchmarkParams& params) override;
|
||||
|
||||
@ -37,6 +37,11 @@ class ProfilingListener : public BenchmarkListener {
|
||||
|
||||
void OnBenchmarkEnd(const BenchmarkResults& results) override;
|
||||
|
||||
protected:
|
||||
// Allow subclasses to create a customized summary writer during init.
|
||||
virtual std::unique_ptr<profiling::ProfileSummaryFormatter>
|
||||
CreateProfileSummaryFormatter(bool format_as_csv) const;
|
||||
|
||||
private:
|
||||
void WriteOutput(const std::string& header, const string& data,
|
||||
std::ostream* stream);
|
||||
|
Loading…
Reference in New Issue
Block a user