diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index e5614be2727..bbad9278ac1 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -529,7 +529,6 @@ cc_library( "//tensorflow/core/platform:coding", "//tensorflow/core/platform:random", "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index db9984e02f8..b752c3acdb7 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -965,8 +965,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } for (int i = 0; i < dataset()->num_writer_threads_; ++i) { ++num_active_threads_; - thread_pool_->Schedule( - [this, env = ctx->env()]() { WriterThread(env); }); + thread_pool_->Schedule([this]() { WriterThread(); }); } first_call_ = false; } @@ -1263,8 +1262,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { Status ProcessOneElement(int64* bytes_written, string* snapshot_data_filename, + std::unique_ptr* file, std::unique_ptr* writer, - bool* end_of_processing, Env* env) { + bool* end_of_processing) { profiler::TraceMe activity( [&]() { return absl::StrCat(prefix(), kSeparator, kProcessOneElement); @@ -1296,6 +1296,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (cancelled || snapshot_failed) { TF_RETURN_IF_ERROR((*writer)->Close()); + TF_RETURN_IF_ERROR((*file)->Sync()); + TF_RETURN_IF_ERROR((*file)->Close()); if (snapshot_failed) { return errors::Internal( "SnapshotDataset::SnapshotWriterIterator snapshot failed"); @@ -1310,17 +1312,20 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } bool should_close; - TF_RETURN_IF_ERROR( - ShouldCloseWriter(*snapshot_data_filename, *bytes_written, - (*writer).get(), &should_close)); + TF_RETURN_IF_ERROR(ShouldCloseFile(*snapshot_data_filename, + *bytes_written, (*writer).get(), + (*file).get(), &should_close)); if (should_close) { // If we exceed the shard size, we get a new file and reset. TF_RETURN_IF_ERROR((*writer)->Close()); + TF_RETURN_IF_ERROR((*file)->Sync()); + TF_RETURN_IF_ERROR((*file)->Close()); *snapshot_data_filename = GetSnapshotFilename(); - - TF_RETURN_IF_ERROR(snapshot_util::Writer::Create( - env, *snapshot_data_filename, dataset()->compression_, - kCurrentVersion, dataset()->output_dtypes(), writer)); + TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile( + *snapshot_data_filename, file)); + *writer = absl::make_unique( + file->get(), dataset()->compression_, kCurrentVersion, + dataset()->output_dtypes()); *bytes_written = 0; } TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value)); @@ -1329,6 +1334,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (*end_of_processing) { TF_RETURN_IF_ERROR((*writer)->Close()); + TF_RETURN_IF_ERROR((*file)->Sync()); + TF_RETURN_IF_ERROR((*file)->Close()); mutex_lock l(mu_); if (!written_final_metadata_file_) { experimental::SnapshotMetadataRecord metadata; @@ -1351,7 +1358,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } // Just pulls off elements from the buffer and writes them. - void WriterThread(Env* env) { + void WriterThread() { auto cleanup = gtl::MakeCleanup([this]() { mutex_lock l(mu_); --num_active_threads_; @@ -1360,10 +1367,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { int64 bytes_written = 0; string snapshot_data_filename = GetSnapshotFilename(); - std::unique_ptr writer; - Status s = snapshot_util::Writer::Create( - env, snapshot_data_filename, dataset()->compression_, - kCurrentVersion, dataset()->output_dtypes(), &writer); + std::unique_ptr file; + Status s = + Env::Default()->NewAppendableFile(snapshot_data_filename, &file); if (!s.ok()) { LOG(ERROR) << "Creating " << snapshot_data_filename << " failed: " << s.ToString(); @@ -1372,12 +1378,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { cond_var_.notify_all(); return; } + std::unique_ptr writer( + new snapshot_util::Writer(file.get(), dataset()->compression_, + kCurrentVersion, + dataset()->output_dtypes())); bool end_of_processing = false; while (!end_of_processing) { Status s = ProcessOneElement(&bytes_written, &snapshot_data_filename, - &writer, &end_of_processing, env); + &file, &writer, &end_of_processing); if (!s.ok()) { LOG(INFO) << "Error while writing snapshot data to disk: " << s.ToString(); @@ -1391,9 +1401,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } } - Status ShouldCloseWriter(const string& filename, uint64 bytes_written, - snapshot_util::Writer* writer, - bool* should_close) { + Status ShouldCloseFile(const string& filename, uint64 bytes_written, + snapshot_util::Writer* writer, + WritableFile* file, bool* should_close) { // If the compression ratio has been estimated, use it to decide // whether the file should be closed. We avoid estimating the // compression ratio repeatedly because it requires syncing the file, @@ -1415,6 +1425,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { // Use the actual file size to determine compression ratio. // Make sure that all bytes are written out. TF_RETURN_IF_ERROR(writer->Sync()); + TF_RETURN_IF_ERROR(file->Sync()); uint64 file_size; TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size)); mutex_lock l(mu_); diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc index ba8336653f4..72d2c5cddd9 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/kernels/data/experimental/snapshot_util.h" -#include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -40,45 +39,29 @@ namespace snapshot_util { /* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes; /* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes; -Writer::Writer(const std::string& filename, const std::string& compression_type, - int version, const DataTypeVector& dtypes) - : filename_(filename), - compression_type_(compression_type), - version_(version), - dtypes_(dtypes) {} - -Status Writer::Create(Env* env, const std::string& filename, - const std::string& compression_type, int version, - const DataTypeVector& dtypes, - std::unique_ptr* out_writer) { - *out_writer = - absl::WrapUnique(new Writer(filename, compression_type, version, dtypes)); - - return (*out_writer)->Initialize(env); -} - -Status Writer::Initialize(tensorflow::Env* env) { - TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_)); +Writer::Writer(WritableFile* dest, const string& compression_type, int version, + const DataTypeVector& dtypes) + : dest_(dest), compression_type_(compression_type), version_(version) { #if defined(IS_SLIM_BUILD) - if (compression_type_ != io::compression::kNone) { + if (compression_type != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " << "off compression."; } #else // IS_SLIM_BUILD - if (compression_type_ == io::compression::kGzip) { - zlib_underlying_dest_.swap(dest_); + if (compression_type == io::compression::kGzip) { io::ZlibCompressionOptions zlib_options; zlib_options = io::ZlibCompressionOptions::GZIP(); - io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer( - zlib_underlying_dest_.get(), zlib_options.input_buffer_size, - zlib_options.output_buffer_size, zlib_options); + io::ZlibOutputBuffer* zlib_output_buffer = + new io::ZlibOutputBuffer(dest, zlib_options.input_buffer_size, + zlib_options.output_buffer_size, zlib_options); TF_CHECK_OK(zlib_output_buffer->Init()); - dest_.reset(zlib_output_buffer); + dest_ = zlib_output_buffer; + dest_is_owned_ = true; } #endif // IS_SLIM_BUILD - simple_tensor_mask_.reserve(dtypes_.size()); - for (const auto& dtype : dtypes_) { + simple_tensor_mask_.reserve(dtypes.size()); + for (const auto& dtype : dtypes) { if (DataTypeCanUseMemcpy(dtype)) { simple_tensor_mask_.push_back(true); num_simple_++; @@ -87,8 +70,6 @@ Status Writer::Initialize(tensorflow::Env* env) { num_complex_++; } } - - return Status::OK(); } Status Writer::WriteTensors(const std::vector& tensors) { @@ -175,21 +156,21 @@ Status Writer::WriteTensors(const std::vector& tensors) { Status Writer::Sync() { return dest_->Sync(); } Status Writer::Close() { - if (dest_ != nullptr) { - TF_RETURN_IF_ERROR(dest_->Close()); + if (dest_is_owned_) { + Status s = dest_->Close(); + delete dest_; dest_ = nullptr; - } - if (zlib_underlying_dest_ != nullptr) { - TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close()); - zlib_underlying_dest_ = nullptr; + return s; } return Status::OK(); } Writer::~Writer() { - Status s = Close(); - if (!s.ok()) { - LOG(ERROR) << "Could not finish writing file: " << s; + if (dest_ != nullptr) { + Status s = Close(); + if (!s.ok()) { + LOG(ERROR) << "Could not finish writing file: " << s; + } } } diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h index e1c6dbeb67b..e962bb56380 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.h +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/io/compression.h" #include "tensorflow/core/lib/io/inputstream_interface.h" -#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/status.h" @@ -57,10 +56,8 @@ class Writer { static constexpr const char* const kWriteCord = "WriteCord"; static constexpr const char* const kSeparator = "::"; - static Status Create(Env* env, const std::string& filename, - const std::string& compression_type, int version, - const DataTypeVector& dtypes, - std::unique_ptr* out_writer); + explicit Writer(WritableFile* dest, const string& compression_type, + int version, const DataTypeVector& dtypes); Status WriteTensors(const std::vector& tensors); @@ -71,27 +68,16 @@ class Writer { ~Writer(); private: - explicit Writer(const std::string& filename, - const std::string& compression_type, int version, - const DataTypeVector& dtypes); - - Status Initialize(tensorflow::Env* env); - Status WriteRecord(const StringPiece& data); #if defined(PLATFORM_GOOGLE) Status WriteRecord(const absl::Cord& data); #endif // PLATFORM_GOOGLE - std::unique_ptr dest_; - const std::string filename_; - const std::string compression_type_; + WritableFile* dest_; + bool dest_is_owned_ = false; + const string compression_type_; const int version_; - const DataTypeVector dtypes_; - // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that - // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original - // dest_ and so we need somewhere to store the original one. - std::unique_ptr zlib_underlying_dest_; std::vector simple_tensor_mask_; // true for simple, false for complex. int num_simple_ = 0; int num_complex_ = 0;