Update snapshot_util::Writer to own WritableFiles instead borrowing
PiperOrigin-RevId: 306099307 Change-Id: I29ad3f26080ab6419c93b9436f74ca0e53520a6d
This commit is contained in:
parent
3364a141f2
commit
a192cf7dad
@ -529,7 +529,6 @@ cc_library(
|
||||
"//tensorflow/core/platform:coding",
|
||||
"//tensorflow/core/platform:random",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -965,8 +965,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
|
||||
++num_active_threads_;
|
||||
thread_pool_->Schedule(
|
||||
[this, env = ctx->env()]() { WriterThread(env); });
|
||||
thread_pool_->Schedule([this]() { WriterThread(); });
|
||||
}
|
||||
first_call_ = false;
|
||||
}
|
||||
@ -1263,8 +1262,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
Status ProcessOneElement(int64* bytes_written,
|
||||
string* snapshot_data_filename,
|
||||
std::unique_ptr<WritableFile>* file,
|
||||
std::unique_ptr<snapshot_util::Writer>* writer,
|
||||
bool* end_of_processing, Env* env) {
|
||||
bool* end_of_processing) {
|
||||
profiler::TraceMe activity(
|
||||
[&]() {
|
||||
return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
|
||||
@ -1296,6 +1296,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
if (cancelled || snapshot_failed) {
|
||||
TF_RETURN_IF_ERROR((*writer)->Close());
|
||||
TF_RETURN_IF_ERROR((*file)->Sync());
|
||||
TF_RETURN_IF_ERROR((*file)->Close());
|
||||
if (snapshot_failed) {
|
||||
return errors::Internal(
|
||||
"SnapshotDataset::SnapshotWriterIterator snapshot failed");
|
||||
@ -1310,17 +1312,20 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
bool should_close;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShouldCloseWriter(*snapshot_data_filename, *bytes_written,
|
||||
(*writer).get(), &should_close));
|
||||
TF_RETURN_IF_ERROR(ShouldCloseFile(*snapshot_data_filename,
|
||||
*bytes_written, (*writer).get(),
|
||||
(*file).get(), &should_close));
|
||||
if (should_close) {
|
||||
// If we exceed the shard size, we get a new file and reset.
|
||||
TF_RETURN_IF_ERROR((*writer)->Close());
|
||||
TF_RETURN_IF_ERROR((*file)->Sync());
|
||||
TF_RETURN_IF_ERROR((*file)->Close());
|
||||
*snapshot_data_filename = GetSnapshotFilename();
|
||||
|
||||
TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
|
||||
env, *snapshot_data_filename, dataset()->compression_,
|
||||
kCurrentVersion, dataset()->output_dtypes(), writer));
|
||||
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
|
||||
*snapshot_data_filename, file));
|
||||
*writer = absl::make_unique<snapshot_util::Writer>(
|
||||
file->get(), dataset()->compression_, kCurrentVersion,
|
||||
dataset()->output_dtypes());
|
||||
*bytes_written = 0;
|
||||
}
|
||||
TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
|
||||
@ -1329,6 +1334,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
if (*end_of_processing) {
|
||||
TF_RETURN_IF_ERROR((*writer)->Close());
|
||||
TF_RETURN_IF_ERROR((*file)->Sync());
|
||||
TF_RETURN_IF_ERROR((*file)->Close());
|
||||
mutex_lock l(mu_);
|
||||
if (!written_final_metadata_file_) {
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
@ -1351,7 +1358,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
// Just pulls off elements from the buffer and writes them.
|
||||
void WriterThread(Env* env) {
|
||||
void WriterThread() {
|
||||
auto cleanup = gtl::MakeCleanup([this]() {
|
||||
mutex_lock l(mu_);
|
||||
--num_active_threads_;
|
||||
@ -1360,10 +1367,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 bytes_written = 0;
|
||||
string snapshot_data_filename = GetSnapshotFilename();
|
||||
std::unique_ptr<snapshot_util::Writer> writer;
|
||||
Status s = snapshot_util::Writer::Create(
|
||||
env, snapshot_data_filename, dataset()->compression_,
|
||||
kCurrentVersion, dataset()->output_dtypes(), &writer);
|
||||
std::unique_ptr<WritableFile> file;
|
||||
Status s =
|
||||
Env::Default()->NewAppendableFile(snapshot_data_filename, &file);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Creating " << snapshot_data_filename
|
||||
<< " failed: " << s.ToString();
|
||||
@ -1372,12 +1378,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
cond_var_.notify_all();
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<snapshot_util::Writer> writer(
|
||||
new snapshot_util::Writer(file.get(), dataset()->compression_,
|
||||
kCurrentVersion,
|
||||
dataset()->output_dtypes()));
|
||||
|
||||
bool end_of_processing = false;
|
||||
while (!end_of_processing) {
|
||||
Status s =
|
||||
ProcessOneElement(&bytes_written, &snapshot_data_filename,
|
||||
&writer, &end_of_processing, env);
|
||||
&file, &writer, &end_of_processing);
|
||||
if (!s.ok()) {
|
||||
LOG(INFO) << "Error while writing snapshot data to disk: "
|
||||
<< s.ToString();
|
||||
@ -1391,9 +1401,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
Status ShouldCloseWriter(const string& filename, uint64 bytes_written,
|
||||
snapshot_util::Writer* writer,
|
||||
bool* should_close) {
|
||||
Status ShouldCloseFile(const string& filename, uint64 bytes_written,
|
||||
snapshot_util::Writer* writer,
|
||||
WritableFile* file, bool* should_close) {
|
||||
// If the compression ratio has been estimated, use it to decide
|
||||
// whether the file should be closed. We avoid estimating the
|
||||
// compression ratio repeatedly because it requires syncing the file,
|
||||
@ -1415,6 +1425,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Use the actual file size to determine compression ratio.
|
||||
// Make sure that all bytes are written out.
|
||||
TF_RETURN_IF_ERROR(writer->Sync());
|
||||
TF_RETURN_IF_ERROR(file->Sync());
|
||||
uint64 file_size;
|
||||
TF_RETURN_IF_ERROR(Env::Default()->GetFileSize(filename, &file_size));
|
||||
mutex_lock l(mu_);
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
@ -40,45 +39,29 @@ namespace snapshot_util {
|
||||
/* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes;
|
||||
/* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes;
|
||||
|
||||
Writer::Writer(const std::string& filename, const std::string& compression_type,
|
||||
int version, const DataTypeVector& dtypes)
|
||||
: filename_(filename),
|
||||
compression_type_(compression_type),
|
||||
version_(version),
|
||||
dtypes_(dtypes) {}
|
||||
|
||||
Status Writer::Create(Env* env, const std::string& filename,
|
||||
const std::string& compression_type, int version,
|
||||
const DataTypeVector& dtypes,
|
||||
std::unique_ptr<Writer>* out_writer) {
|
||||
*out_writer =
|
||||
absl::WrapUnique(new Writer(filename, compression_type, version, dtypes));
|
||||
|
||||
return (*out_writer)->Initialize(env);
|
||||
}
|
||||
|
||||
Status Writer::Initialize(tensorflow::Env* env) {
|
||||
TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_));
|
||||
Writer::Writer(WritableFile* dest, const string& compression_type, int version,
|
||||
const DataTypeVector& dtypes)
|
||||
: dest_(dest), compression_type_(compression_type), version_(version) {
|
||||
#if defined(IS_SLIM_BUILD)
|
||||
if (compression_type_ != io::compression::kNone) {
|
||||
if (compression_type != io::compression::kNone) {
|
||||
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
|
||||
<< "off compression.";
|
||||
}
|
||||
#else // IS_SLIM_BUILD
|
||||
if (compression_type_ == io::compression::kGzip) {
|
||||
zlib_underlying_dest_.swap(dest_);
|
||||
if (compression_type == io::compression::kGzip) {
|
||||
io::ZlibCompressionOptions zlib_options;
|
||||
zlib_options = io::ZlibCompressionOptions::GZIP();
|
||||
|
||||
io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
|
||||
zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
|
||||
zlib_options.output_buffer_size, zlib_options);
|
||||
io::ZlibOutputBuffer* zlib_output_buffer =
|
||||
new io::ZlibOutputBuffer(dest, zlib_options.input_buffer_size,
|
||||
zlib_options.output_buffer_size, zlib_options);
|
||||
TF_CHECK_OK(zlib_output_buffer->Init());
|
||||
dest_.reset(zlib_output_buffer);
|
||||
dest_ = zlib_output_buffer;
|
||||
dest_is_owned_ = true;
|
||||
}
|
||||
#endif // IS_SLIM_BUILD
|
||||
simple_tensor_mask_.reserve(dtypes_.size());
|
||||
for (const auto& dtype : dtypes_) {
|
||||
simple_tensor_mask_.reserve(dtypes.size());
|
||||
for (const auto& dtype : dtypes) {
|
||||
if (DataTypeCanUseMemcpy(dtype)) {
|
||||
simple_tensor_mask_.push_back(true);
|
||||
num_simple_++;
|
||||
@ -87,8 +70,6 @@ Status Writer::Initialize(tensorflow::Env* env) {
|
||||
num_complex_++;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
@ -175,21 +156,21 @@ Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
Status Writer::Sync() { return dest_->Sync(); }
|
||||
|
||||
Status Writer::Close() {
|
||||
if (dest_ != nullptr) {
|
||||
TF_RETURN_IF_ERROR(dest_->Close());
|
||||
if (dest_is_owned_) {
|
||||
Status s = dest_->Close();
|
||||
delete dest_;
|
||||
dest_ = nullptr;
|
||||
}
|
||||
if (zlib_underlying_dest_ != nullptr) {
|
||||
TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
|
||||
zlib_underlying_dest_ = nullptr;
|
||||
return s;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Writer::~Writer() {
|
||||
Status s = Close();
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not finish writing file: " << s;
|
||||
if (dest_ != nullptr) {
|
||||
Status s = Close();
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not finish writing file: " << s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/io/compression.h"
|
||||
#include "tensorflow/core/lib/io/inputstream_interface.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
@ -57,10 +56,8 @@ class Writer {
|
||||
static constexpr const char* const kWriteCord = "WriteCord";
|
||||
static constexpr const char* const kSeparator = "::";
|
||||
|
||||
static Status Create(Env* env, const std::string& filename,
|
||||
const std::string& compression_type, int version,
|
||||
const DataTypeVector& dtypes,
|
||||
std::unique_ptr<Writer>* out_writer);
|
||||
explicit Writer(WritableFile* dest, const string& compression_type,
|
||||
int version, const DataTypeVector& dtypes);
|
||||
|
||||
Status WriteTensors(const std::vector<Tensor>& tensors);
|
||||
|
||||
@ -71,27 +68,16 @@ class Writer {
|
||||
~Writer();
|
||||
|
||||
private:
|
||||
explicit Writer(const std::string& filename,
|
||||
const std::string& compression_type, int version,
|
||||
const DataTypeVector& dtypes);
|
||||
|
||||
Status Initialize(tensorflow::Env* env);
|
||||
|
||||
Status WriteRecord(const StringPiece& data);
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
Status WriteRecord(const absl::Cord& data);
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
std::unique_ptr<WritableFile> dest_;
|
||||
const std::string filename_;
|
||||
const std::string compression_type_;
|
||||
WritableFile* dest_;
|
||||
bool dest_is_owned_ = false;
|
||||
const string compression_type_;
|
||||
const int version_;
|
||||
const DataTypeVector dtypes_;
|
||||
// We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
|
||||
// in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
|
||||
// dest_ and so we need somewhere to store the original one.
|
||||
std::unique_ptr<WritableFile> zlib_underlying_dest_;
|
||||
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
|
||||
int num_simple_ = 0;
|
||||
int num_complex_ = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user