Update snapshot_util::Writer to own WritableFiles instead borrowing

PiperOrigin-RevId: 306099307
Change-Id: I29ad3f26080ab6419c93b9436f74ca0e53520a6d
This commit is contained in:
A. Unique TensorFlower 2020-04-12 01:21:20 -07:00 committed by TensorFlower Gardener
parent 3364a141f2
commit a192cf7dad
4 changed files with 56 additions and 79 deletions

View File

@ -529,7 +529,6 @@ cc_library(
"//tensorflow/core/platform:coding",
"//tensorflow/core/platform:random",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory",
],
)

View File

@ -965,8 +965,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
++num_active_threads_;
thread_pool_->Schedule(
[this, env = ctx->env()]() { WriterThread(env); });
thread_pool_->Schedule([this]() { WriterThread(); });
}
first_call_ = false;
}
@ -1263,8 +1262,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
Status ProcessOneElement(int64* bytes_written,
string* snapshot_data_filename,
std::unique_ptr<WritableFile>* file,
std::unique_ptr<snapshot_util::Writer>* writer,
bool* end_of_processing, Env* env) {
bool* end_of_processing) {
profiler::TraceMe activity(
[&]() {
return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
@ -1296,6 +1296,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (cancelled || snapshot_failed) {
TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Sync());
TF_RETURN_IF_ERROR((*file)->Close());
if (snapshot_failed) {
return errors::Internal(
"SnapshotDataset::SnapshotWriterIterator snapshot failed");
@ -1310,17 +1312,20 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
bool should_close;
TF_RETURN_IF_ERROR(
ShouldCloseWriter(*snapshot_data_filename, *bytes_written,
(*writer).get(), &should_close));
TF_RETURN_IF_ERROR(ShouldCloseFile(*snapshot_data_filename,
*bytes_written, (*writer).get(),
(*file).get(), &should_close));
if (should_close) {
// If we exceed the shard size, we get a new file and reset.
TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Sync());
TF_RETURN_IF_ERROR((*file)->Close());
*snapshot_data_filename = GetSnapshotFilename();
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
env, *snapshot_data_filename, dataset()->compression_,
kCurrentVersion, dataset()->output_dtypes(), writer));
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
*snapshot_data_filename, file));
*writer = absl::make_unique<snapshot_util::Writer>(
file->get(), dataset()->compression_, kCurrentVersion,
dataset()->output_dtypes());
*bytes_written = 0;
}
TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
@ -1329,6 +1334,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (*end_of_processing) {
TF_RETURN_IF_ERROR((*writer)->Close());
TF_RETURN_IF_ERROR((*file)->Sync());
TF_RETURN_IF_ERROR((*file)->Close());
mutex_lock l(mu_);
if (!written_final_metadata_file_) {
experimental::SnapshotMetadataRecord metadata;
@ -1351,7 +1358,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
// Just pulls off elements from the buffer and writes them.
void WriterThread(Env* env) {
void WriterThread() {
auto cleanup = gtl::MakeCleanup([this]() {
mutex_lock l(mu_);
--num_active_threads_;
@ -1360,10 +1367,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
int64 bytes_written = 0;
string snapshot_data_filename = GetSnapshotFilename();
std::unique_ptr<snapshot_util::Writer> writer;
Status s = snapshot_util::Writer::Create(
env, snapshot_data_filename, dataset()->compression_,
kCurrentVersion, dataset()->output_dtypes(), &writer);
std::unique_ptr<WritableFile> file;
Status s =
Env::Default()->NewAppendableFile(snapshot_data_filename, &file);
if (!s.ok()) {
LOG(ERROR) << "Creating " << snapshot_data_filename
<< " failed: " << s.ToString();
@ -1372,12 +1378,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
cond_var_.notify_all();
return;
}
std::unique_ptr<snapshot_util::Writer> writer(
new snapshot_util::Writer(file.get(), dataset()->compression_,
kCurrentVersion,
dataset()->output_dtypes()));
bool end_of_processing = false;
while (!end_of_processing) {
Status s =
ProcessOneElement(&bytes_written, &snapshot_data_filename,
&writer, &end_of_processing, env);
&file, &writer, &end_of_processing);
if (!s.ok()) {
LOG(INFO) << "Error while writing snapshot data to disk: "
<< s.ToString();
@ -1391,9 +1401,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
}
Status ShouldCloseWriter(const string& filename, uint64 bytes_written,
snapshot_util::Writer* writer,
bool* should_close) {
Status ShouldCloseFile(const string& filename, uint64 bytes_written,
snapshot_util::Writer* writer,
WritableFile* file, bool* should_close) {
// If the compression ratio has been estimated, use it to decide
// whether the file should be closed. We avoid estimating the
// compression ratio repeatedly because it requires syncing the file,
@ -1415,6 +1425,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
// Use the actual file size to determine compression ratio.
// Make sure that all bytes are written out.
TF_RETURN_IF_ERROR(writer->Sync());
TF_RETURN_IF_ERROR(file->Sync());
uint64 file_size;
TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size));
mutex_lock l(mu_);

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
@ -40,45 +39,29 @@ namespace snapshot_util {
/* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes;
/* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes;
Writer::Writer(const std::string& filename, const std::string& compression_type,
int version, const DataTypeVector& dtypes)
: filename_(filename),
compression_type_(compression_type),
version_(version),
dtypes_(dtypes) {}
Status Writer::Create(Env* env, const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Writer>* out_writer) {
*out_writer =
absl::WrapUnique(new Writer(filename, compression_type, version, dtypes));
return (*out_writer)->Initialize(env);
}
Status Writer::Initialize(tensorflow::Env* env) {
TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_));
Writer::Writer(WritableFile* dest, const string& compression_type, int version,
const DataTypeVector& dtypes)
: dest_(dest), compression_type_(compression_type), version_(version) {
#if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) {
if (compression_type != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
<< "off compression.";
}
#else // IS_SLIM_BUILD
if (compression_type_ == io::compression::kGzip) {
zlib_underlying_dest_.swap(dest_);
if (compression_type == io::compression::kGzip) {
io::ZlibCompressionOptions zlib_options;
zlib_options = io::ZlibCompressionOptions::GZIP();
io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options);
io::ZlibOutputBuffer* zlib_output_buffer =
new io::ZlibOutputBuffer(dest, zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options);
TF_CHECK_OK(zlib_output_buffer->Init());
dest_.reset(zlib_output_buffer);
dest_ = zlib_output_buffer;
dest_is_owned_ = true;
}
#endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes_.size());
for (const auto& dtype : dtypes_) {
simple_tensor_mask_.reserve(dtypes.size());
for (const auto& dtype : dtypes) {
if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true);
num_simple_++;
@ -87,8 +70,6 @@ Status Writer::Initialize(tensorflow::Env* env) {
num_complex_++;
}
}
return Status::OK();
}
Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
@ -175,21 +156,21 @@ Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
Status Writer::Sync() { return dest_->Sync(); }
Status Writer::Close() {
if (dest_ != nullptr) {
TF_RETURN_IF_ERROR(dest_->Close());
if (dest_is_owned_) {
Status s = dest_->Close();
delete dest_;
dest_ = nullptr;
}
if (zlib_underlying_dest_ != nullptr) {
TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
zlib_underlying_dest_ = nullptr;
return s;
}
return Status::OK();
}
Writer::~Writer() {
Status s = Close();
if (!s.ok()) {
LOG(ERROR) << "Could not finish writing file: " << s;
if (dest_ != nullptr) {
Status s = Close();
if (!s.ok()) {
LOG(ERROR) << "Could not finish writing file: " << s;
}
}
}

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/io/compression.h"
#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/status.h"
@ -57,10 +56,8 @@ class Writer {
static constexpr const char* const kWriteCord = "WriteCord";
static constexpr const char* const kSeparator = "::";
static Status Create(Env* env, const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes,
std::unique_ptr<Writer>* out_writer);
explicit Writer(WritableFile* dest, const string& compression_type,
int version, const DataTypeVector& dtypes);
Status WriteTensors(const std::vector<Tensor>& tensors);
@ -71,27 +68,16 @@ class Writer {
~Writer();
private:
explicit Writer(const std::string& filename,
const std::string& compression_type, int version,
const DataTypeVector& dtypes);
Status Initialize(tensorflow::Env* env);
Status WriteRecord(const StringPiece& data);
#if defined(PLATFORM_GOOGLE)
Status WriteRecord(const absl::Cord& data);
#endif // PLATFORM_GOOGLE
std::unique_ptr<WritableFile> dest_;
const std::string filename_;
const std::string compression_type_;
WritableFile* dest_;
bool dest_is_owned_ = false;
const string compression_type_;
const int version_;
const DataTypeVector dtypes_;
// We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
// in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
// dest_ and so we need somewhere to store the original one.
std::unique_ptr<WritableFile> zlib_underlying_dest_;
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
int num_simple_ = 0;
int num_complex_ = 0;