From a192cf7dadc0ce6b2a6f8d09d3b9230772602ea0 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Sun, 12 Apr 2020 01:21:20 -0700
Subject: [PATCH] Update snapshot_util::Writer to own WritableFiles instead
 borrowing

PiperOrigin-RevId: 306099307
Change-Id: I29ad3f26080ab6419c93b9436f74ca0e53520a6d
---
 .../core/kernels/data/experimental/BUILD      |  1 -
 .../data/experimental/snapshot_dataset_op.cc  | 49 +++++++++------
 .../data/experimental/snapshot_util.cc        | 61 +++++++------------
 .../kernels/data/experimental/snapshot_util.h | 24 ++------
 4 files changed, 56 insertions(+), 79 deletions(-)

diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index e5614be2727..bbad9278ac1 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -529,7 +529,6 @@ cc_library(
         "//tensorflow/core/platform:coding",
         "//tensorflow/core/platform:random",
         "//tensorflow/core/profiler/lib:traceme",
-        "@com_google_absl//absl/memory",
     ],
 )
 
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
index db9984e02f8..b752c3acdb7 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
@@ -965,8 +965,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
               }
               for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
                 ++num_active_threads_;
-                thread_pool_->Schedule(
-                    [this, env = ctx->env()]() { WriterThread(env); });
+                thread_pool_->Schedule([this]() { WriterThread(); });
               }
               first_call_ = false;
             }
@@ -1263,8 +1262,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
 
         Status ProcessOneElement(int64* bytes_written,
                                  string* snapshot_data_filename,
+                                 std::unique_ptr<WritableFile>* file,
                                  std::unique_ptr<snapshot_util::Writer>* writer,
-                                 bool* end_of_processing, Env* env) {
+                                 bool* end_of_processing) {
           profiler::TraceMe activity(
               [&]() {
                 return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
@@ -1296,6 +1296,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
 
           if (cancelled || snapshot_failed) {
             TF_RETURN_IF_ERROR((*writer)->Close());
+            TF_RETURN_IF_ERROR((*file)->Sync());
+            TF_RETURN_IF_ERROR((*file)->Close());
             if (snapshot_failed) {
               return errors::Internal(
                   "SnapshotDataset::SnapshotWriterIterator snapshot failed");
@@ -1310,17 +1312,20 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
             }
 
             bool should_close;
-            TF_RETURN_IF_ERROR(
-                ShouldCloseWriter(*snapshot_data_filename, *bytes_written,
-                                  (*writer).get(), &should_close));
+            TF_RETURN_IF_ERROR(ShouldCloseFile(*snapshot_data_filename,
+                                               *bytes_written, (*writer).get(),
+                                               (*file).get(), &should_close));
             if (should_close) {
               // If we exceed the shard size, we get a new file and reset.
               TF_RETURN_IF_ERROR((*writer)->Close());
+              TF_RETURN_IF_ERROR((*file)->Sync());
+              TF_RETURN_IF_ERROR((*file)->Close());
               *snapshot_data_filename = GetSnapshotFilename();
-
-              TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
-                  env, *snapshot_data_filename, dataset()->compression_,
-                  kCurrentVersion, dataset()->output_dtypes(), writer));
+              TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
+                  *snapshot_data_filename, file));
+              *writer = absl::make_unique<snapshot_util::Writer>(
+                  file->get(), dataset()->compression_, kCurrentVersion,
+                  dataset()->output_dtypes());
               *bytes_written = 0;
             }
             TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
