Refactor: take the memory ownership of output stream when creating a CSVWriter instance.
PiperOrigin-RevId: 287781403 Change-Id: I934dd83f16a0723990624184bcc5a2fbc7fa2fdc
This commit is contained in:
parent
68cbb78a53
commit
2712ab1563
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
|
#define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -28,15 +29,16 @@ namespace metrics {
|
|||||||
// columns. This supports a very limited set of CSV spec and doesn't do any
|
// columns. This supports a very limited set of CSV spec and doesn't do any
|
||||||
// escaping.
|
// escaping.
|
||||||
// Usage:
|
// Usage:
|
||||||
// std::ofstream * output_stream = ...
|
// std::unqiue_str<std::ofstream> output_stream = ...
|
||||||
// CSVWriter writer({"column1", "column2"}, output_stream);
|
// CSVWriter writer({"column1", "column2"}, std::move(output_stream));
|
||||||
// writer.WriteRow({4, 5});
|
// writer.WriteRow({4, 5});
|
||||||
// writer.Flush(); // flush results immediately.
|
// writer.Flush(); // flush results immediately.
|
||||||
class CSVWriter {
|
class CSVWriter {
|
||||||
public:
|
public:
|
||||||
CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
|
CSVWriter(const std::vector<string>& columns,
|
||||||
: num_columns_(columns.size()), output_stream_(output_stream) {
|
std::unique_ptr<std::ofstream> output_stream)
|
||||||
if (WriteRow(columns, output_stream_) != kTfLiteOk) {
|
: num_columns_(columns.size()), output_stream_(std::move(output_stream)) {
|
||||||
|
if (WriteRow(columns, output_stream_.get()) != kTfLiteOk) {
|
||||||
LOG(ERROR) << "Could not write column names to file";
|
LOG(ERROR) << "Could not write column names to file";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -48,7 +50,7 @@ class CSVWriter {
|
|||||||
<< " expected: " << num_columns_;
|
<< " expected: " << num_columns_;
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return WriteRow(values, output_stream_);
|
return WriteRow(values, output_stream_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Flush() { output_stream_->flush(); }
|
void Flush() { output_stream_->flush(); }
|
||||||
@ -76,7 +78,7 @@ class CSVWriter {
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
const size_t num_columns_;
|
const size_t num_columns_;
|
||||||
std::ofstream* output_stream_;
|
std::unique_ptr<std::ofstream> output_stream_;
|
||||||
};
|
};
|
||||||
} // namespace metrics
|
} // namespace metrics
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -38,13 +39,14 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
output_stream_.reset(new std::ofstream(output_file_path, std::ios::out));
|
std::unique_ptr<std::ofstream> output_stream(
|
||||||
if (!output_stream_) {
|
new std::ofstream(output_file_path, std::ios::out));
|
||||||
|
if (!output_stream) {
|
||||||
LOG(ERROR) << "Unable to open output file path: '" << output_file_path
|
LOG(ERROR) << "Unable to open output file path: '" << output_file_path
|
||||||
<< "'";
|
<< "'";
|
||||||
}
|
}
|
||||||
|
|
||||||
(*output_stream_) << std::setprecision(3) << std::fixed;
|
(*output_stream) << std::setprecision(3) << std::fixed;
|
||||||
std::vector<string> columns;
|
std::vector<string> columns;
|
||||||
columns.reserve(top_k);
|
columns.reserve(top_k);
|
||||||
for (int i = 0; i < top_k; i++) {
|
for (int i = 0; i < top_k; i++) {
|
||||||
@ -53,7 +55,7 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
|
|||||||
columns.push_back(column_name);
|
columns.push_back(column_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
writer_.reset(new CSVWriter(columns, output_stream_.get()));
|
writer_.reset(new CSVWriter(columns, std::move(output_stream)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ResultsWriter::AggregateAccuraciesAndNumImages(
|
void ResultsWriter::AggregateAccuraciesAndNumImages(
|
||||||
|
@ -56,9 +56,6 @@ class ResultsWriter : public ImagenetModelEvaluator::Observer {
|
|||||||
shard_id_accuracy_metrics_map_;
|
shard_id_accuracy_metrics_map_;
|
||||||
std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
|
std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
|
||||||
|
|
||||||
// TODO(b/146988222): Refactor CSVWriter to take the memory ownership of
|
|
||||||
// 'output_stream_'.
|
|
||||||
std::unique_ptr<std::ofstream> output_stream_;
|
|
||||||
std::unique_ptr<CSVWriter> writer_;
|
std::unique_ptr<CSVWriter> writer_;
|
||||||
|
|
||||||
// For logging to stdout.
|
// For logging to stdout.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user