diff --git a/tensorflow/core/util/reporter.cc b/tensorflow/core/util/reporter.cc index 8e9d863b4c2..44465a58329 100644 --- a/tensorflow/core/util/reporter.cc +++ b/tensorflow/core/util/reporter.cc @@ -91,6 +91,14 @@ Status TestReporter::SetProperty(const string& name, double value) { return Status::OK(); } +Status TestReporter::AddMetric(const string& name, double value) { + if (report_file_.IsClosed()) return Status::OK(); + auto* metric = benchmark_entry_.add_metrics(); + metric->set_name(name); + metric->set_value(value); + return Status::OK(); +} + Status TestReporter::Initialize() { return report_file_.Initialize(); } } // namespace tensorflow diff --git a/tensorflow/core/util/reporter.h b/tensorflow/core/util/reporter.h index 51d7502701c..900fe40353e 100644 --- a/tensorflow/core/util/reporter.h +++ b/tensorflow/core/util/reporter.h @@ -111,6 +111,9 @@ class TestReporter { // Set property on Benchmark to the given value. Status SetProperty(const string& name, const string& value); + // Add the given value to the metrics on the Benchmark. + Status AddMetric(const string& name, double value); + // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! ~TestReporter() { Close().IgnoreError(); } // Autoclose in destructor. diff --git a/tensorflow/core/util/reporter_test.cc b/tensorflow/core/util/reporter_test.cc index 4c06560b852..77e7ed6467e 100644 --- a/tensorflow/core/util/reporter_test.cc +++ b/tensorflow/core/util/reporter_test.cc @@ -138,5 +138,30 @@ TEST(TestReporter, SetProperties) { EXPECT_EQ(4.0, extras.at("double_prop").double_value()); } +TEST(TestReporter, AddMetrics) { + string fname = + strings::StrCat(testing::TmpDir(), "/test_reporter_benchmarks_"); + TestReporter test_reporter(fname, "b3/4/5"); + TF_EXPECT_OK(test_reporter.Initialize()); + TF_EXPECT_OK(test_reporter.AddMetric("metric1", 2.0)); + TF_EXPECT_OK(test_reporter.AddMetric("metric2", 3.0)); + + TF_EXPECT_OK(test_reporter.Close()); + string expected_fname = strings::StrCat(fname, "b3__4__5"); + string read; + TF_EXPECT_OK(ReadFileToString(Env::Default(), expected_fname, &read)); + + BenchmarkEntries benchmark_entries; + ASSERT_TRUE(benchmark_entries.ParseFromString(read)); + ASSERT_EQ(1, benchmark_entries.entry_size()); + const BenchmarkEntry& benchmark_entry = benchmark_entries.entry(0); + const auto& metrics = benchmark_entry.metrics(); + ASSERT_EQ(2, metrics.size()); + EXPECT_EQ("metric1", metrics.at(0).name()); + EXPECT_EQ(2.0, metrics.at(0).value()); + EXPECT_EQ("metric2", metrics.at(1).name()); + EXPECT_EQ(3.0, metrics.at(1).value()); +} + } // namespace } // namespace tensorflow