Refactor: take the memory ownership of output stream when creating a CSVWriter instance.
PiperOrigin-RevId: 287781403 Change-Id: I934dd83f16a0723990624184bcc5a2fbc7fa2fdc
This commit is contained in:
		
							parent
							
								
									68cbb78a53
								
							
						
					
					
						commit
						2712ab1563
					
				@ -17,6 +17,7 @@ limitations under the License.
 | 
			
		||||
#define TENSORFLOW_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
 | 
			
		||||
 | 
			
		||||
#include <fstream>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/platform/logging.h"
 | 
			
		||||
@ -28,15 +29,16 @@ namespace metrics {
 | 
			
		||||
// columns. This supports a very limited set of CSV spec and doesn't do any
 | 
			
		||||
// escaping.
 | 
			
		||||
// Usage:
 | 
			
		||||
// std::ofstream * output_stream = ...
 | 
			
		||||
// CSVWriter writer({"column1", "column2"}, output_stream);
 | 
			
		||||
// std::unqiue_str<std::ofstream> output_stream = ...
 | 
			
		||||
// CSVWriter writer({"column1", "column2"}, std::move(output_stream));
 | 
			
		||||
// writer.WriteRow({4, 5});
 | 
			
		||||
// writer.Flush(); // flush results immediately.
 | 
			
		||||
class CSVWriter {
 | 
			
		||||
 public:
 | 
			
		||||
  CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
 | 
			
		||||
      : num_columns_(columns.size()), output_stream_(output_stream) {
 | 
			
		||||
    if (WriteRow(columns, output_stream_) != kTfLiteOk) {
 | 
			
		||||
  CSVWriter(const std::vector<string>& columns,
 | 
			
		||||
            std::unique_ptr<std::ofstream> output_stream)
 | 
			
		||||
      : num_columns_(columns.size()), output_stream_(std::move(output_stream)) {
 | 
			
		||||
    if (WriteRow(columns, output_stream_.get()) != kTfLiteOk) {
 | 
			
		||||
      LOG(ERROR) << "Could not write column names to file";
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -48,7 +50,7 @@ class CSVWriter {
 | 
			
		||||
                 << " expected: " << num_columns_;
 | 
			
		||||
      return kTfLiteError;
 | 
			
		||||
    }
 | 
			
		||||
    return WriteRow(values, output_stream_);
 | 
			
		||||
    return WriteRow(values, output_stream_.get());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void Flush() { output_stream_->flush(); }
 | 
			
		||||
@ -76,7 +78,7 @@ class CSVWriter {
 | 
			
		||||
    return kTfLiteOk;
 | 
			
		||||
  }
 | 
			
		||||
  const size_t num_columns_;
 | 
			
		||||
  std::ofstream* output_stream_;
 | 
			
		||||
  std::unique_ptr<std::ofstream> output_stream_;
 | 
			
		||||
};
 | 
			
		||||
}  // namespace metrics
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -17,6 +17,7 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <iomanip>
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "absl/memory/memory.h"
 | 
			
		||||
#include "tensorflow/core/platform/logging.h"
 | 
			
		||||
@ -38,13 +39,14 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  output_stream_.reset(new std::ofstream(output_file_path, std::ios::out));
 | 
			
		||||
  if (!output_stream_) {
 | 
			
		||||
  std::unique_ptr<std::ofstream> output_stream(
 | 
			
		||||
      new std::ofstream(output_file_path, std::ios::out));
 | 
			
		||||
  if (!output_stream) {
 | 
			
		||||
    LOG(ERROR) << "Unable to open output file path: '" << output_file_path
 | 
			
		||||
               << "'";
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  (*output_stream_) << std::setprecision(3) << std::fixed;
 | 
			
		||||
  (*output_stream) << std::setprecision(3) << std::fixed;
 | 
			
		||||
  std::vector<string> columns;
 | 
			
		||||
  columns.reserve(top_k);
 | 
			
		||||
  for (int i = 0; i < top_k; i++) {
 | 
			
		||||
@ -53,7 +55,7 @@ ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
 | 
			
		||||
    columns.push_back(column_name);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  writer_.reset(new CSVWriter(columns, output_stream_.get()));
 | 
			
		||||
  writer_.reset(new CSVWriter(columns, std::move(output_stream)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ResultsWriter::AggregateAccuraciesAndNumImages(
 | 
			
		||||
 | 
			
		||||
@ -56,9 +56,6 @@ class ResultsWriter : public ImagenetModelEvaluator::Observer {
 | 
			
		||||
      shard_id_accuracy_metrics_map_;
 | 
			
		||||
  std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
 | 
			
		||||
 | 
			
		||||
  // TODO(b/146988222): Refactor CSVWriter to take the memory ownership of
 | 
			
		||||
  // 'output_stream_'.
 | 
			
		||||
  std::unique_ptr<std::ofstream> output_stream_;
 | 
			
		||||
  std::unique_ptr<CSVWriter> writer_;
 | 
			
		||||
 | 
			
		||||
  // For logging to stdout.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user