Improving snapshot read performance (under snappy compression) by reducing the

number of copies by 2.

We do the following in this CL.

1) Get rid of the snappy input and output buffers that were previously being
used to do the compression. This saved one copy.
2) Directly decompress the compressed bytes into the TensorBuffer for simple
types (not string, variant, resource). This saves another copy during Tensor
creation. For complex types, we still continue to use the TensorProto encoding
and pay a copy there.
3) As a result, we end up changing the on-disk format for Snapshot. For a group
of tensors that make up one element of an IteratorGetNext output, we first
write out a metadata proto that describes the types, shapes and sizes of
tensors. After that we lay out the Tensor data (TensorBuffers for simple types
and TensorProtos serialized for complex ones) and compress them via snappy.
4) Add a version to the SnapshotMetadata. If it isn't set its assumed to be 0
and the old code path runs. We now set it to 1 while writing so that all new
snapshots are written in this data format.

PiperOrigin-RevId: 297149479
Change-Id: I2c9a35c5a254189a5fad946b2995f25cdc452308
This commit is contained in:
Rohan Jain 2020-02-25 10:35:46 -08:00 committed by TensorFlower Gardener
parent 18d4056bb1
commit 4bba44c1df
8 changed files with 453 additions and 126 deletions

View File

@ -434,6 +434,7 @@ tf_kernel_library(
name = "snapshot_dataset_op", name = "snapshot_dataset_op",
srcs = ["snapshot_dataset_op.cc"], srcs = ["snapshot_dataset_op.cc"],
deps = [ deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -441,6 +442,7 @@ tf_kernel_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/platform:platform_port",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
], ],

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <random> #include <random>
#include "absl/time/clock.h" #include "absl/time/clock.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
@ -32,7 +33,9 @@ limitations under the License.
#include "tensorflow/core/lib/io/compression.h" #include "tensorflow/core/lib/io/compression.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/snappy.h"
#if !defined(IS_SLIM_BUILD) #if !defined(IS_SLIM_BUILD)
#include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h" #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
#include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h" #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
@ -63,9 +66,6 @@ enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
// Defaults to 10 GiB per shard. // Defaults to 10 GiB per shard.
const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024; const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024;
const int64 kSnappyWriterInputBufferSizeBytes = 16 << 20; // 16 MiB
const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20; // 16 MiB
// The reader input buffer size is deliberately large because the input reader // The reader input buffer size is deliberately large because the input reader
// will throw an error if the compressed block length cannot fit in the input // will throw an error if the compressed block length cannot fit in the input
// buffer. // buffer.
@ -75,6 +75,8 @@ const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB
const size_t kHeaderSize = sizeof(uint64); const size_t kHeaderSize = sizeof(uint64);
const int64 kCurrentVersion = 1;
constexpr char kModeAuto[] = "auto"; constexpr char kModeAuto[] = "auto";
constexpr char kModeWrite[] = "write"; constexpr char kModeWrite[] = "write";
constexpr char kModeRead[] = "read"; constexpr char kModeRead[] = "read";
@ -95,6 +97,7 @@ constexpr char kState[] = "state";
constexpr char kHashDir[] = "hash_dir"; constexpr char kHashDir[] = "hash_dir";
constexpr char kRunId[] = "run_id"; constexpr char kRunId[] = "run_id";
constexpr char kRunDir[] = "run_dir"; constexpr char kRunDir[] = "run_dir";
constexpr char kVersionStr[] = "version";
constexpr char kFilenames[] = "filenames"; constexpr char kFilenames[] = "filenames";
constexpr char kCurrentFilenames[] = "current_filenames"; constexpr char kCurrentFilenames[] = "current_filenames";
constexpr char kElementsProduced[] = "elements_produced"; constexpr char kElementsProduced[] = "elements_produced";
@ -115,9 +118,9 @@ class SnapshotWriter {
static constexpr const char* const kWriteStringPiece = "WriteStringPiece"; static constexpr const char* const kWriteStringPiece = "WriteStringPiece";
static constexpr const char* const kWriteCord = "WriteCord"; static constexpr const char* const kWriteCord = "WriteCord";
explicit SnapshotWriter(WritableFile* dest, const string& compression_type = explicit SnapshotWriter(WritableFile* dest, const string& compression_type,
io::compression::kNone) int version, const DataTypeVector& dtypes)
: dest_(dest), compression_type_(compression_type) { : dest_(dest), compression_type_(compression_type), version_(version) {
#if defined(IS_SLIM_BUILD) #if defined(IS_SLIM_BUILD)
if (compression_type != io::compression::kNone) { if (compression_type != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
@ -134,41 +137,100 @@ class SnapshotWriter {
TF_CHECK_OK(zlib_output_buffer->Init()); TF_CHECK_OK(zlib_output_buffer->Init());
dest_ = zlib_output_buffer; dest_ = zlib_output_buffer;
dest_is_owned_ = true; dest_is_owned_ = true;
} else if (compression_type == io::compression::kSnappy) {
io::SnappyOutputBuffer* snappy_output_buffer = new io::SnappyOutputBuffer(
dest, /*input_buffer_bytes=*/kSnappyWriterInputBufferSizeBytes,
/*output_buffer_bytes=*/kSnappyWriterOutputBufferSizeBytes);
dest_ = snappy_output_buffer;
dest_is_owned_ = true;
} }
#endif // IS_SLIM_BUILD #endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes.size());
for (const auto& dtype : dtypes) {
if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true);
num_simple_++;
} else {
simple_tensor_mask_.push_back(false);
num_complex_++;
}
}
} }
Status WriteRecord(const StringPiece& data) { Status WriteTensors(const std::vector<Tensor>& tensors) {
profiler::TraceMe activity( if (compression_type_ != io::compression::kSnappy) {
[&]() { experimental::SnapshotRecord record;
return absl::StrCat(kClassName, kSeparator, kWriteStringPiece); for (const auto& tensor : tensors) {
}, TensorProto* t = record.add_tensor();
profiler::TraceMeLevel::kInfo); tensor.AsProtoTensorContent(t);
char header[kHeaderSize]; }
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
Status WriteRecord(const absl::Cord& data) { return WriteRecord(record.SerializeAsCord());
profiler::TraceMe activity( #else // PLATFORM_GOOGLE
[&]() { return absl::StrCat(kClassName, kSeparator, kWriteCord); }, return WriteRecord(record.SerializeAsString());
profiler::TraceMeLevel::kInfo);
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#endif // PLATFORM_GOOGLE #endif // PLATFORM_GOOGLE
}
if (version_ != 1) {
return errors::InvalidArgument("Version: ", version_,
" is not supported.");
}
if (compression_type_ != io::compression::kSnappy) {
return errors::InvalidArgument(
"Version 1 is only compatible with snappy compression");
}
std::vector<const TensorBuffer*> tensor_buffers;
tensor_buffers.reserve(num_simple_);
std::vector<TensorProto> tensor_protos;
tensor_protos.reserve(num_complex_);
SnapshotTensorMetadata metadata;
int64 total_size = 0;
for (int i = 0; i < tensors.size(); ++i) {
const Tensor& tensor = tensors[i];
TensorMetadata* tensor_metadata = metadata.add_tensor_metadata();
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
int64 size = 0;
if (simple_tensor_mask_[i]) {
auto tensor_buffer = DMAHelper::buffer(&tensor);
tensor_buffers.push_back(tensor_buffer);
size = tensor_buffer->size();
} else {
TensorProto proto;
tensor.AsProtoTensorContent(&proto);
size = proto.ByteSizeLong();
tensor_protos.push_back(std::move(proto));
}
tensor_metadata->set_tensor_size_bytes(size);
total_size += size;
}
std::vector<char> uncompressed(total_size);
char* position = uncompressed.data();
int buffer_index = 0;
int proto_index = 0;
for (int i = 0; i < tensors.size(); ++i) {
const auto& tensor_metadata = metadata.tensor_metadata(i);
if (simple_tensor_mask_[i]) {
memcpy(position, tensor_buffers[buffer_index]->data(),
tensor_metadata.tensor_size_bytes());
buffer_index++;
} else {
tensor_protos[proto_index].SerializeToArray(
position, tensor_metadata.tensor_size_bytes());
proto_index++;
}
position += tensor_metadata.tensor_size_bytes();
}
DCHECK_EQ(position, uncompressed.data() + total_size);
string output;
if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
return errors::Internal("Failed to compress using snappy.");
}
#if defined(PLATFORM_GOOGLE)
absl::Cord metadata_serialized = metadata.SerializeAsCord();
#else // PLATFORM_GOOGLE
std::string metadata_serialized = metadata.SerializeAsString();
#endif // PLATFORM_GOOGLE
TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
TF_RETURN_IF_ERROR(WriteRecord(output));
return Status::OK();
}
Status Sync() { return dest_->Sync(); } Status Sync() { return dest_->Sync(); }
@ -192,9 +254,29 @@ class SnapshotWriter {
} }
private: private:
Status WriteRecord(const StringPiece& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#if defined(PLATFORM_GOOGLE)
Status WriteRecord(const absl::Cord& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
return dest_->Append(data);
}
#endif // PLATFORM_GOOGLE
WritableFile* dest_; WritableFile* dest_;
bool dest_is_owned_ = false; bool dest_is_owned_ = false;
const string compression_type_; const string compression_type_;
const int version_;
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
int num_simple_ = 0;
int num_complex_ = 0;
}; };
class SnapshotReader { class SnapshotReader {
@ -203,12 +285,14 @@ class SnapshotReader {
static constexpr const char* const kReadString = "ReadString"; static constexpr const char* const kReadString = "ReadString";
static constexpr const char* const kReadCord = "ReadCord"; static constexpr const char* const kReadCord = "ReadCord";
explicit SnapshotReader( explicit SnapshotReader(RandomAccessFile* file,
RandomAccessFile* file, const string& compression_type, int version,
const string& compression_type = io::compression::kNone) const DataTypeVector& dtypes)
: file_(file), : file_(file),
input_stream_(new io::RandomAccessInputStream(file)), input_stream_(new io::RandomAccessInputStream(file)),
compression_type_(compression_type) { compression_type_(compression_type),
version_(version),
dtypes_(dtypes) {
#if defined(IS_SLIM_BUILD) #if defined(IS_SLIM_BUILD)
if (compression_type_ != io::compression::kNone) { if (compression_type_ != io::compression::kNone) {
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
@ -223,17 +307,167 @@ class SnapshotReader {
input_stream_.release(), zlib_options.input_buffer_size, input_stream_.release(), zlib_options.input_buffer_size,
zlib_options.output_buffer_size, zlib_options, true); zlib_options.output_buffer_size, zlib_options, true);
} else if (compression_type_ == io::compression::kSnappy) { } else if (compression_type_ == io::compression::kSnappy) {
input_stream_ = absl::make_unique<io::SnappyInputBuffer>( if (version_ == 0) {
file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
} else {
input_stream_ =
absl::make_unique<io::BufferedInputStream>(file_, 64 << 20);
}
} }
#endif // IS_SLIM_BUILD #endif // IS_SLIM_BUILD
simple_tensor_mask_.reserve(dtypes.size());
for (const auto& dtype : dtypes) {
if (DataTypeCanUseMemcpy(dtype)) {
simple_tensor_mask_.push_back(true);
num_simple_++;
} else {
simple_tensor_mask_.push_back(false);
num_complex_++;
}
}
}
Status ReadTensors(std::vector<Tensor>* read_tensors) {
profiler::TraceMe activity(
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
profiler::TraceMeLevel::kInfo);
if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
return ReadTensorsV0(read_tensors);
}
if (version_ != 1) {
return errors::InvalidArgument("Version: ", version_,
" is not supported.");
}
if (compression_type_ != io::compression::kSnappy) {
return errors::InvalidArgument("Version 1 only supports snappy.");
}
SnapshotTensorMetadata metadata;
tstring metadata_str;
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
return errors::DataLoss("Could not parse SnapshotTensorMetadata");
}
read_tensors->reserve(metadata.tensor_metadata_size());
std::vector<Tensor> simple_tensors;
simple_tensors.reserve(num_simple_);
std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
tensor_proto_strs.reserve(num_complex_);
TF_RETURN_IF_ERROR(
SnappyUncompress(metadata, &simple_tensors, &tensor_proto_strs));
int simple_index = 0;
int complex_index = 0;
for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
if (simple_tensor_mask_[i]) {
read_tensors->push_back(std::move(simple_tensors[simple_index]));
simple_index++;
} else {
auto tensor_proto_str =
std::move(tensor_proto_strs[complex_index].first);
size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
TensorProto tp;
#if defined(PLATFORM_GOOGLE)
auto tensor_proto_ptr = tensor_proto_str.release();
absl::Cord c;
c.AppendExternalMemory(
absl::string_view(tensor_proto_ptr, tensor_proto_size),
tensor_proto_ptr,
[](void* arg) { delete[] static_cast<char*>(arg); });
if (!tp.ParseFromCord(c)) {
return errors::Internal("Could not parse TensorProto");
}
#else // PLATFORM_GOOGLE
if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
return errors::Internal("Could not parse TensorProto");
}
#endif // PLATFORM_GOOGLE
Tensor t;
if (!t.FromProto(tp)) {
return errors::Internal("Could not parse Tensor");
}
read_tensors->push_back(std::move(t));
complex_index++;
}
}
return Status::OK();
}
private:
Status ReadTensorsV0(std::vector<Tensor>* read_tensors) {
experimental::SnapshotRecord record;
#if defined(PLATFORM_GOOGLE)
absl::Cord c;
TF_RETURN_IF_ERROR(ReadRecord(&c));
record.ParseFromCord(c);
#else // PLATFORM_GOOGLE
tstring record_bytes;
TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
record.ParseFromArray(record_bytes.data(), record_bytes.size());
#endif // PLATFORM_GOOGLE
read_tensors->reserve(record.tensor_size());
for (int i = 0; i < record.tensor_size(); ++i) {
read_tensors->emplace_back();
if (!read_tensors->back().FromProto(record.tensor(i))) {
return errors::DataLoss("Unable to parse tensor from proto.");
}
}
return Status::OK();
}
Status SnappyUncompress(
const SnapshotTensorMetadata& metadata,
std::vector<Tensor>* simple_tensors,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
tensor_proto_strs) {
tstring compressed;
TF_RETURN_IF_ERROR(ReadRecord(&compressed));
size_t size;
if (!port::Snappy_GetUncompressedLength(compressed.data(),
compressed.size(), &size)) {
return errors::Internal("Could not get snappy uncompressed length");
}
int num_tensors = metadata.tensor_metadata_size();
std::vector<struct iovec> iov(num_tensors);
int index = 0;
int64 total_size = 0;
for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
const auto& tensor_metadata = metadata.tensor_metadata(i);
if (simple_tensor_mask_[i]) {
TensorShape shape(tensor_metadata.tensor_shape());
Tensor simple_tensor(dtypes_[i], shape);
TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
iov[index].iov_base = buffer->data();
iov[index].iov_len = buffer->size();
simple_tensors->push_back(std::move(simple_tensor));
} else {
auto tensor_proto_str =
absl::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
iov[index].iov_base = tensor_proto_str.get();
iov[index].iov_len = tensor_metadata.tensor_size_bytes();
tensor_proto_strs->push_back(std::make_pair(
std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
}
total_size += iov[index].iov_len;
index++;
}
if (size != total_size) {
return errors::Internal("Uncompressed size mismatch. Snappy expects ",
size, " whereas the tensor metadata suggests ",
total_size);
}
if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
iov.data(), num_tensors)) {
return errors::Internal("Failed to perform snappy decompression.");
}
return Status::OK();
} }
Status ReadRecord(tstring* record) { Status ReadRecord(tstring* record) {
profiler::TraceMe activity(
[&]() { return absl::StrCat(kClassName, kSeparator, kReadString); },
profiler::TraceMeLevel::kInfo);
tstring header; tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data()); uint64 length = core::DecodeFixed64(header.data());
@ -245,13 +479,6 @@ class SnapshotReader {
tstring header; tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data()); uint64 length = core::DecodeFixed64(header.data());
profiler::TraceMe activity(
[&]() {
return absl::StrCat(kClassName, kSeparator, kReadCord,
"#length=", length, "#");
},
profiler::TraceMeLevel::kInfo);
if (compression_type_ == io::compression::kNone) { if (compression_type_ == io::compression::kNone) {
return input_stream_->ReadNBytes(length, record); return input_stream_->ReadNBytes(length, record);
} else { } else {
@ -268,50 +495,31 @@ class SnapshotReader {
} }
#endif #endif
private:
RandomAccessFile* file_; RandomAccessFile* file_;
std::unique_ptr<io::InputStreamInterface> input_stream_; std::unique_ptr<io::InputStreamInterface> input_stream_;
const string compression_type_; const string compression_type_;
const int version_;
const DataTypeVector dtypes_;
int num_simple_ = 0;
int num_complex_ = 0;
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
}; };
Status WriteMetadataFile(const string& hash_dir, Status WriteMetadataFile(const string& hash_dir,
const experimental::SnapshotMetadataRecord& metadata) { const experimental::SnapshotMetadataRecord& metadata) {
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir)); TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
std::string tmp_filename = std::string tmp_filename =
absl::StrCat(metadata_filename, "-tmp-", random::New64()); absl::StrCat(metadata_filename, "-tmp-", random::New64());
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, metadata));
std::unique_ptr<WritableFile> file; return Env::Default()->RenameFile(tmp_filename, metadata_filename);
TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(tmp_filename, &file));
auto writer = absl::make_unique<SnapshotWriter>(file.get());
TF_RETURN_IF_ERROR(writer->WriteRecord(metadata.SerializeAsString()));
TF_RETURN_IF_ERROR(writer->Close());
TF_RETURN_IF_ERROR(file->Sync());
TF_RETURN_IF_ERROR(file->Close());
TF_RETURN_IF_ERROR(
Env::Default()->RenameFile(tmp_filename, metadata_filename));
return Status::OK();
} }
Status ReadMetadataFile(const string& hash_dir, Status ReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata) { experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename)); TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(
Env::Default()->NewRandomAccessFile(metadata_filename, &file));
tstring record_bytes;
SnapshotReader reader(file.get());
TF_RETURN_IF_ERROR(reader.ReadRecord(&record_bytes));
metadata->ParseFromArray(record_bytes.data(), record_bytes.size());
return Status::OK();
} }
Status DumpDatasetGraph(const std::string& path, uint64 hash, Status DumpDatasetGraph(const std::string& path, uint64 hash,
@ -332,6 +540,10 @@ Status DetermineOpState(const std::string& mode_string,
const uint64 pending_snapshot_expiry_seconds, const uint64 pending_snapshot_expiry_seconds,
SnapshotMode* mode) { SnapshotMode* mode) {
if (mode_string == kModeRead) { if (mode_string == kModeRead) {
// In read mode, we should expect a metadata file is written.
if (errors::IsNotFound(file_status)) {
return file_status;
}
LOG(INFO) << "Overriding mode to reader."; LOG(INFO) << "Overriding mode to reader.";
*mode = READER; *mode = READER;
return Status::OK(); return Status::OK();
@ -727,10 +939,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (run_id.empty()) { if (run_id.empty()) {
run_id = metadata.run_id(); run_id = metadata.run_id();
} }
// dtypes in metadata should be the same as dataset()->output_dtypes
if (metadata.dtype_size() != dataset()->output_dtypes().size()) {
return errors::Internal(
"Expected number of dtypes: ",
dataset()->output_dtypes().size(),
" but number in snapshot: ", metadata.dtype_size());
}
for (int i = 0; i < metadata.dtype_size(); ++i) {
if (metadata.dtype(i) != dataset()->output_dtypes()[i]) {
return errors::Internal(
"Type: ", i,
" doesn't match. Snapshot: ", metadata.dtype(i),
"; dataset: ", dataset()->output_dtypes()[i]);
}
}
iterator_ = absl::make_unique<SnapshotReaderIterator>( iterator_ = absl::make_unique<SnapshotReaderIterator>(
SnapshotReaderIterator::Params{ SnapshotReaderIterator::Params{
dataset(), absl::StrCat(prefix(), "ReaderImpl")}, dataset(), absl::StrCat(prefix(), "ReaderImpl")},
hash_dir_, run_id); hash_dir_, run_id, metadata.version());
break; break;
case PASSTHROUGH: case PASSTHROUGH:
iterator_ = absl::make_unique<SnapshotPassthroughIterator>( iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
@ -748,10 +975,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
explicit SnapshotReaderIterator(const Params& params, explicit SnapshotReaderIterator(const Params& params,
const string& hash_dir, const string& hash_dir,
const string& run_id) const string& run_id, int64 version)
: DatasetIterator<Dataset>(params), : DatasetIterator<Dataset>(params),
hash_dir_(hash_dir), hash_dir_(hash_dir),
run_id_(run_id) {} run_id_(run_id),
version_(version) {}
~SnapshotReaderIterator() override { ~SnapshotReaderIterator() override {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -889,6 +1117,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
writer->WriteScalar(full_name(kHashDir), hash_dir_)); writer->WriteScalar(full_name(kHashDir), hash_dir_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kVersionStr), version_));
TF_RETURN_IF_ERROR(writer->WriteScalar( TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kFilenames, kSizeSuffix)), full_name(strings::StrCat(kFilenames, kSizeSuffix)),
filenames_.size())); filenames_.size()));
@ -932,6 +1162,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kVersionStr), &version_));
curr_filenames_.clear(); curr_filenames_.clear();
curr_filenames_.reserve(dataset()->num_reader_threads_); curr_filenames_.reserve(dataset()->num_reader_threads_);
for (auto i = 0; i < dataset()->num_reader_threads_; ++i) { for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
@ -986,7 +1218,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<RandomAccessFile> file; std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
Env::Default()->NewRandomAccessFile(filename, &file)); Env::Default()->NewRandomAccessFile(filename, &file));
SnapshotReader reader(file.get(), dataset()->compression_); SnapshotReader reader(file.get(), dataset()->compression_, version_,
dataset()->output_dtypes());
while (true) { while (true) {
// Wait for a slot in the buffer. // Wait for a slot in the buffer.
@ -1003,30 +1236,14 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
"ReadFile"); "ReadFile");
} }
} }
#if !defined(PLATFORM_GOOGLE) std::vector<Tensor> read_tensors;
tstring record_bytes; Status s = reader.ReadTensors(&read_tensors);
Status s = reader.ReadRecord(&record_bytes);
#else
absl::Cord record_cord;
Status s = reader.ReadRecord(&record_cord);
#endif
if (s.ok()) { if (s.ok()) {
profiler::TraceMe activity( profiler::TraceMe activity(
[&]() { return absl::StrCat(prefix(), kSeparator, kParse); }, [&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
experimental::SnapshotRecord record;
#if !defined(PLATFORM_GOOGLE)
record.ParseFromArray(record_bytes.data(), record_bytes.size());
#else
record.ParseFromCord(record_cord);
#endif
BufferElement elem; BufferElement elem;
for (int i = 0; i < record.tensor_size(); ++i) { elem.value = std::move(read_tensors);
elem.value.emplace_back();
if (!elem.value.back().FromProto(record.tensor(i))) {
return errors::DataLoss("Unable to parse tensor from proto.");
}
}
elem.status = Status::OK(); elem.status = Status::OK();
mutex_lock l(mu_); mutex_lock l(mu_);
buffer_.push_back(std::move(elem)); buffer_.push_back(std::move(elem));
@ -1142,9 +1359,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
condition_variable cond_var_; condition_variable cond_var_;
const string hash_dir_; const string hash_dir_;
const experimental::SnapshotMetadataRecord metadata_;
tstring run_id_ GUARDED_BY(mu_); tstring run_id_ GUARDED_BY(mu_);
tstring run_dir_ GUARDED_BY(mu_); tstring run_dir_ GUARDED_BY(mu_);
int64 version_;
std::vector<tstring> filenames_; std::vector<tstring> filenames_;
uint64 elements_produced_ GUARDED_BY(mu_) = 0; uint64 elements_produced_ GUARDED_BY(mu_) = 0;
@ -1220,6 +1437,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
metadata.set_creation_timestamp(EnvTime::NowMicros()); metadata.set_creation_timestamp(EnvTime::NowMicros());
metadata.set_graph_hash(dataset()->graph_hash_); metadata.set_graph_hash(dataset()->graph_hash_);
metadata.set_run_id(run_id_.data(), run_id_.size()); metadata.set_run_id(run_id_.data(), run_id_.size());
metadata.set_version(kCurrentVersion);
for (const auto& output_dtype : dataset()->output_dtypes()) {
metadata.add_dtype(output_dtype);
}
metadata.set_finalized(false); metadata.set_finalized(false);
TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata)); TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata));
} }
@ -1564,11 +1785,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
if (produced_elem) { if (produced_elem) {
experimental::SnapshotRecord record;
for (const auto& out_tensor : elem.value) { for (const auto& out_tensor : elem.value) {
*bytes_written += out_tensor.TotalBytes(); *bytes_written += out_tensor.TotalBytes();
TensorProto* t = record.add_tensor();
out_tensor.AsProtoTensorContent(t);
} }
bool should_close; bool should_close;
@ -1584,16 +1802,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile( TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
*snapshot_data_filename, file)); *snapshot_data_filename, file));
*writer = absl::make_unique<SnapshotWriter>( *writer = absl::make_unique<SnapshotWriter>(
file->get(), dataset()->compression_); file->get(), dataset()->compression_, kCurrentVersion,
dataset()->output_dtypes());
*bytes_written = 0; *bytes_written = 0;
} }
#if defined(PLATFORM_GOOGLE) TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
TF_RETURN_IF_ERROR(
(*writer)->WriteRecord(record.SerializeAsCord()));
#else // PLATFORM_GOOGLE
TF_RETURN_IF_ERROR(
(*writer)->WriteRecord(record.SerializeAsString()));
#endif // PLATFORM_GOOGLE
return Status::OK(); return Status::OK();
} }
@ -1641,7 +1854,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
return; return;
} }
std::unique_ptr<SnapshotWriter> writer( std::unique_ptr<SnapshotWriter> writer(
new SnapshotWriter(file.get(), dataset()->compression_)); new SnapshotWriter(file.get(), dataset()->compression_,
kCurrentVersion, dataset()->output_dtypes()));
bool end_of_processing = false; bool end_of_processing = false;
while (!end_of_processing) { while (!end_of_processing) {

View File

@ -332,6 +332,16 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) {
#endif #endif
} }
bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
const struct iovec* iov, size_t iov_cnt) {
#ifdef TF_USE_SNAPPY
return snappy::RawUncompressToIOVec(compressed, compressed_length, iov,
iov_cnt);
#else
return false;
#endif
}
string Demangle(const char* mangled) { return mangled; } string Demangle(const char* mangled) { return mangled; }
double NominalCPUFrequency() { double NominalCPUFrequency() {

View File

@ -18,6 +18,17 @@ limitations under the License.
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#if !defined(PLATFORM_WINDOWS)
#include <sys/uio.h>
#else
namespace tensorflow {
struct iovec {
void* iov_base;
size_t iov_len;
};
} // namespace tensorflow
#endif
namespace tensorflow { namespace tensorflow {
namespace port { namespace port {
@ -28,6 +39,9 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
size_t* result); size_t* result);
bool Snappy_Uncompress(const char* input, size_t length, char* output); bool Snappy_Uncompress(const char* input, size_t length, char* output);
bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
const struct iovec* iov, size_t iov_cnt);
} // namespace port } // namespace port
} // namespace tensorflow } // namespace tensorflow

View File

@ -157,6 +157,17 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) {
#endif #endif
} }
bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
const struct iovec* iov, size_t iov_cnt) {
#ifdef TF_USE_SNAPPY
const snappy::iovec* snappy_iov = reinterpret_cast<const snappy::iovec*>(iov);
return snappy::RawUncompressToIOVec(compressed, compressed_length, snappy_iov,
iov_cnt);
#else
return false;
#endif
}
string Demangle(const char* mangled) { return mangled; } string Demangle(const char* mangled) { return mangled; }
double NominalCPUFrequency() { double NominalCPUFrequency() {

View File

@ -3,6 +3,8 @@ syntax = "proto3";
package tensorflow.data.experimental; package tensorflow.data.experimental;
import "tensorflow/core/framework/tensor.proto"; import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
// Each SnapshotRecord represents one batch of pre-processed input data. A batch // Each SnapshotRecord represents one batch of pre-processed input data. A batch
// consists of a list of tensors that we encode as TensorProtos. This message // consists of a list of tensors that we encode as TensorProtos. This message
@ -13,9 +15,29 @@ message SnapshotRecord {
// This stores the metadata information present in each snapshot record. // This stores the metadata information present in each snapshot record.
message SnapshotMetadataRecord { message SnapshotMetadataRecord {
// Stores the fingerprint of the graph that describes the dataset that is
// snapshotted.
string graph_hash = 1; string graph_hash = 1;
// Run ID that this snapshot corresponds to.
string run_id = 2; string run_id = 2;
// Time when we started creating this snapshot.
int64 creation_timestamp = 3; int64 creation_timestamp = 3;
// Version of the snapshot data file format.
int64 version = 4;
// A list of tensor dtype corresponding to each element of the snapshot.
repeated .tensorflow.DataType dtype = 5;
bool finalized = 1000; bool finalized = 1000;
} }
// Metadata for a single tensor in the Snapshot Record.
message TensorMetadata {
.tensorflow.TensorShapeProto tensor_shape = 2;
// Number of uncompressed bytes used to store the tensor representation.
int64 tensor_size_bytes = 3;
}
// Metadata for all the tensors in a Snapshot Record.
message SnapshotTensorMetadata {
repeated TensorMetadata tensor_metadata = 1;
}

View File

@ -161,17 +161,49 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(
def testWriteSnapshotRepeatAfterwards(self): combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testWriteSnapshotRepeatAfterwards(self, compression):
tmpdir = self.snapshot_dir tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10) dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(snapshot.snapshot(tmpdir)) dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
dataset = dataset.repeat(10) dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertDatasetProduces(dataset, list(range(10)) * 10)
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testWriteSnapshotMixTypes(self, compression):
tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.range(10)
def map_fn(x):
return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x)
dataset = dataset.map(map_fn)
dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
dataset = dataset.repeat(10)
expected = []
for i in range(10):
expected.append((i, str(i), str(2 * i), 2 * i))
self.assertDatasetProduces(dataset, expected * 10)
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testSpecifySnapshotNameWriteAndRead(self): def testSpecifySnapshotNameWriteAndRead(self):
tmpdir = self.snapshot_dir tmpdir = self.snapshot_dir
@ -365,8 +397,14 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
res3 = self.evaluate(next3()) res3 = self.evaluate(next3())
self.assertEqual(res2, res3) self.assertEqual(res2, res3)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(
def testReadSnapshotParallelAfterWrite(self): combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testReadSnapshotParallelAfterWrite(self, compression):
self.setUpTFRecord(10, 4000) self.setUpTFRecord(10, 4000)
filenames = self.test_filenames filenames = self.test_filenames
@ -383,7 +421,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
tmpdir, tmpdir,
shard_size_bytes=1024 * 1024, shard_size_bytes=1024 * 1024,
num_reader_threads=2, num_reader_threads=2,
reader_buffer_size=10)) reader_buffer_size=10,
compression=compression))
self.assertDatasetProduces(dataset, expected, assert_items_equal=True) self.assertDatasetProduces(dataset, expected, assert_items_equal=True)
# remove the original files and try to read the data back only from # remove the original files and try to read the data back only from
@ -396,7 +435,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
tmpdir, tmpdir,
shard_size_bytes=1024 * 1024, shard_size_bytes=1024 * 1024,
num_reader_threads=2, num_reader_threads=2,
reader_buffer_size=10)) reader_buffer_size=10,
compression=compression))
self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
# Not testing Snappy here because Snappy reads currently require a lot of # Not testing Snappy here because Snappy reads currently require a lot of
@ -514,21 +554,31 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
self.evaluate(next2()) self.evaluate(next2())
self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(
def testSpecifyShardSize(self): combinations.times(
test_base.default_test_combinations(),
combinations.combine(compression=[
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
snapshot.COMPRESSION_SNAPPY
])))
def testSpecifyShardSize(self, compression):
tmpdir = self.snapshot_dir tmpdir = self.snapshot_dir
dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024]))
dataset = dataset.repeat(10) dataset = dataset.repeat(10)
dataset = dataset.apply( dataset = dataset.apply(
snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024)) snapshot.snapshot(
tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression))
next_fn = self.getNext(dataset) next_fn = self.getNext(dataset)
for _ in range(10): for _ in range(10):
self.evaluate(next_fn()) self.evaluate(next_fn())
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 3) num_files = 1
if compression == snapshot.COMPRESSION_NONE:
num_files = 3
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testAdditionalOperationsAfterReadBack(self): def testAdditionalOperationsAfterReadBack(self):

View File

@ -27,6 +27,10 @@ cc_library(
"-Wno-implicit-function-declaration", "-Wno-implicit-function-declaration",
], ],
}), }),
defines = select({
"@org_tensorflow//tensorflow:windows": [],
"//conditions:default": ["HAVE_SYS_UIO_H"],
}),
) )
genrule( genrule(