Change namespace for snapshot utilities to snapshot_util from experimental
PiperOrigin-RevId: 299197783 Change-Id: I8c2f33e319379df7b69246b1ae3116e17a5d2039
This commit is contained in:
parent
059c3253d9
commit
94c04fc77e
@ -125,7 +125,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_));
|
||||||
|
|
||||||
mode_ = kSnapshotModeAuto;
|
mode_ = snapshot_util::kModeAuto;
|
||||||
if (ctx->HasAttr("mode")) {
|
if (ctx->HasAttr("mode")) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_));
|
||||||
}
|
}
|
||||||
@ -160,14 +160,15 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"pending_snapshot_expiry_seconds must be at least 1 second."));
|
"pending_snapshot_expiry_seconds must be at least 1 second."));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(ctx,
|
||||||
ctx,
|
mode_ == snapshot_util::kModeAuto ||
|
||||||
mode_ == kSnapshotModeAuto || mode_ == kSnapshotModeRead ||
|
mode_ == snapshot_util::kModeRead ||
|
||||||
mode_ == kSnapshotModeWrite || mode_ == kSnapshotModePassthrough,
|
mode_ == snapshot_util::kModeWrite ||
|
||||||
errors::InvalidArgument("mode must be either '", kSnapshotModeAuto,
|
mode_ == snapshot_util::kModePassthrough,
|
||||||
"', '", kSnapshotModeRead, "', '",
|
errors::InvalidArgument(
|
||||||
kSnapshotModeWrite, "', or '",
|
"mode must be either '", snapshot_util::kModeAuto, "', '",
|
||||||
kSnapshotModePassthrough, "'."));
|
snapshot_util::kModeRead, "', '", snapshot_util::kModeWrite,
|
||||||
|
"', or '", snapshot_util::kModePassthrough, "'."));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -190,7 +191,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
uint64 hash;
|
uint64 hash;
|
||||||
OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
|
OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
|
||||||
|
|
||||||
Status dump_status = SnapshotDumpDatasetGraph(path, hash, &graph_def);
|
Status dump_status =
|
||||||
|
snapshot_util::DumpDatasetGraph(path, hash, &graph_def);
|
||||||
if (!dump_status.ok()) {
|
if (!dump_status.ok()) {
|
||||||
LOG(WARNING) << "Unable to write graphdef to disk, error: "
|
LOG(WARNING) << "Unable to write graphdef to disk, error: "
|
||||||
<< dump_status.ToString();
|
<< dump_status.ToString();
|
||||||
@ -376,8 +378,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (iterator_ == nullptr) {
|
if (iterator_ == nullptr) {
|
||||||
experimental::SnapshotMetadataRecord metadata;
|
experimental::SnapshotMetadataRecord metadata;
|
||||||
Status s = SnapshotReadMetadataFile(hash_dir_, &metadata);
|
Status s = snapshot_util::ReadMetadataFile(hash_dir_, &metadata);
|
||||||
TF_RETURN_IF_ERROR(SnapshotDetermineOpState(
|
TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
|
||||||
dataset()->mode_, s, &metadata,
|
dataset()->mode_, s, &metadata,
|
||||||
dataset()->pending_snapshot_expiry_seconds_, &state_));
|
dataset()->pending_snapshot_expiry_seconds_, &state_));
|
||||||
VLOG(2) << "Snapshot state: " << state_;
|
VLOG(2) << "Snapshot state: " << state_;
|
||||||
@ -411,10 +413,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
{
|
{
|
||||||
int64 temp;
|
int64 temp;
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp));
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp));
|
||||||
state_ = SnapshotMode(temp);
|
state_ = snapshot_util::Mode(temp);
|
||||||
}
|
}
|
||||||
experimental::SnapshotMetadataRecord metadata;
|
experimental::SnapshotMetadataRecord metadata;
|
||||||
TF_RETURN_IF_ERROR(SnapshotReadMetadataFile(hash_dir_, &metadata));
|
TF_RETURN_IF_ERROR(
|
||||||
|
snapshot_util::ReadMetadataFile(hash_dir_, &metadata));
|
||||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
|
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
|
||||||
VLOG(2) << "Restoring Snapshot iterator: " << state_;
|
VLOG(2) << "Restoring Snapshot iterator: " << state_;
|
||||||
return RestoreInput(ctx, reader, iterator_);
|
return RestoreInput(ctx, reader, iterator_);
|
||||||
@ -434,13 +437,13 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch (state_) {
|
switch (state_) {
|
||||||
case WRITER:
|
case snapshot_util::WRITER:
|
||||||
iterator_ = absl::make_unique<SnapshotWriterIterator>(
|
iterator_ = absl::make_unique<SnapshotWriterIterator>(
|
||||||
SnapshotWriterIterator::Params{
|
SnapshotWriterIterator::Params{
|
||||||
dataset(), absl::StrCat(prefix(), "WriterImpl")},
|
dataset(), absl::StrCat(prefix(), "WriterImpl")},
|
||||||
hash_dir_, run_id);
|
hash_dir_, run_id);
|
||||||
break;
|
break;
|
||||||
case READER:
|
case snapshot_util::READER:
|
||||||
if (run_id.empty() && metadata.run_id().empty()) {
|
if (run_id.empty() && metadata.run_id().empty()) {
|
||||||
return errors::NotFound(
|
return errors::NotFound(
|
||||||
"Could not find a valid snapshot to read.");
|
"Could not find a valid snapshot to read.");
|
||||||
@ -468,7 +471,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
dataset(), absl::StrCat(prefix(), "ReaderImpl")},
|
dataset(), absl::StrCat(prefix(), "ReaderImpl")},
|
||||||
hash_dir_, run_id, metadata.version());
|
hash_dir_, run_id, metadata.version());
|
||||||
break;
|
break;
|
||||||
case PASSTHROUGH:
|
case snapshot_util::PASSTHROUGH:
|
||||||
iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
|
iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
|
||||||
SnapshotPassthroughIterator::Params{
|
SnapshotPassthroughIterator::Params{
|
||||||
dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
|
dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
|
||||||
@ -727,8 +730,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::unique_ptr<RandomAccessFile> file;
|
std::unique_ptr<RandomAccessFile> file;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
Env::Default()->NewRandomAccessFile(filename, &file));
|
Env::Default()->NewRandomAccessFile(filename, &file));
|
||||||
SnapshotReader reader(file.get(), dataset()->compression_, version_,
|
snapshot_util::Reader reader(file.get(), dataset()->compression_,
|
||||||
dataset()->output_dtypes());
|
version_, dataset()->output_dtypes());
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
// Wait for a slot in the buffer.
|
// Wait for a slot in the buffer.
|
||||||
@ -953,7 +956,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
metadata.set_finalized(false);
|
metadata.set_finalized(false);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
SnapshotWriteMetadataFile(hash_dir_, &metadata));
|
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
|
||||||
}
|
}
|
||||||
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
|
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
|
||||||
++num_active_threads_;
|
++num_active_threads_;
|
||||||
@ -1252,7 +1255,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
Status ProcessOneElement(int64* bytes_written,
|
Status ProcessOneElement(int64* bytes_written,
|
||||||
string* snapshot_data_filename,
|
string* snapshot_data_filename,
|
||||||
std::unique_ptr<WritableFile>* file,
|
std::unique_ptr<WritableFile>* file,
|
||||||
std::unique_ptr<SnapshotWriter>* writer,
|
std::unique_ptr<snapshot_util::Writer>* writer,
|
||||||
bool* end_of_processing) {
|
bool* end_of_processing) {
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&]() {
|
[&]() {
|
||||||
@ -1312,7 +1315,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
*snapshot_data_filename = GetSnapshotFilename();
|
*snapshot_data_filename = GetSnapshotFilename();
|
||||||
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
|
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
|
||||||
*snapshot_data_filename, file));
|
*snapshot_data_filename, file));
|
||||||
*writer = absl::make_unique<SnapshotWriter>(
|
*writer = absl::make_unique<snapshot_util::Writer>(
|
||||||
file->get(), dataset()->compression_, kCurrentVersion,
|
file->get(), dataset()->compression_, kCurrentVersion,
|
||||||
dataset()->output_dtypes());
|
dataset()->output_dtypes());
|
||||||
*bytes_written = 0;
|
*bytes_written = 0;
|
||||||
@ -1329,12 +1332,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
if (!written_final_metadata_file_) {
|
if (!written_final_metadata_file_) {
|
||||||
experimental::SnapshotMetadataRecord metadata;
|
experimental::SnapshotMetadataRecord metadata;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
SnapshotReadMetadataFile(hash_dir_, &metadata));
|
snapshot_util::ReadMetadataFile(hash_dir_, &metadata));
|
||||||
|
|
||||||
if (metadata.run_id() == run_id_) {
|
if (metadata.run_id() == run_id_) {
|
||||||
metadata.set_finalized(true);
|
metadata.set_finalized(true);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
SnapshotWriteMetadataFile(hash_dir_, &metadata));
|
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
|
||||||
} else {
|
} else {
|
||||||
// TODO(frankchn): We lost the race, remove all snapshots.
|
// TODO(frankchn): We lost the race, remove all snapshots.
|
||||||
}
|
}
|
||||||
@ -1366,9 +1369,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
cond_var_.notify_all();
|
cond_var_.notify_all();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::unique_ptr<SnapshotWriter> writer(
|
std::unique_ptr<snapshot_util::Writer> writer(
|
||||||
new SnapshotWriter(file.get(), dataset()->compression_,
|
new snapshot_util::Writer(file.get(), dataset()->compression_,
|
||||||
kCurrentVersion, dataset()->output_dtypes()));
|
kCurrentVersion,
|
||||||
|
dataset()->output_dtypes()));
|
||||||
|
|
||||||
bool end_of_processing = false;
|
bool end_of_processing = false;
|
||||||
while (!end_of_processing) {
|
while (!end_of_processing) {
|
||||||
@ -1389,8 +1393,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ShouldCloseFile(const string& filename, uint64 bytes_written,
|
Status ShouldCloseFile(const string& filename, uint64 bytes_written,
|
||||||
SnapshotWriter* writer, WritableFile* file,
|
snapshot_util::Writer* writer,
|
||||||
bool* should_close) {
|
WritableFile* file, bool* should_close) {
|
||||||
// If the compression ratio has been estimated, use it to decide
|
// If the compression ratio has been estimated, use it to decide
|
||||||
// whether the file should be closed. We avoid estimating the
|
// whether the file should be closed. We avoid estimating the
|
||||||
// compression ratio repeatedly because it requires syncing the file,
|
// compression ratio repeatedly because it requires syncing the file,
|
||||||
@ -1490,7 +1494,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
string hash_dir_ TF_GUARDED_BY(mu_);
|
string hash_dir_ TF_GUARDED_BY(mu_);
|
||||||
SnapshotMode state_ TF_GUARDED_BY(mu_);
|
snapshot_util::Mode state_ TF_GUARDED_BY(mu_);
|
||||||
std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
|
std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
@ -34,11 +34,10 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace experimental {
|
namespace snapshot_util {
|
||||||
|
|
||||||
SnapshotWriter::SnapshotWriter(WritableFile* dest,
|
Writer::Writer(WritableFile* dest, const string& compression_type, int version,
|
||||||
const string& compression_type, int version,
|
const DataTypeVector& dtypes)
|
||||||
const DataTypeVector& dtypes)
|
|
||||||
: dest_(dest), compression_type_(compression_type), version_(version) {
|
: dest_(dest), compression_type_(compression_type), version_(version) {
|
||||||
#if defined(IS_SLIM_BUILD)
|
#if defined(IS_SLIM_BUILD)
|
||||||
if (compression_type != io::compression::kNone) {
|
if (compression_type != io::compression::kNone) {
|
||||||
@ -70,7 +69,7 @@ SnapshotWriter::SnapshotWriter(WritableFile* dest,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||||
if (compression_type_ != io::compression::kSnappy) {
|
if (compression_type_ != io::compression::kSnappy) {
|
||||||
experimental::SnapshotRecord record;
|
experimental::SnapshotRecord record;
|
||||||
for (const auto& tensor : tensors) {
|
for (const auto& tensor : tensors) {
|
||||||
@ -96,11 +95,12 @@ Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
|||||||
tensor_buffers.reserve(num_simple_);
|
tensor_buffers.reserve(num_simple_);
|
||||||
std::vector<TensorProto> tensor_protos;
|
std::vector<TensorProto> tensor_protos;
|
||||||
tensor_protos.reserve(num_complex_);
|
tensor_protos.reserve(num_complex_);
|
||||||
SnapshotTensorMetadata metadata;
|
experimental::SnapshotTensorMetadata metadata;
|
||||||
int64 total_size = 0;
|
int64 total_size = 0;
|
||||||
for (int i = 0; i < tensors.size(); ++i) {
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
const Tensor& tensor = tensors[i];
|
const Tensor& tensor = tensors[i];
|
||||||
TensorMetadata* tensor_metadata = metadata.add_tensor_metadata();
|
experimental::TensorMetadata* tensor_metadata =
|
||||||
|
metadata.add_tensor_metadata();
|
||||||
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
|
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
|
||||||
int64 size = 0;
|
int64 size = 0;
|
||||||
if (simple_tensor_mask_[i]) {
|
if (simple_tensor_mask_[i]) {
|
||||||
@ -150,9 +150,9 @@ Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotWriter::Sync() { return dest_->Sync(); }
|
Status Writer::Sync() { return dest_->Sync(); }
|
||||||
|
|
||||||
Status SnapshotWriter::Close() {
|
Status Writer::Close() {
|
||||||
if (dest_is_owned_) {
|
if (dest_is_owned_) {
|
||||||
Status s = dest_->Close();
|
Status s = dest_->Close();
|
||||||
delete dest_;
|
delete dest_;
|
||||||
@ -162,7 +162,7 @@ Status SnapshotWriter::Close() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
SnapshotWriter::~SnapshotWriter() {
|
Writer::~Writer() {
|
||||||
if (dest_ != nullptr) {
|
if (dest_ != nullptr) {
|
||||||
Status s = Close();
|
Status s = Close();
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
@ -171,7 +171,7 @@ SnapshotWriter::~SnapshotWriter() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotWriter::WriteRecord(const StringPiece& data) {
|
Status Writer::WriteRecord(const StringPiece& data) {
|
||||||
char header[kHeaderSize];
|
char header[kHeaderSize];
|
||||||
core::EncodeFixed64(header, data.size());
|
core::EncodeFixed64(header, data.size());
|
||||||
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
||||||
@ -179,7 +179,7 @@ Status SnapshotWriter::WriteRecord(const StringPiece& data) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(PLATFORM_GOOGLE)
|
#if defined(PLATFORM_GOOGLE)
|
||||||
Status SnapshotWriter::WriteRecord(const absl::Cord& data) {
|
Status Writer::WriteRecord(const absl::Cord& data) {
|
||||||
char header[kHeaderSize];
|
char header[kHeaderSize];
|
||||||
core::EncodeFixed64(header, data.size());
|
core::EncodeFixed64(header, data.size());
|
||||||
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
||||||
@ -187,9 +187,8 @@ Status SnapshotWriter::WriteRecord(const absl::Cord& data) {
|
|||||||
}
|
}
|
||||||
#endif // PLATFORM_GOOGLE
|
#endif // PLATFORM_GOOGLE
|
||||||
|
|
||||||
SnapshotReader::SnapshotReader(RandomAccessFile* file,
|
Reader::Reader(RandomAccessFile* file, const string& compression_type,
|
||||||
const string& compression_type, int version,
|
int version, const DataTypeVector& dtypes)
|
||||||
const DataTypeVector& dtypes)
|
|
||||||
: file_(file),
|
: file_(file),
|
||||||
input_stream_(new io::RandomAccessInputStream(file)),
|
input_stream_(new io::RandomAccessInputStream(file)),
|
||||||
compression_type_(compression_type),
|
compression_type_(compression_type),
|
||||||
@ -231,7 +230,7 @@ SnapshotReader::SnapshotReader(RandomAccessFile* file,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
|
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
@ -245,7 +244,7 @@ Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
|||||||
return errors::InvalidArgument("Version 1 only supports snappy.");
|
return errors::InvalidArgument("Version 1 only supports snappy.");
|
||||||
}
|
}
|
||||||
|
|
||||||
SnapshotTensorMetadata metadata;
|
experimental::SnapshotTensorMetadata metadata;
|
||||||
tstring metadata_str;
|
tstring metadata_str;
|
||||||
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
|
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
|
||||||
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
|
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
|
||||||
@ -296,7 +295,7 @@ Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
|
Status Reader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
|
||||||
experimental::SnapshotRecord record;
|
experimental::SnapshotRecord record;
|
||||||
#if defined(PLATFORM_GOOGLE)
|
#if defined(PLATFORM_GOOGLE)
|
||||||
absl::Cord c;
|
absl::Cord c;
|
||||||
@ -317,8 +316,9 @@ Status SnapshotReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotReader::SnappyUncompress(
|
Status Reader::SnappyUncompress(
|
||||||
const SnapshotTensorMetadata* metadata, std::vector<Tensor>* simple_tensors,
|
const experimental::SnapshotTensorMetadata* metadata,
|
||||||
|
std::vector<Tensor>* simple_tensors,
|
||||||
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
|
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
|
||||||
tensor_proto_strs) {
|
tensor_proto_strs) {
|
||||||
tstring compressed;
|
tstring compressed;
|
||||||
@ -365,7 +365,7 @@ Status SnapshotReader::SnappyUncompress(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotReader::ReadRecord(tstring* record) {
|
Status Reader::ReadRecord(tstring* record) {
|
||||||
tstring header;
|
tstring header;
|
||||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||||
uint64 length = core::DecodeFixed64(header.data());
|
uint64 length = core::DecodeFixed64(header.data());
|
||||||
@ -373,7 +373,7 @@ Status SnapshotReader::ReadRecord(tstring* record) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(PLATFORM_GOOGLE)
|
#if defined(PLATFORM_GOOGLE)
|
||||||
Status SnapshotReader::ReadRecord(absl::Cord* record) {
|
Status Reader::ReadRecord(absl::Cord* record) {
|
||||||
tstring header;
|
tstring header;
|
||||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||||
uint64 length = core::DecodeFixed64(header.data());
|
uint64 length = core::DecodeFixed64(header.data());
|
||||||
@ -392,10 +392,9 @@ Status SnapshotReader::ReadRecord(absl::Cord* record) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Status SnapshotWriteMetadataFile(
|
Status WriteMetadataFile(const string& hash_dir,
|
||||||
const string& hash_dir,
|
const experimental::SnapshotMetadataRecord* metadata) {
|
||||||
const experimental::SnapshotMetadataRecord* metadata) {
|
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
|
||||||
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
|
|
||||||
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
|
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
|
||||||
std::string tmp_filename =
|
std::string tmp_filename =
|
||||||
absl::StrCat(metadata_filename, "-tmp-", random::New64());
|
absl::StrCat(metadata_filename, "-tmp-", random::New64());
|
||||||
@ -403,15 +402,15 @@ Status SnapshotWriteMetadataFile(
|
|||||||
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
|
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotReadMetadataFile(
|
Status ReadMetadataFile(const string& hash_dir,
|
||||||
const string& hash_dir, experimental::SnapshotMetadataRecord* metadata) {
|
experimental::SnapshotMetadataRecord* metadata) {
|
||||||
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
|
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
|
||||||
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
|
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
|
||||||
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
|
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
|
Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
||||||
const GraphDef* graph) {
|
const GraphDef* graph) {
|
||||||
std::string hash_hex =
|
std::string hash_hex =
|
||||||
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
|
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
|
||||||
std::string graph_file =
|
std::string graph_file =
|
||||||
@ -422,11 +421,12 @@ Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
|
|||||||
return WriteTextProto(Env::Default(), graph_file, *graph);
|
return WriteTextProto(Env::Default(), graph_file, *graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SnapshotDetermineOpState(
|
Status DetermineOpState(const std::string& mode_string,
|
||||||
const std::string& mode_string, const Status& file_status,
|
const Status& file_status,
|
||||||
const experimental::SnapshotMetadataRecord* metadata,
|
const experimental::SnapshotMetadataRecord* metadata,
|
||||||
const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode) {
|
const uint64 pending_snapshot_expiry_seconds,
|
||||||
if (mode_string == kSnapshotModeRead) {
|
Mode* mode) {
|
||||||
|
if (mode_string == kModeRead) {
|
||||||
// In read mode, we should expect a metadata file is written.
|
// In read mode, we should expect a metadata file is written.
|
||||||
if (errors::IsNotFound(file_status)) {
|
if (errors::IsNotFound(file_status)) {
|
||||||
return file_status;
|
return file_status;
|
||||||
@ -436,13 +436,13 @@ Status SnapshotDetermineOpState(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode_string == kSnapshotModeWrite) {
|
if (mode_string == kModeWrite) {
|
||||||
LOG(INFO) << "Overriding mode to writer.";
|
LOG(INFO) << "Overriding mode to writer.";
|
||||||
*mode = WRITER;
|
*mode = WRITER;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode_string == kSnapshotModePassthrough) {
|
if (mode_string == kModePassthrough) {
|
||||||
LOG(INFO) << "Overriding mode to passthrough.";
|
LOG(INFO) << "Overriding mode to passthrough.";
|
||||||
*mode = PASSTHROUGH;
|
*mode = PASSTHROUGH;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -476,6 +476,6 @@ Status SnapshotDetermineOpState(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace experimental
|
} // namespace snapshot_util
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -28,22 +28,26 @@ namespace tensorflow {
|
|||||||
class GraphDef;
|
class GraphDef;
|
||||||
|
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
namespace experimental {
|
namespace experimental {
|
||||||
|
|
||||||
class SnapshotMetadataRecord;
|
class SnapshotMetadataRecord;
|
||||||
|
|
||||||
constexpr char kSnapshotFilename[] = "snapshot.metadata";
|
|
||||||
|
|
||||||
constexpr char kSnapshotModeAuto[] = "auto";
|
|
||||||
constexpr char kSnapshotModeWrite[] = "write";
|
|
||||||
constexpr char kSnapshotModeRead[] = "read";
|
|
||||||
constexpr char kSnapshotModePassthrough[] = "passthrough";
|
|
||||||
|
|
||||||
enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
|
|
||||||
|
|
||||||
class SnapshotTensorMetadata;
|
class SnapshotTensorMetadata;
|
||||||
|
|
||||||
class SnapshotWriter {
|
} // namespace experimental
|
||||||
|
|
||||||
|
namespace snapshot_util {
|
||||||
|
|
||||||
|
constexpr char kMetadataFilename[] = "snapshot.metadata";
|
||||||
|
|
||||||
|
constexpr char kModeAuto[] = "auto";
|
||||||
|
constexpr char kModeWrite[] = "write";
|
||||||
|
constexpr char kModeRead[] = "read";
|
||||||
|
constexpr char kModePassthrough[] = "passthrough";
|
||||||
|
|
||||||
|
enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
|
||||||
|
|
||||||
|
class Writer {
|
||||||
public:
|
public:
|
||||||
static constexpr const size_t kHeaderSize = sizeof(uint64);
|
static constexpr const size_t kHeaderSize = sizeof(uint64);
|
||||||
|
|
||||||
@ -52,8 +56,8 @@ class SnapshotWriter {
|
|||||||
static constexpr const char* const kWriteCord = "WriteCord";
|
static constexpr const char* const kWriteCord = "WriteCord";
|
||||||
static constexpr const char* const kSeparator = "::";
|
static constexpr const char* const kSeparator = "::";
|
||||||
|
|
||||||
explicit SnapshotWriter(WritableFile* dest, const string& compression_type,
|
explicit Writer(WritableFile* dest, const string& compression_type,
|
||||||
int version, const DataTypeVector& dtypes);
|
int version, const DataTypeVector& dtypes);
|
||||||
|
|
||||||
Status WriteTensors(const std::vector<Tensor>& tensors);
|
Status WriteTensors(const std::vector<Tensor>& tensors);
|
||||||
|
|
||||||
@ -61,7 +65,7 @@ class SnapshotWriter {
|
|||||||
|
|
||||||
Status Close();
|
Status Close();
|
||||||
|
|
||||||
~SnapshotWriter();
|
~Writer();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status WriteRecord(const StringPiece& data);
|
Status WriteRecord(const StringPiece& data);
|
||||||
@ -79,7 +83,7 @@ class SnapshotWriter {
|
|||||||
int num_complex_ = 0;
|
int num_complex_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SnapshotReader {
|
class Reader {
|
||||||
public:
|
public:
|
||||||
// The reader input buffer size is deliberately large because the input reader
|
// The reader input buffer size is deliberately large because the input reader
|
||||||
// will throw an error if the compressed block length cannot fit in the input
|
// will throw an error if the compressed block length cannot fit in the input
|
||||||
@ -96,9 +100,8 @@ class SnapshotReader {
|
|||||||
static constexpr const char* const kReadCord = "ReadCord";
|
static constexpr const char* const kReadCord = "ReadCord";
|
||||||
static constexpr const char* const kSeparator = "::";
|
static constexpr const char* const kSeparator = "::";
|
||||||
|
|
||||||
explicit SnapshotReader(RandomAccessFile* file,
|
explicit Reader(RandomAccessFile* file, const string& compression_type,
|
||||||
const string& compression_type, int version,
|
int version, const DataTypeVector& dtypes);
|
||||||
const DataTypeVector& dtypes);
|
|
||||||
|
|
||||||
Status ReadTensors(std::vector<Tensor>* read_tensors);
|
Status ReadTensors(std::vector<Tensor>* read_tensors);
|
||||||
|
|
||||||
@ -106,7 +109,7 @@ class SnapshotReader {
|
|||||||
Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
|
Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
|
||||||
|
|
||||||
Status SnappyUncompress(
|
Status SnappyUncompress(
|
||||||
const SnapshotTensorMetadata* metadata,
|
const experimental::SnapshotTensorMetadata* metadata,
|
||||||
std::vector<Tensor>* simple_tensors,
|
std::vector<Tensor>* simple_tensors,
|
||||||
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
|
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
|
||||||
tensor_proto_strs);
|
tensor_proto_strs);
|
||||||
@ -127,22 +130,22 @@ class SnapshotReader {
|
|||||||
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
|
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
|
||||||
};
|
};
|
||||||
|
|
||||||
Status SnapshotWriteMetadataFile(
|
Status WriteMetadataFile(const string& hash_dir,
|
||||||
const string& hash_dir,
|
const experimental::SnapshotMetadataRecord* metadata);
|
||||||
const experimental::SnapshotMetadataRecord* metadata);
|
|
||||||
|
|
||||||
Status SnapshotReadMetadataFile(const string& hash_dir,
|
Status ReadMetadataFile(const string& hash_dir,
|
||||||
experimental::SnapshotMetadataRecord* metadata);
|
experimental::SnapshotMetadataRecord* metadata);
|
||||||
|
|
||||||
Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
|
Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
||||||
const GraphDef* graph);
|
const GraphDef* graph);
|
||||||
|
|
||||||
Status SnapshotDetermineOpState(
|
Status DetermineOpState(const std::string& mode_string,
|
||||||
const std::string& mode_string, const Status& file_status,
|
const Status& file_status,
|
||||||
const experimental::SnapshotMetadataRecord* metadata,
|
const experimental::SnapshotMetadataRecord* metadata,
|
||||||
const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode);
|
const uint64 pending_snapshot_expiry_seconds,
|
||||||
|
Mode* mode);
|
||||||
|
|
||||||
} // namespace experimental
|
} // namespace snapshot_util
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user