Refactor: take the memory ownership of output stream when creating a CSVWriter instance.

PiperOrigin-RevId: 287781403
Change-Id: I934dd83f16a0723990624184bcc5a2fbc7fa2fdc
This commit is contained in:
Chao Mei 2020-01-01 19:27:32 -08:00 committed by TensorFlower Gardener
parent 68cbb78a53
commit 2712ab1563
3 changed files with 15 additions and 14 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ #define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
#include <fstream> #include <fstream>
#include <memory>
#include <vector> #include <vector>
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -28,15 +29,16 @@ namespace metrics {
// columns. This supports a very limited set of CSV spec and doesn't do any // columns. This supports a very limited set of CSV spec and doesn't do any
// escaping. // escaping.
// Usage: // Usage:
// std::ofstream * output_stream = ... // std::unqiue_str<std::ofstream> output_stream = ...
// CSVWriter writer({"column1", "column2"}, output_stream); // CSVWriter writer({"column1", "column2"}, std::move(output_stream));
// writer.WriteRow({4, 5}); // writer.WriteRow({4, 5});
// writer.Flush(); // flush results immediately. // writer.Flush(); // flush results immediately.
class CSVWriter { class CSVWriter {
public: public:
CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream) CSVWriter(const std::vector<string>& columns,
: num_columns_(columns.size()), output_stream_(output_stream) { std::unique_ptr<std::ofstream> output_stream)
if (WriteRow(columns, output_stream_) != kTfLiteOk) { : num_columns_(columns.size()), output_stream_(std::move(output_stream)) {
if (WriteRow(columns, output_stream_.get()) != kTfLiteOk) {
LOG(ERROR) << "Could not write column names to file"; LOG(ERROR) << "Could not write column names to file";
} }
} }
@ -48,7 +50,7 @@ class CSVWriter {
<< " expected: " << num_columns_; << " expected: " << num_columns_;
return kTfLiteError; return kTfLiteError;
} }
return WriteRow(values, output_stream_); return WriteRow(values, output_stream_.get());
} }
void Flush() { output_stream_->flush(); } void Flush() { output_stream_->flush(); }
@ -76,7 +78,7 @@ class CSVWriter {
return kTfLiteOk; return kTfLiteOk;
} }
const size_t num_columns_; const size_t num_columns_;
std::ofstream* output_stream_; std::unique_ptr<std::ofstream> output_stream_;
}; };
} // namespace metrics } // namespace metrics
} // namespace tensorflow } // namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <cstdlib> #include <cstdlib>
#include <iomanip> #include <iomanip>
#include <memory>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -38,13 +39,14 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
return; return;
} }
output_stream_.reset(new std::ofstream(output_file_path, std::ios::out)); std::unique_ptr<std::ofstream> output_stream(
if (!output_stream_) { new std::ofstream(output_file_path, std::ios::out));
if (!output_stream) {
LOG(ERROR) << "Unable to open output file path: '" << output_file_path LOG(ERROR) << "Unable to open output file path: '" << output_file_path
<< "'"; << "'";
} }
(*output_stream_) << std::setprecision(3) << std::fixed; (*output_stream) << std::setprecision(3) << std::fixed;
std::vector<string> columns; std::vector<string> columns;
columns.reserve(top_k); columns.reserve(top_k);
for (int i = 0; i < top_k; i++) { for (int i = 0; i < top_k; i++) {
@ -53,7 +55,7 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
columns.push_back(column_name); columns.push_back(column_name);
} }
writer_.reset(new CSVWriter(columns, output_stream_.get())); writer_.reset(new CSVWriter(columns, std::move(output_stream)));
} }
void ResultsWriter::AggregateAccuraciesAndNumImages( void ResultsWriter::AggregateAccuraciesAndNumImages(

View File

@ -56,9 +56,6 @@ class ResultsWriter : public ImagenetModelEvaluator::Observer {
shard_id_accuracy_metrics_map_; shard_id_accuracy_metrics_map_;
std::unordered_map<uint64_t, int> shard_id_done_image_count_map_; std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
// TODO(b/146988222): Refactor CSVWriter to take the memory ownership of
// 'output_stream_'.
std::unique_ptr<std::ofstream> output_stream_;
std::unique_ptr<CSVWriter> writer_; std::unique_ptr<CSVWriter> writer_;
// For logging to stdout. // For logging to stdout.