Refactor: take the memory ownership of output stream when creating a CSVWriter instance.
PiperOrigin-RevId: 287781403 Change-Id: I934dd83f16a0723990624184bcc5a2fbc7fa2fdc
This commit is contained in:
parent
68cbb78a53
commit
2712ab1563
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -28,15 +29,16 @@ namespace metrics {
|
||||
// columns. This supports a very limited set of CSV spec and doesn't do any
|
||||
// escaping.
|
||||
// Usage:
|
||||
// std::ofstream * output_stream = ...
|
||||
// CSVWriter writer({"column1", "column2"}, output_stream);
|
||||
// std::unqiue_str<std::ofstream> output_stream = ...
|
||||
// CSVWriter writer({"column1", "column2"}, std::move(output_stream));
|
||||
// writer.WriteRow({4, 5});
|
||||
// writer.Flush(); // flush results immediately.
|
||||
class CSVWriter {
|
||||
public:
|
||||
CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
|
||||
: num_columns_(columns.size()), output_stream_(output_stream) {
|
||||
if (WriteRow(columns, output_stream_) != kTfLiteOk) {
|
||||
CSVWriter(const std::vector<string>& columns,
|
||||
std::unique_ptr<std::ofstream> output_stream)
|
||||
: num_columns_(columns.size()), output_stream_(std::move(output_stream)) {
|
||||
if (WriteRow(columns, output_stream_.get()) != kTfLiteOk) {
|
||||
LOG(ERROR) << "Could not write column names to file";
|
||||
}
|
||||
}
|
||||
@ -48,7 +50,7 @@ class CSVWriter {
|
||||
<< " expected: " << num_columns_;
|
||||
return kTfLiteError;
|
||||
}
|
||||
return WriteRow(values, output_stream_);
|
||||
return WriteRow(values, output_stream_.get());
|
||||
}
|
||||
|
||||
void Flush() { output_stream_->flush(); }
|
||||
@ -76,7 +78,7 @@ class CSVWriter {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
const size_t num_columns_;
|
||||
std::ofstream* output_stream_;
|
||||
std::unique_ptr<std::ofstream> output_stream_;
|
||||
};
|
||||
} // namespace metrics
|
||||
} // namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -38,13 +39,14 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
|
||||
return;
|
||||
}
|
||||
|
||||
output_stream_.reset(new std::ofstream(output_file_path, std::ios::out));
|
||||
if (!output_stream_) {
|
||||
std::unique_ptr<std::ofstream> output_stream(
|
||||
new std::ofstream(output_file_path, std::ios::out));
|
||||
if (!output_stream) {
|
||||
LOG(ERROR) << "Unable to open output file path: '" << output_file_path
|
||||
<< "'";
|
||||
}
|
||||
|
||||
(*output_stream_) << std::setprecision(3) << std::fixed;
|
||||
(*output_stream) << std::setprecision(3) << std::fixed;
|
||||
std::vector<string> columns;
|
||||
columns.reserve(top_k);
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
@ -53,7 +55,7 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
|
||||
columns.push_back(column_name);
|
||||
}
|
||||
|
||||
writer_.reset(new CSVWriter(columns, output_stream_.get()));
|
||||
writer_.reset(new CSVWriter(columns, std::move(output_stream)));
|
||||
}
|
||||
|
||||
void ResultsWriter::AggregateAccuraciesAndNumImages(
|
||||
|
@ -56,9 +56,6 @@ class ResultsWriter : public ImagenetModelEvaluator::Observer {
|
||||
shard_id_accuracy_metrics_map_;
|
||||
std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
|
||||
|
||||
// TODO(b/146988222): Refactor CSVWriter to take the memory ownership of
|
||||
// 'output_stream_'.
|
||||
std::unique_ptr<std::ofstream> output_stream_;
|
||||
std::unique_ptr<CSVWriter> writer_;
|
||||
|
||||
// For logging to stdout.
|
||||
|
Loading…
x
Reference in New Issue
Block a user