@@ -1329,6 +1334,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
 
           if (*end_of_processing) {
             TF_RETURN_IF_ERROR((*writer)->Close());
+            TF_RETURN_IF_ERROR((*file)->Sync());
+            TF_RETURN_IF_ERROR((*file)->Close());
             mutex_lock l(mu_);
             if (!written_final_metadata_file_) {
               experimental::SnapshotMetadataRecord metadata;
@@ -1351,7 +1358,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
         }
 
         // Just pulls off elements from the buffer and writes them.
-        void WriterThread(Env* env) {
+        void WriterThread() {
           auto cleanup = gtl::MakeCleanup([this]() {
             mutex_lock l(mu_);
             --num_active_threads_;
@@ -1360,10 +1367,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
 
           int64 bytes_written = 0;
           string snapshot_data_filename = GetSnapshotFilename();
-          std::unique_ptr<snapshot_util::Writer> writer;
-          Status s = snapshot_util::Writer::Create(
-              env, snapshot_data_filename, dataset()->compression_,
-              kCurrentVersion, dataset()->output_dtypes(), &writer);
+          std::unique_ptr<WritableFile> file;
+          Status s =
+              Env::Default()->NewAppendableFile(snapshot_data_filename, &file);
           if (!s.ok()) {
             LOG(ERROR) << "Creating " << snapshot_data_filename
                        << " failed: " << s.ToString();
@@ -1372,12 +1378,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
             cond_var_.notify_all();
             return;
           }
+          std::unique_ptr<snapshot_util::Writer> writer(
+              new snapshot_util::Writer(file.get(), dataset()->compression_,
+                                        kCurrentVersion,
+                                        dataset()->output_dtypes()));
 
           bool end_of_processing = false;
           while (!end_of_processing) {
             Status s =
                 ProcessOneElement(&bytes_written, &snapshot_data_filename,
-                                  &writer, &end_of_processing, env);
+                                  &file, &writer, &end_of_processing);
             if (!s.ok()) {
               LOG(INFO) << "Error while writing snapshot data to disk: "
                         << s.ToString();
@@ -1391,9 +1401,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
           }
         }
 
-        Status ShouldCloseWriter(const string& filename, uint64 bytes_written,
-                                 snapshot_util::Writer* writer,
-                                 bool* should_close) {
+        Status ShouldCloseFile(const string& filename, uint64 bytes_written,
+                               snapshot_util::Writer* writer,
+                               WritableFile* file, bool* should_close) {
           // If the compression ratio has been estimated, use it to decide
           // whether the file should be closed. We avoid estimating the
           // compression ratio repeatedly because it requires syncing the file,
@@ -1415,6 +1425,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
           // Use the actual file size to determine compression ratio.
           // Make sure that all bytes are written out.
           TF_RETURN_IF_ERROR(writer->Sync());
+          TF_RETURN_IF_ERROR(file->Sync());
           uint64 file_size;
           TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size));
           mutex_lock l(mu_);
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc
index ba8336653f4..72d2c5cddd9 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc
@@ -15,7 +15,6 @@ limitations under the License.
 
 #include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
 
-#include "absl/memory/memory.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/tensor.pb.h"
@@ -40,45 +39,29 @@ namespace snapshot_util {
 /* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes;
 /* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes;
 
-Writer::Writer(const std::string& filename, const std::string& compression_type,
-               int version, const DataTypeVector& dtypes)
-    : filename_(filename),
-      compression_type_(compression_type),
-      version_(version),
-      dtypes_(dtypes) {}
-
-Status Writer::Create(Env* env, const std::string& filename,
-                      const std::string& compression_type, int version,
-                      const DataTypeVector& dtypes,
-                      std::unique_ptr<Writer>* out_writer) {
-  *out_writer =
-      absl::WrapUnique(new Writer(filename, compression_type, version, dtypes));
-
-  return (*out_writer)->Initialize(env);
-}
-
-Status Writer::Initialize(tensorflow::Env* env) {
-  TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_));
+Writer::Writer(WritableFile* dest, const string& compression_type, int version,
+               const DataTypeVector& dtypes)
+    : dest_(dest), compression_type_(compression_type), version_(version) {
 #if defined(IS_SLIM_BUILD)
-  if (compression_type_ != io::compression::kNone) {
+  if (compression_type != io::compression::kNone) {
     LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
                << "off compression.";
   }
 #else   // IS_SLIM_BUILD
-  if (compression_type_ == io::compression::kGzip) {
-    zlib_underlying_dest_.swap(dest_);
+  if (compression_type == io::compression::kGzip) {
     io::ZlibCompressionOptions zlib_options;
     zlib_options = io::ZlibCompressionOptions::GZIP();
 
-    io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
-        zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
-        zlib_options.output_buffer_size, zlib_options);
+    io::ZlibOutputBuffer* zlib_output_buffer =
+        new io::ZlibOutputBuffer(dest, zlib_options.input_buffer_size,
+                                 zlib_options.output_buffer_size, zlib_options);
     TF_CHECK_OK(zlib_output_buffer->Init());
-    dest_.reset(zlib_output_buffer);
+    dest_ = zlib_output_buffer;
+    dest_is_owned_ = true;
   }
 #endif  // IS_SLIM_BUILD
-  simple_tensor_mask_.reserve(dtypes_.size());
-  for (const auto& dtype : dtypes_) {
+  simple_tensor_mask_.reserve(dtypes.size());
+  for (const auto& dtype : dtypes) {
     if (DataTypeCanUseMemcpy(dtype)) {
       simple_tensor_mask_.push_back(true);
       num_simple_++;
@@ -87,8 +70,6 @@ Status Writer::Initialize(tensorflow::Env* env) {
       num_complex_++;
     }
   }
-
-  return Status::OK();
 }
 
 Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
@@ -175,21 +156,21 @@ Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
 Status Writer::Sync() { return dest_->Sync(); }
 
 Status Writer::Close() {
-  if (dest_ != nullptr) {
-    TF_RETURN_IF_ERROR(dest_->Close());
+  if (dest_is_owned_) {
+    Status s = dest_->Close();
+    delete dest_;
     dest_ = nullptr;
-  }
-  if (zlib_underlying_dest_ != nullptr) {
-    TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
-    zlib_underlying_dest_ = nullptr;
+    return s;
   }
   return Status::OK();
 }
 
 Writer::~Writer() {
-  Status s = Close();
-  if (!s.ok()) {
-    LOG(ERROR) << "Could not finish writing file: " << s;
+  if (dest_ != nullptr) {
+    Status s = Close();
+    if (!s.ok()) {
+      LOG(ERROR) << "Could not finish writing file: " << s;
+    }
   }
 }
 
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h
index e1c6dbeb67b..e962bb56380 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_util.h
+++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h
@@ -20,7 +20,6 @@ limitations under the License.
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/io/compression.h"
 #include "tensorflow/core/lib/io/inputstream_interface.h"
-#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/file_system.h"
 #include "tensorflow/core/platform/status.h"
 
@@ -57,10 +56,8 @@ class Writer {
   static constexpr const char* const kWriteCord = "WriteCord";
   static constexpr const char* const kSeparator = "::";
 
-  static Status Create(Env* env, const std::string& filename,
-                       const std::string& compression_type, int version,
-                       const DataTypeVector& dtypes,
-                       std::unique_ptr<Writer>* out_writer);
+  explicit Writer(WritableFile* dest, const string& compression_type,
+                  int version, const DataTypeVector& dtypes);
 
   Status WriteTensors(const std::vector<Tensor>& tensors);
 
@@ -71,27 +68,16 @@ class Writer {
   ~Writer();
 
  private:
-  explicit Writer(const std::string& filename,
-                  const std::string& compression_type, int version,
-                  const DataTypeVector& dtypes);
-
-  Status Initialize(tensorflow::Env* env);
-
   Status WriteRecord(const StringPiece& data);
 
 #if defined(PLATFORM_GOOGLE)
   Status WriteRecord(const absl::Cord& data);
 #endif  // PLATFORM_GOOGLE
 
-  std::unique_ptr<WritableFile> dest_;
-  const std::string filename_;
-  const std::string compression_type_;
+  WritableFile* dest_;
+  bool dest_is_owned_ = false;
+  const string compression_type_;
   const int version_;
-  const DataTypeVector dtypes_;
-  // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
-  // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
-  // dest_ and so we need somewhere to store the original one.
-  std::unique_ptr<WritableFile> zlib_underlying_dest_;
   std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
   int num_simple_ = 0;
   int num_complex_ = 0;