Improving snapshot read performance (under snappy compression) by reducing the
number of copies by 2. We do the following in this CL. 1) Get rid of the snappy input and output buffers that were previously being used to do the compression. This saved one copy. 2) Directly decompress the compressed bytes into the TensorBuffer for simple types (not string, variant, resource). This saves another copy during Tensor creation. For complex types, we still continue to use the TensorProto encoding and pay a copy there. 3) As a result, we end up changing the on-disk format for Snapshot. For a group of tensors that make up one element of an IteratorGetNext output, we first write out a metadata proto that describes the types, shapes and sizes of tensors. After that we lay out the Tensor data (TensorBuffers for simple types and TensorProtos serialized for complex ones) and compress them via snappy. 4) Add a version to the SnapshotMetadata. If it isn't set its assumed to be 0 and the old code path runs. We now set it to 1 while writing so that all new snapshots are written in this data format. PiperOrigin-RevId: 297149479 Change-Id: I2c9a35c5a254189a5fad946b2995f25cdc452308
This commit is contained in:
parent
18d4056bb1
commit
4bba44c1df
|
@ -434,6 +434,7 @@ tf_kernel_library(
|
||||||
name = "snapshot_dataset_op",
|
name = "snapshot_dataset_op",
|
||||||
srcs = ["snapshot_dataset_op.cc"],
|
srcs = ["snapshot_dataset_op.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:dataset_ops_op_lib",
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -441,6 +442,7 @@ tf_kernel_library(
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
"//tensorflow/core/grappler:graph_view",
|
||||||
"//tensorflow/core/kernels/data:dataset_utils",
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
|
"//tensorflow/core/platform:platform_port",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"@com_google_absl//absl/time",
|
"@com_google_absl//absl/time",
|
||||||
],
|
],
|
||||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include "absl/time/clock.h"
|
#include "absl/time/clock.h"
|
||||||
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
|
@ -32,7 +33,9 @@ limitations under the License.
|
||||||
#include "tensorflow/core/lib/io/compression.h"
|
#include "tensorflow/core/lib/io/compression.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/io/random_inputstream.h"
|
#include "tensorflow/core/lib/io/random_inputstream.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/file_system.h"
|
#include "tensorflow/core/platform/file_system.h"
|
||||||
|
#include "tensorflow/core/platform/snappy.h"
|
||||||
#if !defined(IS_SLIM_BUILD)
|
#if !defined(IS_SLIM_BUILD)
|
||||||
#include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
|
#include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
|
||||||
#include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
|
#include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
|
||||||
|
@ -63,9 +66,6 @@ enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
|
||||||
// Defaults to 10 GiB per shard.
|
// Defaults to 10 GiB per shard.
|
||||||
const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024;
|
const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024;
|
||||||
|
|
||||||
const int64 kSnappyWriterInputBufferSizeBytes = 16 << 20; // 16 MiB
|
|
||||||
const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20; // 16 MiB
|
|
||||||
|
|
||||||
// The reader input buffer size is deliberately large because the input reader
|
// The reader input buffer size is deliberately large because the input reader
|
||||||
// will throw an error if the compressed block length cannot fit in the input
|
// will throw an error if the compressed block length cannot fit in the input
|
||||||
// buffer.
|
// buffer.
|
||||||
|
@ -75,6 +75,8 @@ const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB
|
||||||
|
|
||||||
const size_t kHeaderSize = sizeof(uint64);
|
const size_t kHeaderSize = sizeof(uint64);
|
||||||
|
|
||||||
|
const int64 kCurrentVersion = 1;
|
||||||
|
|
||||||
constexpr char kModeAuto[] = "auto";
|
constexpr char kModeAuto[] = "auto";
|
||||||
constexpr char kModeWrite[] = "write";
|
constexpr char kModeWrite[] = "write";
|
||||||
constexpr char kModeRead[] = "read";
|
constexpr char kModeRead[] = "read";
|
||||||
|
@ -95,6 +97,7 @@ constexpr char kState[] = "state";
|
||||||
constexpr char kHashDir[] = "hash_dir";
|
constexpr char kHashDir[] = "hash_dir";
|
||||||
constexpr char kRunId[] = "run_id";
|
constexpr char kRunId[] = "run_id";
|
||||||
constexpr char kRunDir[] = "run_dir";
|
constexpr char kRunDir[] = "run_dir";
|
||||||
|
constexpr char kVersionStr[] = "version";
|
||||||
constexpr char kFilenames[] = "filenames";
|
constexpr char kFilenames[] = "filenames";
|
||||||
constexpr char kCurrentFilenames[] = "current_filenames";
|
constexpr char kCurrentFilenames[] = "current_filenames";
|
||||||
constexpr char kElementsProduced[] = "elements_produced";
|
constexpr char kElementsProduced[] = "elements_produced";
|
||||||
|
@ -115,9 +118,9 @@ class SnapshotWriter {
|
||||||
static constexpr const char* const kWriteStringPiece = "WriteStringPiece";
|
static constexpr const char* const kWriteStringPiece = "WriteStringPiece";
|
||||||
static constexpr const char* const kWriteCord = "WriteCord";
|
static constexpr const char* const kWriteCord = "WriteCord";
|
||||||
|
|
||||||
explicit SnapshotWriter(WritableFile* dest, const string& compression_type =
|
explicit SnapshotWriter(WritableFile* dest, const string& compression_type,
|
||||||
io::compression::kNone)
|
int version, const DataTypeVector& dtypes)
|
||||||
: dest_(dest), compression_type_(compression_type) {
|
: dest_(dest), compression_type_(compression_type), version_(version) {
|
||||||
#if defined(IS_SLIM_BUILD)
|
#if defined(IS_SLIM_BUILD)
|
||||||
if (compression_type != io::compression::kNone) {
|
if (compression_type != io::compression::kNone) {
|
||||||
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
|
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
|
||||||
|
@ -134,41 +137,100 @@ class SnapshotWriter {
|
||||||
TF_CHECK_OK(zlib_output_buffer->Init());
|
TF_CHECK_OK(zlib_output_buffer->Init());
|
||||||
dest_ = zlib_output_buffer;
|
dest_ = zlib_output_buffer;
|
||||||
dest_is_owned_ = true;
|
dest_is_owned_ = true;
|
||||||
} else if (compression_type == io::compression::kSnappy) {
|
|
||||||
io::SnappyOutputBuffer* snappy_output_buffer = new io::SnappyOutputBuffer(
|
|
||||||
dest, /*input_buffer_bytes=*/kSnappyWriterInputBufferSizeBytes,
|
|
||||||
/*output_buffer_bytes=*/kSnappyWriterOutputBufferSizeBytes);
|
|
||||||
dest_ = snappy_output_buffer;
|
|
||||||
dest_is_owned_ = true;
|
|
||||||
}
|
}
|
||||||
#endif // IS_SLIM_BUILD
|
#endif // IS_SLIM_BUILD
|
||||||
|
simple_tensor_mask_.reserve(dtypes.size());
|
||||||
|
for (const auto& dtype : dtypes) {
|
||||||
|
if (DataTypeCanUseMemcpy(dtype)) {
|
||||||
|
simple_tensor_mask_.push_back(true);
|
||||||
|
num_simple_++;
|
||||||
|
} else {
|
||||||
|
simple_tensor_mask_.push_back(false);
|
||||||
|
num_complex_++;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status WriteRecord(const StringPiece& data) {
|
Status WriteTensors(const std::vector<Tensor>& tensors) {
|
||||||
profiler::TraceMe activity(
|
if (compression_type_ != io::compression::kSnappy) {
|
||||||
[&]() {
|
experimental::SnapshotRecord record;
|
||||||
return absl::StrCat(kClassName, kSeparator, kWriteStringPiece);
|
for (const auto& tensor : tensors) {
|
||||||
},
|
TensorProto* t = record.add_tensor();
|
||||||
profiler::TraceMeLevel::kInfo);
|
tensor.AsProtoTensorContent(t);
|
||||||
char header[kHeaderSize];
|
|
||||||
core::EncodeFixed64(header, data.size());
|
|
||||||
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
|
||||||
return dest_->Append(data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(PLATFORM_GOOGLE)
|
#if defined(PLATFORM_GOOGLE)
|
||||||
Status WriteRecord(const absl::Cord& data) {
|
return WriteRecord(record.SerializeAsCord());
|
||||||
profiler::TraceMe activity(
|
#else // PLATFORM_GOOGLE
|
||||||
[&]() { return absl::StrCat(kClassName, kSeparator, kWriteCord); },
|
return WriteRecord(record.SerializeAsString());
|
||||||
profiler::TraceMeLevel::kInfo);
|
|
||||||
char header[kHeaderSize];
|
|
||||||
core::EncodeFixed64(header, data.size());
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
|
||||||
|
|
||||||
return dest_->Append(data);
|
|
||||||
}
|
|
||||||
#endif // PLATFORM_GOOGLE
|
#endif // PLATFORM_GOOGLE
|
||||||
|
}
|
||||||
|
|
||||||
|
if (version_ != 1) {
|
||||||
|
return errors::InvalidArgument("Version: ", version_,
|
||||||
|
" is not supported.");
|
||||||
|
}
|
||||||
|
if (compression_type_ != io::compression::kSnappy) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Version 1 is only compatible with snappy compression");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const TensorBuffer*> tensor_buffers;
|
||||||
|
tensor_buffers.reserve(num_simple_);
|
||||||
|
std::vector<TensorProto> tensor_protos;
|
||||||
|
tensor_protos.reserve(num_complex_);
|
||||||
|
SnapshotTensorMetadata metadata;
|
||||||
|
int64 total_size = 0;
|
||||||
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
|
const Tensor& tensor = tensors[i];
|
||||||
|
TensorMetadata* tensor_metadata = metadata.add_tensor_metadata();
|
||||||
|
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
|
||||||
|
int64 size = 0;
|
||||||
|
if (simple_tensor_mask_[i]) {
|
||||||
|
auto tensor_buffer = DMAHelper::buffer(&tensor);
|
||||||
|
tensor_buffers.push_back(tensor_buffer);
|
||||||
|
size = tensor_buffer->size();
|
||||||
|
} else {
|
||||||
|
TensorProto proto;
|
||||||
|
tensor.AsProtoTensorContent(&proto);
|
||||||
|
size = proto.ByteSizeLong();
|
||||||
|
tensor_protos.push_back(std::move(proto));
|
||||||
|
}
|
||||||
|
tensor_metadata->set_tensor_size_bytes(size);
|
||||||
|
total_size += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char> uncompressed(total_size);
|
||||||
|
char* position = uncompressed.data();
|
||||||
|
int buffer_index = 0;
|
||||||
|
int proto_index = 0;
|
||||||
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
|
const auto& tensor_metadata = metadata.tensor_metadata(i);
|
||||||
|
if (simple_tensor_mask_[i]) {
|
||||||
|
memcpy(position, tensor_buffers[buffer_index]->data(),
|
||||||
|
tensor_metadata.tensor_size_bytes());
|
||||||
|
buffer_index++;
|
||||||
|
} else {
|
||||||
|
tensor_protos[proto_index].SerializeToArray(
|
||||||
|
position, tensor_metadata.tensor_size_bytes());
|
||||||
|
proto_index++;
|
||||||
|
}
|
||||||
|
position += tensor_metadata.tensor_size_bytes();
|
||||||
|
}
|
||||||
|
DCHECK_EQ(position, uncompressed.data() + total_size);
|
||||||
|
|
||||||
|
string output;
|
||||||
|
if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
|
||||||
|
return errors::Internal("Failed to compress using snappy.");
|
||||||
|
}
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
absl::Cord metadata_serialized = metadata.SerializeAsCord();
|
||||||
|
#else // PLATFORM_GOOGLE
|
||||||
|
std::string metadata_serialized = metadata.SerializeAsString();
|
||||||
|
#endif // PLATFORM_GOOGLE
|
||||||
|
TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
|
||||||
|
TF_RETURN_IF_ERROR(WriteRecord(output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status Sync() { return dest_->Sync(); }
|
Status Sync() { return dest_->Sync(); }
|
||||||
|
|
||||||
|
@ -192,9 +254,29 @@ class SnapshotWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Status WriteRecord(const StringPiece& data) {
|
||||||
|
char header[kHeaderSize];
|
||||||
|
core::EncodeFixed64(header, data.size());
|
||||||
|
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
||||||
|
return dest_->Append(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
Status WriteRecord(const absl::Cord& data) {
|
||||||
|
char header[kHeaderSize];
|
||||||
|
core::EncodeFixed64(header, data.size());
|
||||||
|
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
||||||
|
return dest_->Append(data);
|
||||||
|
}
|
||||||
|
#endif // PLATFORM_GOOGLE
|
||||||
|
|
||||||
WritableFile* dest_;
|
WritableFile* dest_;
|
||||||
bool dest_is_owned_ = false;
|
bool dest_is_owned_ = false;
|
||||||
const string compression_type_;
|
const string compression_type_;
|
||||||
|
const int version_;
|
||||||
|
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
|
||||||
|
int num_simple_ = 0;
|
||||||
|
int num_complex_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SnapshotReader {
|
class SnapshotReader {
|
||||||
|
@ -203,12 +285,14 @@ class SnapshotReader {
|
||||||
static constexpr const char* const kReadString = "ReadString";
|
static constexpr const char* const kReadString = "ReadString";
|
||||||
static constexpr const char* const kReadCord = "ReadCord";
|
static constexpr const char* const kReadCord = "ReadCord";
|
||||||
|
|
||||||
explicit SnapshotReader(
|
explicit SnapshotReader(RandomAccessFile* file,
|
||||||
RandomAccessFile* file,
|
const string& compression_type, int version,
|
||||||
const string& compression_type = io::compression::kNone)
|
const DataTypeVector& dtypes)
|
||||||
: file_(file),
|
: file_(file),
|
||||||
input_stream_(new io::RandomAccessInputStream(file)),
|
input_stream_(new io::RandomAccessInputStream(file)),
|
||||||
compression_type_(compression_type) {
|
compression_type_(compression_type),
|
||||||
|
version_(version),
|
||||||
|
dtypes_(dtypes) {
|
||||||
#if defined(IS_SLIM_BUILD)
|
#if defined(IS_SLIM_BUILD)
|
||||||
if (compression_type_ != io::compression::kNone) {
|
if (compression_type_ != io::compression::kNone) {
|
||||||
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
|
LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
|
||||||
|
@ -223,17 +307,167 @@ class SnapshotReader {
|
||||||
input_stream_.release(), zlib_options.input_buffer_size,
|
input_stream_.release(), zlib_options.input_buffer_size,
|
||||||
zlib_options.output_buffer_size, zlib_options, true);
|
zlib_options.output_buffer_size, zlib_options, true);
|
||||||
} else if (compression_type_ == io::compression::kSnappy) {
|
} else if (compression_type_ == io::compression::kSnappy) {
|
||||||
|
if (version_ == 0) {
|
||||||
input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
|
input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
|
||||||
file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
|
file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
|
||||||
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
|
/*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
|
||||||
|
} else {
|
||||||
|
input_stream_ =
|
||||||
|
absl::make_unique<io::BufferedInputStream>(file_, 64 << 20);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif // IS_SLIM_BUILD
|
#endif // IS_SLIM_BUILD
|
||||||
|
simple_tensor_mask_.reserve(dtypes.size());
|
||||||
|
for (const auto& dtype : dtypes) {
|
||||||
|
if (DataTypeCanUseMemcpy(dtype)) {
|
||||||
|
simple_tensor_mask_.push_back(true);
|
||||||
|
num_simple_++;
|
||||||
|
} else {
|
||||||
|
simple_tensor_mask_.push_back(false);
|
||||||
|
num_complex_++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||||
|
profiler::TraceMe activity(
|
||||||
|
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
|
||||||
|
profiler::TraceMeLevel::kInfo);
|
||||||
|
if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
|
||||||
|
return ReadTensorsV0(read_tensors);
|
||||||
|
}
|
||||||
|
if (version_ != 1) {
|
||||||
|
return errors::InvalidArgument("Version: ", version_,
|
||||||
|
" is not supported.");
|
||||||
|
}
|
||||||
|
if (compression_type_ != io::compression::kSnappy) {
|
||||||
|
return errors::InvalidArgument("Version 1 only supports snappy.");
|
||||||
|
}
|
||||||
|
|
||||||
|
SnapshotTensorMetadata metadata;
|
||||||
|
tstring metadata_str;
|
||||||
|
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
|
||||||
|
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
|
||||||
|
return errors::DataLoss("Could not parse SnapshotTensorMetadata");
|
||||||
|
}
|
||||||
|
read_tensors->reserve(metadata.tensor_metadata_size());
|
||||||
|
|
||||||
|
std::vector<Tensor> simple_tensors;
|
||||||
|
simple_tensors.reserve(num_simple_);
|
||||||
|
std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
|
||||||
|
tensor_proto_strs.reserve(num_complex_);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
SnappyUncompress(metadata, &simple_tensors, &tensor_proto_strs));
|
||||||
|
|
||||||
|
int simple_index = 0;
|
||||||
|
int complex_index = 0;
|
||||||
|
for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
|
||||||
|
if (simple_tensor_mask_[i]) {
|
||||||
|
read_tensors->push_back(std::move(simple_tensors[simple_index]));
|
||||||
|
simple_index++;
|
||||||
|
} else {
|
||||||
|
auto tensor_proto_str =
|
||||||
|
std::move(tensor_proto_strs[complex_index].first);
|
||||||
|
size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
|
||||||
|
TensorProto tp;
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
auto tensor_proto_ptr = tensor_proto_str.release();
|
||||||
|
absl::Cord c;
|
||||||
|
c.AppendExternalMemory(
|
||||||
|
absl::string_view(tensor_proto_ptr, tensor_proto_size),
|
||||||
|
tensor_proto_ptr,
|
||||||
|
[](void* arg) { delete[] static_cast<char*>(arg); });
|
||||||
|
if (!tp.ParseFromCord(c)) {
|
||||||
|
return errors::Internal("Could not parse TensorProto");
|
||||||
|
}
|
||||||
|
#else // PLATFORM_GOOGLE
|
||||||
|
if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
|
||||||
|
return errors::Internal("Could not parse TensorProto");
|
||||||
|
}
|
||||||
|
#endif // PLATFORM_GOOGLE
|
||||||
|
Tensor t;
|
||||||
|
if (!t.FromProto(tp)) {
|
||||||
|
return errors::Internal("Could not parse Tensor");
|
||||||
|
}
|
||||||
|
read_tensors->push_back(std::move(t));
|
||||||
|
complex_index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status ReadTensorsV0(std::vector<Tensor>* read_tensors) {
|
||||||
|
experimental::SnapshotRecord record;
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
absl::Cord c;
|
||||||
|
TF_RETURN_IF_ERROR(ReadRecord(&c));
|
||||||
|
record.ParseFromCord(c);
|
||||||
|
#else // PLATFORM_GOOGLE
|
||||||
|
tstring record_bytes;
|
||||||
|
TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
|
||||||
|
record.ParseFromArray(record_bytes.data(), record_bytes.size());
|
||||||
|
#endif // PLATFORM_GOOGLE
|
||||||
|
read_tensors->reserve(record.tensor_size());
|
||||||
|
for (int i = 0; i < record.tensor_size(); ++i) {
|
||||||
|
read_tensors->emplace_back();
|
||||||
|
if (!read_tensors->back().FromProto(record.tensor(i))) {
|
||||||
|
return errors::DataLoss("Unable to parse tensor from proto.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SnappyUncompress(
|
||||||
|
const SnapshotTensorMetadata& metadata,
|
||||||
|
std::vector<Tensor>* simple_tensors,
|
||||||
|
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
|
||||||
|
tensor_proto_strs) {
|
||||||
|
tstring compressed;
|
||||||
|
TF_RETURN_IF_ERROR(ReadRecord(&compressed));
|
||||||
|
size_t size;
|
||||||
|
if (!port::Snappy_GetUncompressedLength(compressed.data(),
|
||||||
|
compressed.size(), &size)) {
|
||||||
|
return errors::Internal("Could not get snappy uncompressed length");
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_tensors = metadata.tensor_metadata_size();
|
||||||
|
std::vector<struct iovec> iov(num_tensors);
|
||||||
|
int index = 0;
|
||||||
|
int64 total_size = 0;
|
||||||
|
for (int i = 0; i < simple_tensor_mask_.size(); ++i) {
|
||||||
|
const auto& tensor_metadata = metadata.tensor_metadata(i);
|
||||||
|
if (simple_tensor_mask_[i]) {
|
||||||
|
TensorShape shape(tensor_metadata.tensor_shape());
|
||||||
|
Tensor simple_tensor(dtypes_[i], shape);
|
||||||
|
TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
|
||||||
|
iov[index].iov_base = buffer->data();
|
||||||
|
iov[index].iov_len = buffer->size();
|
||||||
|
simple_tensors->push_back(std::move(simple_tensor));
|
||||||
|
} else {
|
||||||
|
auto tensor_proto_str =
|
||||||
|
absl::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
|
||||||
|
iov[index].iov_base = tensor_proto_str.get();
|
||||||
|
iov[index].iov_len = tensor_metadata.tensor_size_bytes();
|
||||||
|
tensor_proto_strs->push_back(std::make_pair(
|
||||||
|
std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
|
||||||
|
}
|
||||||
|
total_size += iov[index].iov_len;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
if (size != total_size) {
|
||||||
|
return errors::Internal("Uncompressed size mismatch. Snappy expects ",
|
||||||
|
size, " whereas the tensor metadata suggests ",
|
||||||
|
total_size);
|
||||||
|
}
|
||||||
|
if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
|
||||||
|
iov.data(), num_tensors)) {
|
||||||
|
return errors::Internal("Failed to perform snappy decompression.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ReadRecord(tstring* record) {
|
Status ReadRecord(tstring* record) {
|
||||||
profiler::TraceMe activity(
|
|
||||||
[&]() { return absl::StrCat(kClassName, kSeparator, kReadString); },
|
|
||||||
profiler::TraceMeLevel::kInfo);
|
|
||||||
tstring header;
|
tstring header;
|
||||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||||
uint64 length = core::DecodeFixed64(header.data());
|
uint64 length = core::DecodeFixed64(header.data());
|
||||||
|
@ -245,13 +479,6 @@ class SnapshotReader {
|
||||||
tstring header;
|
tstring header;
|
||||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||||
uint64 length = core::DecodeFixed64(header.data());
|
uint64 length = core::DecodeFixed64(header.data());
|
||||||
profiler::TraceMe activity(
|
|
||||||
[&]() {
|
|
||||||
return absl::StrCat(kClassName, kSeparator, kReadCord,
|
|
||||||
"#length=", length, "#");
|
|
||||||
},
|
|
||||||
profiler::TraceMeLevel::kInfo);
|
|
||||||
|
|
||||||
if (compression_type_ == io::compression::kNone) {
|
if (compression_type_ == io::compression::kNone) {
|
||||||
return input_stream_->ReadNBytes(length, record);
|
return input_stream_->ReadNBytes(length, record);
|
||||||
} else {
|
} else {
|
||||||
|
@ -268,50 +495,31 @@ class SnapshotReader {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
private:
|
|
||||||
RandomAccessFile* file_;
|
RandomAccessFile* file_;
|
||||||
std::unique_ptr<io::InputStreamInterface> input_stream_;
|
std::unique_ptr<io::InputStreamInterface> input_stream_;
|
||||||
const string compression_type_;
|
const string compression_type_;
|
||||||
|
const int version_;
|
||||||
|
const DataTypeVector dtypes_;
|
||||||
|
int num_simple_ = 0;
|
||||||
|
int num_complex_ = 0;
|
||||||
|
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
|
||||||
};
|
};
|
||||||
|
|
||||||
Status WriteMetadataFile(const string& hash_dir,
|
Status WriteMetadataFile(const string& hash_dir,
|
||||||
const experimental::SnapshotMetadataRecord& metadata) {
|
const experimental::SnapshotMetadataRecord& metadata) {
|
||||||
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
|
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
|
||||||
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
|
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
|
||||||
|
|
||||||
std::string tmp_filename =
|
std::string tmp_filename =
|
||||||
absl::StrCat(metadata_filename, "-tmp-", random::New64());
|
absl::StrCat(metadata_filename, "-tmp-", random::New64());
|
||||||
|
TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, metadata));
|
||||||
std::unique_ptr<WritableFile> file;
|
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
|
||||||
TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(tmp_filename, &file));
|
|
||||||
|
|
||||||
auto writer = absl::make_unique<SnapshotWriter>(file.get());
|
|
||||||
TF_RETURN_IF_ERROR(writer->WriteRecord(metadata.SerializeAsString()));
|
|
||||||
TF_RETURN_IF_ERROR(writer->Close());
|
|
||||||
TF_RETURN_IF_ERROR(file->Sync());
|
|
||||||
TF_RETURN_IF_ERROR(file->Close());
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Env::Default()->RenameFile(tmp_filename, metadata_filename));
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ReadMetadataFile(const string& hash_dir,
|
Status ReadMetadataFile(const string& hash_dir,
|
||||||
experimental::SnapshotMetadataRecord* metadata) {
|
experimental::SnapshotMetadataRecord* metadata) {
|
||||||
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
|
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
|
||||||
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
|
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
|
||||||
|
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
|
||||||
std::unique_ptr<RandomAccessFile> file;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Env::Default()->NewRandomAccessFile(metadata_filename, &file));
|
|
||||||
|
|
||||||
tstring record_bytes;
|
|
||||||
SnapshotReader reader(file.get());
|
|
||||||
TF_RETURN_IF_ERROR(reader.ReadRecord(&record_bytes));
|
|
||||||
|
|
||||||
metadata->ParseFromArray(record_bytes.data(), record_bytes.size());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
||||||
|
@ -332,6 +540,10 @@ Status DetermineOpState(const std::string& mode_string,
|
||||||
const uint64 pending_snapshot_expiry_seconds,
|
const uint64 pending_snapshot_expiry_seconds,
|
||||||
SnapshotMode* mode) {
|
SnapshotMode* mode) {
|
||||||
if (mode_string == kModeRead) {
|
if (mode_string == kModeRead) {
|
||||||
|
// In read mode, we should expect a metadata file is written.
|
||||||
|
if (errors::IsNotFound(file_status)) {
|
||||||
|
return file_status;
|
||||||
|
}
|
||||||
LOG(INFO) << "Overriding mode to reader.";
|
LOG(INFO) << "Overriding mode to reader.";
|
||||||
*mode = READER;
|
*mode = READER;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -727,10 +939,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
if (run_id.empty()) {
|
if (run_id.empty()) {
|
||||||
run_id = metadata.run_id();
|
run_id = metadata.run_id();
|
||||||
}
|
}
|
||||||
|
// dtypes in metadata should be the same as dataset()->output_dtypes
|
||||||
|
if (metadata.dtype_size() != dataset()->output_dtypes().size()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Expected number of dtypes: ",
|
||||||
|
dataset()->output_dtypes().size(),
|
||||||
|
" but number in snapshot: ", metadata.dtype_size());
|
||||||
|
}
|
||||||
|
for (int i = 0; i < metadata.dtype_size(); ++i) {
|
||||||
|
if (metadata.dtype(i) != dataset()->output_dtypes()[i]) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Type: ", i,
|
||||||
|
" doesn't match. Snapshot: ", metadata.dtype(i),
|
||||||
|
"; dataset: ", dataset()->output_dtypes()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
iterator_ = absl::make_unique<SnapshotReaderIterator>(
|
iterator_ = absl::make_unique<SnapshotReaderIterator>(
|
||||||
SnapshotReaderIterator::Params{
|
SnapshotReaderIterator::Params{
|
||||||
dataset(), absl::StrCat(prefix(), "ReaderImpl")},
|
dataset(), absl::StrCat(prefix(), "ReaderImpl")},
|
||||||
hash_dir_, run_id);
|
hash_dir_, run_id, metadata.version());
|
||||||
break;
|
break;
|
||||||
case PASSTHROUGH:
|
case PASSTHROUGH:
|
||||||
iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
|
iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
|
||||||
|
@ -748,10 +975,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
|
|
||||||
explicit SnapshotReaderIterator(const Params& params,
|
explicit SnapshotReaderIterator(const Params& params,
|
||||||
const string& hash_dir,
|
const string& hash_dir,
|
||||||
const string& run_id)
|
const string& run_id, int64 version)
|
||||||
: DatasetIterator<Dataset>(params),
|
: DatasetIterator<Dataset>(params),
|
||||||
hash_dir_(hash_dir),
|
hash_dir_(hash_dir),
|
||||||
run_id_(run_id) {}
|
run_id_(run_id),
|
||||||
|
version_(version) {}
|
||||||
|
|
||||||
~SnapshotReaderIterator() override {
|
~SnapshotReaderIterator() override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
@ -889,6 +1117,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
writer->WriteScalar(full_name(kHashDir), hash_dir_));
|
writer->WriteScalar(full_name(kHashDir), hash_dir_));
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_));
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_));
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_));
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name(kVersionStr), version_));
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||||
full_name(strings::StrCat(kFilenames, kSizeSuffix)),
|
full_name(strings::StrCat(kFilenames, kSizeSuffix)),
|
||||||
filenames_.size()));
|
filenames_.size()));
|
||||||
|
@ -932,6 +1162,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_));
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_));
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_));
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(full_name(kVersionStr), &version_));
|
||||||
curr_filenames_.clear();
|
curr_filenames_.clear();
|
||||||
curr_filenames_.reserve(dataset()->num_reader_threads_);
|
curr_filenames_.reserve(dataset()->num_reader_threads_);
|
||||||
for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
|
for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
|
||||||
|
@ -986,7 +1218,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
std::unique_ptr<RandomAccessFile> file;
|
std::unique_ptr<RandomAccessFile> file;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
Env::Default()->NewRandomAccessFile(filename, &file));
|
Env::Default()->NewRandomAccessFile(filename, &file));
|
||||||
SnapshotReader reader(file.get(), dataset()->compression_);
|
SnapshotReader reader(file.get(), dataset()->compression_, version_,
|
||||||
|
dataset()->output_dtypes());
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
// Wait for a slot in the buffer.
|
// Wait for a slot in the buffer.
|
||||||
|
@ -1003,30 +1236,14 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
"ReadFile");
|
"ReadFile");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#if !defined(PLATFORM_GOOGLE)
|
std::vector<Tensor> read_tensors;
|
||||||
tstring record_bytes;
|
Status s = reader.ReadTensors(&read_tensors);
|
||||||
Status s = reader.ReadRecord(&record_bytes);
|
|
||||||
#else
|
|
||||||
absl::Cord record_cord;
|
|
||||||
Status s = reader.ReadRecord(&record_cord);
|
|
||||||
#endif
|
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
|
[&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
experimental::SnapshotRecord record;
|
|
||||||
#if !defined(PLATFORM_GOOGLE)
|
|
||||||
record.ParseFromArray(record_bytes.data(), record_bytes.size());
|
|
||||||
#else
|
|
||||||
record.ParseFromCord(record_cord);
|
|
||||||
#endif
|
|
||||||
BufferElement elem;
|
BufferElement elem;
|
||||||
for (int i = 0; i < record.tensor_size(); ++i) {
|
elem.value = std::move(read_tensors);
|
||||||
elem.value.emplace_back();
|
|
||||||
if (!elem.value.back().FromProto(record.tensor(i))) {
|
|
||||||
return errors::DataLoss("Unable to parse tensor from proto.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
elem.status = Status::OK();
|
elem.status = Status::OK();
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
buffer_.push_back(std::move(elem));
|
buffer_.push_back(std::move(elem));
|
||||||
|
@ -1142,9 +1359,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
condition_variable cond_var_;
|
condition_variable cond_var_;
|
||||||
|
|
||||||
const string hash_dir_;
|
const string hash_dir_;
|
||||||
const experimental::SnapshotMetadataRecord metadata_;
|
|
||||||
tstring run_id_ GUARDED_BY(mu_);
|
tstring run_id_ GUARDED_BY(mu_);
|
||||||
tstring run_dir_ GUARDED_BY(mu_);
|
tstring run_dir_ GUARDED_BY(mu_);
|
||||||
|
int64 version_;
|
||||||
std::vector<tstring> filenames_;
|
std::vector<tstring> filenames_;
|
||||||
|
|
||||||
uint64 elements_produced_ GUARDED_BY(mu_) = 0;
|
uint64 elements_produced_ GUARDED_BY(mu_) = 0;
|
||||||
|
@ -1220,6 +1437,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
metadata.set_creation_timestamp(EnvTime::NowMicros());
|
metadata.set_creation_timestamp(EnvTime::NowMicros());
|
||||||
metadata.set_graph_hash(dataset()->graph_hash_);
|
metadata.set_graph_hash(dataset()->graph_hash_);
|
||||||
metadata.set_run_id(run_id_.data(), run_id_.size());
|
metadata.set_run_id(run_id_.data(), run_id_.size());
|
||||||
|
metadata.set_version(kCurrentVersion);
|
||||||
|
for (const auto& output_dtype : dataset()->output_dtypes()) {
|
||||||
|
metadata.add_dtype(output_dtype);
|
||||||
|
}
|
||||||
metadata.set_finalized(false);
|
metadata.set_finalized(false);
|
||||||
TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata));
|
TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata));
|
||||||
}
|
}
|
||||||
|
@ -1564,11 +1785,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (produced_elem) {
|
if (produced_elem) {
|
||||||
experimental::SnapshotRecord record;
|
|
||||||
for (const auto& out_tensor : elem.value) {
|
for (const auto& out_tensor : elem.value) {
|
||||||
*bytes_written += out_tensor.TotalBytes();
|
*bytes_written += out_tensor.TotalBytes();
|
||||||
TensorProto* t = record.add_tensor();
|
|
||||||
out_tensor.AsProtoTensorContent(t);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool should_close;
|
bool should_close;
|
||||||
|
@ -1584,16 +1802,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
|
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
|
||||||
*snapshot_data_filename, file));
|
*snapshot_data_filename, file));
|
||||||
*writer = absl::make_unique<SnapshotWriter>(
|
*writer = absl::make_unique<SnapshotWriter>(
|
||||||
file->get(), dataset()->compression_);
|
file->get(), dataset()->compression_, kCurrentVersion,
|
||||||
|
dataset()->output_dtypes());
|
||||||
*bytes_written = 0;
|
*bytes_written = 0;
|
||||||
}
|
}
|
||||||
#if defined(PLATFORM_GOOGLE)
|
TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
(*writer)->WriteRecord(record.SerializeAsCord()));
|
|
||||||
#else // PLATFORM_GOOGLE
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
(*writer)->WriteRecord(record.SerializeAsString()));
|
|
||||||
#endif // PLATFORM_GOOGLE
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1641,7 +1854,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::unique_ptr<SnapshotWriter> writer(
|
std::unique_ptr<SnapshotWriter> writer(
|
||||||
new SnapshotWriter(file.get(), dataset()->compression_));
|
new SnapshotWriter(file.get(), dataset()->compression_,
|
||||||
|
kCurrentVersion, dataset()->output_dtypes()));
|
||||||
|
|
||||||
bool end_of_processing = false;
|
bool end_of_processing = false;
|
||||||
while (!end_of_processing) {
|
while (!end_of_processing) {
|
||||||
|
|
|
@ -332,6 +332,16 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
|
||||||
|
const struct iovec* iov, size_t iov_cnt) {
|
||||||
|
#ifdef TF_USE_SNAPPY
|
||||||
|
return snappy::RawUncompressToIOVec(compressed, compressed_length, iov,
|
||||||
|
iov_cnt);
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
string Demangle(const char* mangled) { return mangled; }
|
string Demangle(const char* mangled) { return mangled; }
|
||||||
|
|
||||||
double NominalCPUFrequency() {
|
double NominalCPUFrequency() {
|
||||||
|
|
|
@ -18,6 +18,17 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
#if !defined(PLATFORM_WINDOWS)
|
||||||
|
#include <sys/uio.h>
|
||||||
|
#else
|
||||||
|
namespace tensorflow {
|
||||||
|
struct iovec {
|
||||||
|
void* iov_base;
|
||||||
|
size_t iov_len;
|
||||||
|
};
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace port {
|
namespace port {
|
||||||
|
|
||||||
|
@ -28,6 +39,9 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
|
||||||
size_t* result);
|
size_t* result);
|
||||||
bool Snappy_Uncompress(const char* input, size_t length, char* output);
|
bool Snappy_Uncompress(const char* input, size_t length, char* output);
|
||||||
|
|
||||||
|
bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
|
||||||
|
const struct iovec* iov, size_t iov_cnt);
|
||||||
|
|
||||||
} // namespace port
|
} // namespace port
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
|
|
@ -157,6 +157,17 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length,
|
||||||
|
const struct iovec* iov, size_t iov_cnt) {
|
||||||
|
#ifdef TF_USE_SNAPPY
|
||||||
|
const snappy::iovec* snappy_iov = reinterpret_cast<const snappy::iovec*>(iov);
|
||||||
|
return snappy::RawUncompressToIOVec(compressed, compressed_length, snappy_iov,
|
||||||
|
iov_cnt);
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
string Demangle(const char* mangled) { return mangled; }
|
string Demangle(const char* mangled) { return mangled; }
|
||||||
|
|
||||||
double NominalCPUFrequency() {
|
double NominalCPUFrequency() {
|
||||||
|
|
|
@ -3,6 +3,8 @@ syntax = "proto3";
|
||||||
package tensorflow.data.experimental;
|
package tensorflow.data.experimental;
|
||||||
|
|
||||||
import "tensorflow/core/framework/tensor.proto";
|
import "tensorflow/core/framework/tensor.proto";
|
||||||
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||||||
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
|
||||||
// Each SnapshotRecord represents one batch of pre-processed input data. A batch
|
// Each SnapshotRecord represents one batch of pre-processed input data. A batch
|
||||||
// consists of a list of tensors that we encode as TensorProtos. This message
|
// consists of a list of tensors that we encode as TensorProtos. This message
|
||||||
|
@ -13,9 +15,29 @@ message SnapshotRecord {
|
||||||
|
|
||||||
// This stores the metadata information present in each snapshot record.
|
// This stores the metadata information present in each snapshot record.
|
||||||
message SnapshotMetadataRecord {
|
message SnapshotMetadataRecord {
|
||||||
|
// Stores the fingerprint of the graph that describes the dataset that is
|
||||||
|
// snapshotted.
|
||||||
string graph_hash = 1;
|
string graph_hash = 1;
|
||||||
|
// Run ID that this snapshot corresponds to.
|
||||||
string run_id = 2;
|
string run_id = 2;
|
||||||
|
// Time when we started creating this snapshot.
|
||||||
int64 creation_timestamp = 3;
|
int64 creation_timestamp = 3;
|
||||||
|
// Version of the snapshot data file format.
|
||||||
|
int64 version = 4;
|
||||||
|
// A list of tensor dtype corresponding to each element of the snapshot.
|
||||||
|
repeated .tensorflow.DataType dtype = 5;
|
||||||
|
|
||||||
bool finalized = 1000;
|
bool finalized = 1000;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Metadata for a single tensor in the Snapshot Record.
|
||||||
|
message TensorMetadata {
|
||||||
|
.tensorflow.TensorShapeProto tensor_shape = 2;
|
||||||
|
// Number of uncompressed bytes used to store the tensor representation.
|
||||||
|
int64 tensor_size_bytes = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metadata for all the tensors in a Snapshot Record.
|
||||||
|
message SnapshotTensorMetadata {
|
||||||
|
repeated TensorMetadata tensor_metadata = 1;
|
||||||
|
}
|
||||||
|
|
|
@ -161,17 +161,49 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||||
|
|
||||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(
|
||||||
def testWriteSnapshotRepeatAfterwards(self):
|
combinations.times(
|
||||||
|
test_base.default_test_combinations(),
|
||||||
|
combinations.combine(compression=[
|
||||||
|
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
|
||||||
|
snapshot.COMPRESSION_SNAPPY
|
||||||
|
])))
|
||||||
|
def testWriteSnapshotRepeatAfterwards(self, compression):
|
||||||
tmpdir = self.snapshot_dir
|
tmpdir = self.snapshot_dir
|
||||||
|
|
||||||
dataset = dataset_ops.Dataset.range(10)
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
dataset = dataset.apply(snapshot.snapshot(tmpdir))
|
dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
|
||||||
dataset = dataset.repeat(10)
|
dataset = dataset.repeat(10)
|
||||||
self.assertDatasetProduces(dataset, list(range(10)) * 10)
|
self.assertDatasetProduces(dataset, list(range(10)) * 10)
|
||||||
|
|
||||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
test_base.default_test_combinations(),
|
||||||
|
combinations.combine(compression=[
|
||||||
|
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
|
||||||
|
snapshot.COMPRESSION_SNAPPY
|
||||||
|
])))
|
||||||
|
def testWriteSnapshotMixTypes(self, compression):
|
||||||
|
tmpdir = self.snapshot_dir
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
|
def map_fn(x):
|
||||||
|
return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x)
|
||||||
|
|
||||||
|
dataset = dataset.map(map_fn)
|
||||||
|
dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
|
||||||
|
dataset = dataset.repeat(10)
|
||||||
|
|
||||||
|
expected = []
|
||||||
|
for i in range(10):
|
||||||
|
expected.append((i, str(i), str(2 * i), 2 * i))
|
||||||
|
self.assertDatasetProduces(dataset, expected * 10)
|
||||||
|
|
||||||
|
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testSpecifySnapshotNameWriteAndRead(self):
|
def testSpecifySnapshotNameWriteAndRead(self):
|
||||||
tmpdir = self.snapshot_dir
|
tmpdir = self.snapshot_dir
|
||||||
|
@ -365,8 +397,14 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||||
res3 = self.evaluate(next3())
|
res3 = self.evaluate(next3())
|
||||||
self.assertEqual(res2, res3)
|
self.assertEqual(res2, res3)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(
|
||||||
def testReadSnapshotParallelAfterWrite(self):
|
combinations.times(
|
||||||
|
test_base.default_test_combinations(),
|
||||||
|
combinations.combine(compression=[
|
||||||
|
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
|
||||||
|
snapshot.COMPRESSION_SNAPPY
|
||||||
|
])))
|
||||||
|
def testReadSnapshotParallelAfterWrite(self, compression):
|
||||||
self.setUpTFRecord(10, 4000)
|
self.setUpTFRecord(10, 4000)
|
||||||
filenames = self.test_filenames
|
filenames = self.test_filenames
|
||||||
|
|
||||||
|
@ -383,7 +421,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||||
tmpdir,
|
tmpdir,
|
||||||
shard_size_bytes=1024 * 1024,
|
shard_size_bytes=1024 * 1024,
|
||||||
num_reader_threads=2,
|
num_reader_threads=2,
|
||||||
reader_buffer_size=10))
|
reader_buffer_size=10,
|
||||||
|
compression=compression))
|
||||||
self.assertDatasetProduces(dataset, expected, assert_items_equal=True)
|
self.assertDatasetProduces(dataset, expected, assert_items_equal=True)
|
||||||
|
|
||||||
# remove the original files and try to read the data back only from
|
# remove the original files and try to read the data back only from
|
||||||
|
@ -396,7 +435,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||||
tmpdir,
|
tmpdir,
|
||||||
shard_size_bytes=1024 * 1024,
|
shard_size_bytes=1024 * 1024,
|
||||||
num_reader_threads=2,
|
num_reader_threads=2,
|
||||||
reader_buffer_size=10))
|
reader_buffer_size=10,
|
||||||
|
compression=compression))
|
||||||
self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
|
self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)
|
||||||
|
|
||||||
# Not testing Snappy here because Snappy reads currently require a lot of
|
# Not testing Snappy here because Snappy reads currently require a lot of
|
||||||
|
@ -514,21 +554,31 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||||
self.evaluate(next2())
|
self.evaluate(next2())
|
||||||
self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
|
self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(
|
||||||
def testSpecifyShardSize(self):
|
combinations.times(
|
||||||
|
test_base.default_test_combinations(),
|
||||||
|
combinations.combine(compression=[
|
||||||
|
snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
|
||||||
|
snapshot.COMPRESSION_SNAPPY
|
||||||
|
])))
|
||||||
|
def testSpecifyShardSize(self, compression):
|
||||||
tmpdir = self.snapshot_dir
|
tmpdir = self.snapshot_dir
|
||||||
|
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
|
dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
|
||||||
dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024]))
|
dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024]))
|
||||||
dataset = dataset.repeat(10)
|
dataset = dataset.repeat(10)
|
||||||
dataset = dataset.apply(
|
dataset = dataset.apply(
|
||||||
snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024))
|
snapshot.snapshot(
|
||||||
|
tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression))
|
||||||
next_fn = self.getNext(dataset)
|
next_fn = self.getNext(dataset)
|
||||||
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.evaluate(next_fn())
|
self.evaluate(next_fn())
|
||||||
|
|
||||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 3)
|
num_files = 1
|
||||||
|
if compression == snapshot.COMPRESSION_NONE:
|
||||||
|
num_files = 3
|
||||||
|
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testAdditionalOperationsAfterReadBack(self):
|
def testAdditionalOperationsAfterReadBack(self):
|
||||||
|
|
|
@ -27,6 +27,10 @@ cc_library(
|
||||||
"-Wno-implicit-function-declaration",
|
"-Wno-implicit-function-declaration",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
|
defines = select({
|
||||||
|
"@org_tensorflow//tensorflow:windows": [],
|
||||||
|
"//conditions:default": ["HAVE_SYS_UIO_H"],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
genrule(
|
genrule(
|
||||||
|
|
Loading…
Reference in New Issue