diff --git a/tensorflow/lite/tools/accuracy/csv_writer.h b/tensorflow/lite/tools/accuracy/csv_writer.h index 85c0f5c2044..e8f298fd211 100644 --- a/tensorflow/lite/tools/accuracy/csv_writer.h +++ b/tensorflow/lite/tools/accuracy/csv_writer.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ #include +#include #include #include "tensorflow/core/platform/logging.h" @@ -28,15 +29,16 @@ namespace metrics { // columns. This supports a very limited set of CSV spec and doesn't do any // escaping. // Usage: -// std::ofstream * output_stream = ... -// CSVWriter writer({"column1", "column2"}, output_stream); +// std::unqiue_str output_stream = ... +// CSVWriter writer({"column1", "column2"}, std::move(output_stream)); // writer.WriteRow({4, 5}); // writer.Flush(); // flush results immediately. class CSVWriter { public: - CSVWriter(const std::vector& columns, std::ofstream* output_stream) - : num_columns_(columns.size()), output_stream_(output_stream) { - if (WriteRow(columns, output_stream_) != kTfLiteOk) { + CSVWriter(const std::vector& columns, + std::unique_ptr output_stream) + : num_columns_(columns.size()), output_stream_(std::move(output_stream)) { + if (WriteRow(columns, output_stream_.get()) != kTfLiteOk) { LOG(ERROR) << "Could not write column names to file"; } } @@ -48,7 +50,7 @@ class CSVWriter { << " expected: " << num_columns_; return kTfLiteError; } - return WriteRow(values, output_stream_); + return WriteRow(values, output_stream_.get()); } void Flush() { output_stream_->flush(); } @@ -76,7 +78,7 @@ class CSVWriter { return kTfLiteOk; } const size_t num_columns_; - std::ofstream* output_stream_; + std::unique_ptr output_stream_; }; } // namespace metrics } // namespace tensorflow diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc index 9139cfc5def..eb1ad42e8e0 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/memory/memory.h" #include "tensorflow/core/platform/logging.h" @@ -38,13 +39,14 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path) return; } - output_stream_.reset(new std::ofstream(output_file_path, std::ios::out)); - if (!output_stream_) { + std::unique_ptr output_stream( + new std::ofstream(output_file_path, std::ios::out)); + if (!output_stream) { LOG(ERROR) << "Unable to open output file path: '" << output_file_path << "'"; } - (*output_stream_) << std::setprecision(3) << std::fixed; + (*output_stream) << std::setprecision(3) << std::fixed; std::vector columns; columns.reserve(top_k); for (int i = 0; i < top_k; i++) { @@ -53,7 +55,7 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path) columns.push_back(column_name); } - writer_.reset(new CSVWriter(columns, output_stream_.get())); + writer_.reset(new CSVWriter(columns, std::move(output_stream))); } void ResultsWriter::AggregateAccuraciesAndNumImages( diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h index 6e3d614353f..f764a6bb8b7 100644 --- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h +++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h @@ -56,9 +56,6 @@ class ResultsWriter : public ImagenetModelEvaluator::Observer { shard_id_accuracy_metrics_map_; std::unordered_map shard_id_done_image_count_map_; - // TODO(b/146988222): Refactor CSVWriter to take the memory ownership of - // 'output_stream_'. - std::unique_ptr output_stream_; std::unique_ptr writer_; // For logging to stdout.