Refactor: take the memory ownership of output stream when creating a CSVWriter instance.

PiperOrigin-RevId: 287781403
Change-Id: I934dd83f16a0723990624184bcc5a2fbc7fa2fdc
This commit is contained in:
Chao Mei 2020-01-01 19:27:32 -08:00 committed by TensorFlower Gardener
parent 68cbb78a53
commit 2712ab1563
3 changed files with 15 additions and 14 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
#include <fstream>
#include <memory>
#include <vector>
#include "tensorflow/core/platform/logging.h"
@ -28,15 +29,16 @@ namespace metrics {
// columns. This supports a very limited set of CSV spec and doesn't do any
// escaping.
// Usage:
// std::ofstream * output_stream = ...
// CSVWriter writer({"column1", "column2"}, output_stream);
// std::unqiue_str<std::ofstream> output_stream = ...
// CSVWriter writer({"column1", "column2"}, std::move(output_stream));
// writer.WriteRow({4, 5});
// writer.Flush(); // flush results immediately.
class CSVWriter {
public:
CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
: num_columns_(columns.size()), output_stream_(output_stream) {
if (WriteRow(columns, output_stream_) != kTfLiteOk) {
CSVWriter(const std::vector<string>& columns,
std::unique_ptr<std::ofstream> output_stream)
: num_columns_(columns.size()), output_stream_(std::move(output_stream)) {
if (WriteRow(columns, output_stream_.get()) != kTfLiteOk) {
LOG(ERROR) << "Could not write column names to file";
}
}
@ -48,7 +50,7 @@ class CSVWriter {
<< " expected: " << num_columns_;
return kTfLiteError;
}
return WriteRow(values, output_stream_);
return WriteRow(values, output_stream_.get());
}
void Flush() { output_stream_->flush(); }
@ -76,7 +78,7 @@ class CSVWriter {
return kTfLiteOk;
}
const size_t num_columns_;
std::ofstream* output_stream_;
std::unique_ptr<std::ofstream> output_stream_;
};
} // namespace metrics
} // namespace tensorflow

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <cstdlib>
#include <iomanip>
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h"
@ -38,13 +39,14 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
return;
}
output_stream_.reset(new std::ofstream(output_file_path, std::ios::out));
if (!output_stream_) {
std::unique_ptr<std::ofstream> output_stream(
new std::ofstream(output_file_path, std::ios::out));
if (!output_stream) {
LOG(ERROR) << "Unable to open output file path: '" << output_file_path
<< "'";
}
(*output_stream_) << std::setprecision(3) << std::fixed;
(*output_stream) << std::setprecision(3) << std::fixed;
std::vector<string> columns;
columns.reserve(top_k);
for (int i = 0; i < top_k; i++) {
@ -53,7 +55,7 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
columns.push_back(column_name);
}
writer_.reset(new CSVWriter(columns, output_stream_.get()));
writer_.reset(new CSVWriter(columns, std::move(output_stream)));
}
void ResultsWriter::AggregateAccuraciesAndNumImages(

View File

@ -56,9 +56,6 @@ class ResultsWriter : public ImagenetModelEvaluator::Observer {
shard_id_accuracy_metrics_map_;
std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
// TODO(b/146988222): Refactor CSVWriter to take the memory ownership of
// 'output_stream_'.
std::unique_ptr<std::ofstream> output_stream_;
std::unique_ptr<CSVWriter> writer_;
// For logging to stdout.