Update snapshot_util::Writer to own WritableFiles instead borrowing

PiperOrigin-RevId: 306099307
Change-Id: I29ad3f26080ab6419c93b9436f74ca0e53520a6d
This commit is contained in:
A. Unique TensorFlower 2020-04-12 01:21:20 -07:00 committed by TensorFlower Gardener
parent 3364a141f2
commit a192cf7dad
4 changed files with 56 additions and 79 deletions
tensorflow/core/kernels/data/experimental

View File

@ -529,7 +529,6 @@ cc_library(
"//tensorflow/core/platform:coding", "//tensorflow/core/platform:coding",
"//tensorflow/core/platform:random", "//tensorflow/core/platform:random",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory",
], ],
) )

View File

@ -965,8 +965,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
for (int i = 0; i < dataset()->num_writer_threads_; ++i) { for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
++num_active_threads_; ++num_active_threads_;
thread_pool_->Schedule( thread_pool_->Schedule([this]() { WriterThread(); });
[this, env = ctx->env()]() { WriterThread(env); });
} }
first_call_ = false; first_call_ = false;
} }
@ -1263,8 +1262,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
Status ProcessOneElement(int64* bytes_written, Status ProcessOneElement(int64* bytes_written,
string* snapshot_data_filename, string* snapshot_data_filename,
std::unique_ptr<WritableFile>* file,
std::unique_ptr<snapshot_util::Writer>* writer, std::unique_ptr<snapshot_util::Writer>* writer,
bool* end_of_processing, Env* env) { bool* end_of_processing) {
profiler::TraceMe activity( profiler::TraceMe activity(
[&]() { [&]() {
return absl::StrCat(prefix(), kSeparator, kProcessOneElement); return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
@ -1296,6 +1296,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (cancelled || snapshot_failed) { if (cancelled || snapshot_failed) {
TF_RETURN_IF_ERROR((*writer)->Close()); TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Sync());
TF_RETURN_IF_ERROR((*file)->Close());
if (snapshot_failed) { if (snapshot_failed) {
return errors::Internal( return errors::Internal(
"SnapshotDataset::SnapshotWriterIterator snapshot failed"); "SnapshotDataset::SnapshotWriterIterator snapshot failed");
@ -1310,17 +1312,20 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
bool should_close; bool should_close;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ShouldCloseFile(*snapshot_data_filename,
ShouldCloseWriter(*snapshot_data_filename, *bytes_written, *bytes_written, (*writer).get(),
(*writer).get(), &should_close)); (*file).get(), &should_close));
if (should_close) { if (should_close) {
// If we exceed the shard size, we get a new file and reset. // If we exceed the shard size, we get a new file and reset.
TF_RETURN_IF_ERROR((*writer)->Close()); TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Sync());
TF_RETURN_IF_ERROR((*file)->Close());
*snapshot_data_filename = GetSnapshotFilename(); *snapshot_data_filename = GetSnapshotFilename();
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create( *snapshot_data_filename, file));
env, *snapshot_data_filename, dataset()->compression_, *writer = absl::make_unique<snapshot_util::Writer>(
kCurrentVersion, dataset()->output_dtypes(), writer)); file->get(), dataset()->compression_, kCurrentVersion,
dataset()->output_dtypes());
*bytes_written = 0; *bytes_written = 0;
} }
TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value)); TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
@ -1329,6 +1334,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (*end_of_processing) { if (*end_of_processing) {
TF_RETURN_IF_ERROR((*writer)->Close()); TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Sync());
TF_RETURN_IF_ERROR((*file)->Close());
mutex_lock l(mu_); mutex_lock l(mu_);
if (!written_final_metadata_file_) { if (!written_final_metadata_file_) {
experimental::SnapshotMetadataRecord metadata; experimental::SnapshotMetadataRecord metadata;
@ -1351,7 +1358,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
// Just pulls off elements from the buffer and writes them. // Just pulls off elements from the buffer and writes them.
void WriterThread(Env* env) { void WriterThread() {
auto cleanup = gtl::MakeCleanup([this]() { auto cleanup = gtl::MakeCleanup([this]() {
mutex_lock l(mu_); mutex_lock l(mu_);
--num_active_threads_; --num_active_threads_;
@ -1360,10 +1367,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
int64 bytes_written = 0; int64 bytes_written = 0;
string snapshot_data_filename = GetSnapshotFilename(); string snapshot_data_filename = GetSnapshotFilename();
std::unique_ptr<snapshot_util::Writer> writer; std::unique_ptr<WritableFile> file;
Status s = snapshot_util::Writer::Create( Status s =
env, snapshot_data_filename, dataset()->compression_, Env::Default()->NewAppendableFile(snapshot_data_filename, &file);
kCurrentVersion, dataset()->output_dtypes(), &writer);
if (!s.ok()) { if (!s.ok()) {
LOG(ERROR) << "Creating " << snapshot_data_filename LOG(ERROR) << "Creating " << snapshot_data_filename
<< " failed: " << s.ToString(); << " failed: " << s.ToString();
@ -1372,12 +1378,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
cond_var_.notify_all(); cond_var_.notify_all();
return; return;
} }
std::unique_ptr<snapshot_util::Writer> writer(
new snapshot_util::Writer(file.get(), dataset()->compression_,
kCurrentVersion,
dataset()->output_dtypes()));
bool end_of_processing = false; bool end_of_processing = false;
while (!end_of_processing) { while (!end_of_processing) {
Status s = Status s =
ProcessOneElement(&bytes_written, &snapshot_data_filename, ProcessOneElement(&bytes_written, &snapshot_data_filename,
&writer, &end_of_processing, env); &file, &writer, &end_of_processing);
if (!s.ok()) { if (!s.ok()) {
LOG(INFO) << "Error while writing snapshot data to disk: " LOG(INFO) << "Error while writing snapshot data to disk: "
<< s.ToString(); << s.ToString();
@ -1391,9 +1401,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
} }
Status ShouldCloseWriter(const string& filename, uint64 bytes_written, Status ShouldCloseFile(const string& filename, uint64 bytes_written,
snapshot_util::Writer* writer, snapshot_util::Writer* writer,
bool* should_close) { WritableFile* file, bool* should_close) {
// If the compression ratio has been estimated, use it to decide // If the compression ratio has been estimated, use it to decide
// whether the file should be closed. We avoid estimating the // whether the file should be closed. We avoid estimating the
// compression ratio repeatedly because it requires syncing the file, // compression ratio repeatedly because it requires syncing the file,
@ -1415,6 +1425,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
// Use the actual file size to determine compression ratio. // Use the actual file size to determine compression ratio.
// Make sure that all bytes are written out. // Make sure that all bytes are written out.
TF_RETURN_IF_ERROR(writer->Sync()); TF_RETURN_IF_ERROR(writer->Sync());
TF_RETURN_IF_ERROR(file->Sync());
uint64 file_size; uint64 file_size;
TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size)); TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size));
mutex_lock l(mu_); mutex_lock l(mu_);

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h" #include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
@ -40,45 +39,29 @@ namespace snapshot_util {
/* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes; /* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes;
/* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes; /* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes;
Writer::Writer(const std::string& filename, const std::string& compression_type, Writer::Writer(WritableFile* dest, const string& compression_type, int version,
int version, const DataTypeVector& dtypes) const DataTypeVector& dtypes)
: filename_(filename), : dest_(dest), compression_type_(compression_type), version_(version) {
compression_type_(compression_type),
version_(version),
dtypes_(dtypes) {}
Status Writer::Create(Env* env, const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Writer>* out_writer) {
*out_writer =
absl::WrapUnique(new Writer(filename, compression_type, version, dtypes));
return (*out_writer)->Initialize(env);
}
Status Writer::Initialize(tensorflow::Env* env) {
TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_));
#if defined(IS_SLIM_BUILD) #if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) { if (compression_type != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
<< "off compression."; << "off compression.";
} }
#else // IS_SLIM_BUILD #else // IS_SLIM_BUILD
if (compression_type_ == io::compression::kGzip) { if (compression_type == io::compression::kGzip) {
zlib_underlying_dest_.swap(dest_);
io::ZlibCompressionOptions zlib_options; io::ZlibCompressionOptions zlib_options;
zlib_options = io::ZlibCompressionOptions::GZIP(); zlib_options = io::ZlibCompressionOptions::GZIP();
io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer( io::ZlibOutputBuffer* zlib_output_buffer =
zlib_underlying_dest_.get(), zlib_options.input_buffer_size, new io::ZlibOutputBuffer(dest, zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options); zlib_options.output_buffer_size, zlib_options);
TF_CHECK_OK(zlib_output_buffer->Init()); TF_CHECK_OK(zlib_output_buffer->Init());
dest_.reset(zlib_output_buffer); dest_ = zlib_output_buffer;
dest_is_owned_ = true;
} }
#endif // IS_SLIM_BUILD #endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes_.size()); simple_tensor_mask_.reserve(dtypes.size());
for (const auto& dtype : dtypes_) { for (const auto& dtype : dtypes) {
if (DataTypeCanUseMemcpy(dtype)) { if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true); simple_tensor_mask_.push_back(true);
num_simple_++; num_simple_++;
@ -87,8 +70,6 @@ Status Writer::Initialize(tensorflow::Env* env) {
num_complex_++; num_complex_++;
} }
} }
return Status::OK();
} }
Status Writer::WriteTensors(const std::vector<Tensor>& tensors) { Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
@ -175,21 +156,21 @@ Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
Status Writer::Sync() { return dest_->Sync(); } Status Writer::Sync() { return dest_->Sync(); }
Status Writer::Close() { Status Writer::Close() {
if (dest_ != nullptr) { if (dest_is_owned_) {
TF_RETURN_IF_ERROR(dest_->Close()); Status s = dest_->Close();
delete dest_;
dest_ = nullptr; dest_ = nullptr;
} return s;
if (zlib_underlying_dest_ != nullptr) {
TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
zlib_underlying_dest_ = nullptr;
} }
return Status::OK(); return Status::OK();
} }
Writer::~Writer() { Writer::~Writer() {
Status s = Close(); if (dest_ != nullptr) {
if (!s.ok()) { Status s = Close();
LOG(ERROR) << "Could not finish writing file: " << s; if (!s.ok()) {
LOG(ERROR) << "Could not finish writing file: " << s;
}
} }
} }

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/io/compression.h" #include "tensorflow/core/lib/io/compression.h"
#include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
@ -57,10 +56,8 @@ class Writer {
static constexpr const char* const kWriteCord = "WriteCord"; static constexpr const char* const kWriteCord = "WriteCord";
static constexpr const char* const kSeparator = "::"; static constexpr const char* const kSeparator = "::";
static Status Create(Env* env, const std::string& filename, explicit Writer(WritableFile* dest, const string& compression_type,
const std::string& compression_type, int version, int version, const DataTypeVector& dtypes);
const DataTypeVector& dtypes,
std::unique_ptr<Writer>* out_writer);
Status WriteTensors(const std::vector<Tensor>& tensors); Status WriteTensors(const std::vector<Tensor>& tensors);
@ -71,27 +68,16 @@ class Writer {
~Writer(); ~Writer();
private: private:
explicit Writer(const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes);
Status Initialize(tensorflow::Env* env);
Status WriteRecord(const StringPiece& data); Status WriteRecord(const StringPiece& data);
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
Status WriteRecord(const absl::Cord& data); Status WriteRecord(const absl::Cord& data);
#endif // PLATFORM_GOOGLE #endif // PLATFORM_GOOGLE
std::unique_ptr<WritableFile> dest_; WritableFile* dest_;
const std::string filename_; bool dest_is_owned_ = false;
const std::string compression_type_; const string compression_type_;
const int version_; const int version_;
const DataTypeVector dtypes_;
// We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
// in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
// dest_ and so we need somewhere to store the original one.
std::unique_ptr<WritableFile> zlib_underlying_dest_;
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
int num_simple_ = 0; int num_simple_ = 0;
int num_complex_ = 0; int num_complex_ = 0;