Change namespace for snapshot utilities to snapshot_util from experimental

PiperOrigin-RevId: 299197783
Change-Id: I8c2f33e319379df7b69246b1ae3116e17a5d2039
This commit is contained in:
Frank Chen 2020-03-05 14:42:13 -08:00 committed by TensorFlower Gardener
parent 059c3253d9
commit 94c04fc77e
3 changed files with 107 additions and 100 deletions

View File

@ -125,7 +125,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_));
mode_ = kSnapshotModeAuto; mode_ = snapshot_util::kModeAuto;
if (ctx->HasAttr("mode")) { if (ctx->HasAttr("mode")) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_));
} }
@ -160,14 +160,15 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument( errors::InvalidArgument(
"pending_snapshot_expiry_seconds must be at least 1 second.")); "pending_snapshot_expiry_seconds must be at least 1 second."));
OP_REQUIRES( OP_REQUIRES(ctx,
ctx, mode_ == snapshot_util::kModeAuto ||
mode_ == kSnapshotModeAuto || mode_ == kSnapshotModeRead || mode_ == snapshot_util::kModeRead ||
mode_ == kSnapshotModeWrite || mode_ == kSnapshotModePassthrough, mode_ == snapshot_util::kModeWrite ||
errors::InvalidArgument("mode must be either '", kSnapshotModeAuto, mode_ == snapshot_util::kModePassthrough,
"', '", kSnapshotModeRead, "', '", errors::InvalidArgument(
kSnapshotModeWrite, "', or '", "mode must be either '", snapshot_util::kModeAuto, "', '",
kSnapshotModePassthrough, "'.")); snapshot_util::kModeRead, "', '", snapshot_util::kModeWrite,
"', or '", snapshot_util::kModePassthrough, "'."));
} }
protected: protected:
@ -190,7 +191,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
uint64 hash; uint64 hash;
OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash)); OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
Status dump_status = SnapshotDumpDatasetGraph(path, hash, &graph_def); Status dump_status =
snapshot_util::DumpDatasetGraph(path, hash, &graph_def);
if (!dump_status.ok()) { if (!dump_status.ok()) {
LOG(WARNING) << "Unable to write graphdef to disk, error: " LOG(WARNING) << "Unable to write graphdef to disk, error: "
<< dump_status.ToString(); << dump_status.ToString();
@ -376,8 +378,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_); mutex_lock l(mu_);
if (iterator_ == nullptr) { if (iterator_ == nullptr) {
experimental::SnapshotMetadataRecord metadata; experimental::SnapshotMetadataRecord metadata;
Status s = SnapshotReadMetadataFile(hash_dir_, &metadata); Status s = snapshot_util::ReadMetadataFile(hash_dir_, &metadata);
TF_RETURN_IF_ERROR(SnapshotDetermineOpState( TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
dataset()->mode_, s, &metadata, dataset()->mode_, s, &metadata,
dataset()->pending_snapshot_expiry_seconds_, &state_)); dataset()->pending_snapshot_expiry_seconds_, &state_));
VLOG(2) << "Snapshot state: " << state_; VLOG(2) << "Snapshot state: " << state_;
@ -411,10 +413,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
{ {
int64 temp; int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp));
state_ = SnapshotMode(temp); state_ = snapshot_util::Mode(temp);
} }
experimental::SnapshotMetadataRecord metadata; experimental::SnapshotMetadataRecord metadata;
TF_RETURN_IF_ERROR(SnapshotReadMetadataFile(hash_dir_, &metadata)); TF_RETURN_IF_ERROR(
snapshot_util::ReadMetadataFile(hash_dir_, &metadata));
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata)); TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
VLOG(2) << "Restoring Snapshot iterator: " << state_; VLOG(2) << "Restoring Snapshot iterator: " << state_;
return RestoreInput(ctx, reader, iterator_); return RestoreInput(ctx, reader, iterator_);
@ -434,13 +437,13 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
switch (state_) { switch (state_) {
case WRITER: case snapshot_util::WRITER:
iterator_ = absl::make_unique<SnapshotWriterIterator>( iterator_ = absl::make_unique<SnapshotWriterIterator>(
SnapshotWriterIterator::Params{ SnapshotWriterIterator::Params{
dataset(), absl::StrCat(prefix(), "WriterImpl")}, dataset(), absl::StrCat(prefix(), "WriterImpl")},
hash_dir_, run_id); hash_dir_, run_id);
break; break;
case READER: case snapshot_util::READER:
if (run_id.empty() && metadata.run_id().empty()) { if (run_id.empty() && metadata.run_id().empty()) {
return errors::NotFound( return errors::NotFound(
"Could not find a valid snapshot to read."); "Could not find a valid snapshot to read.");
@ -468,7 +471,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
dataset(), absl::StrCat(prefix(), "ReaderImpl")}, dataset(), absl::StrCat(prefix(), "ReaderImpl")},
hash_dir_, run_id, metadata.version()); hash_dir_, run_id, metadata.version());
break; break;
case PASSTHROUGH: case snapshot_util::PASSTHROUGH:
iterator_ = absl::make_unique<SnapshotPassthroughIterator>( iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
SnapshotPassthroughIterator::Params{ SnapshotPassthroughIterator::Params{
dataset(), absl::StrCat(prefix(), "PassthroughImpl")}); dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
@ -727,8 +730,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<RandomAccessFile> file; std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
Env::Default()->NewRandomAccessFile(filename, &file)); Env::Default()->NewRandomAccessFile(filename, &file));
SnapshotReader reader(file.get(), dataset()->compression_, version_, snapshot_util::Reader reader(file.get(), dataset()->compression_,
dataset()->output_dtypes()); version_, dataset()->output_dtypes());
while (true) { while (true) {
// Wait for a slot in the buffer. // Wait for a slot in the buffer.
@ -953,7 +956,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
metadata.set_finalized(false); metadata.set_finalized(false);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
SnapshotWriteMetadataFile(hash_dir_, &metadata)); snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
} }
for (int i = 0; i < dataset()->num_writer_threads_; ++i) { for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
++num_active_threads_; ++num_active_threads_;
@ -1252,7 +1255,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
Status ProcessOneElement(int64* bytes_written, Status ProcessOneElement(int64* bytes_written,
string* snapshot_data_filename, string* snapshot_data_filename,
std::unique_ptr<WritableFile>* file, std::unique_ptr<WritableFile>* file,
std::unique_ptr<SnapshotWriter>* writer, std::unique_ptr<snapshot_util::Writer>* writer,
bool* end_of_processing) { bool* end_of_processing) {
profiler::TraceMe activity( profiler::TraceMe activity(
[&]() { [&]() {
@ -1312,7 +1315,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
*snapshot_data_filename = GetSnapshotFilename(); *snapshot_data_filename = GetSnapshotFilename();
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile( TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
*snapshot_data_filename, file)); *snapshot_data_filename, file));
*writer = absl::make_unique<SnapshotWriter>( *writer = absl::make_unique<snapshot_util::Writer>(
file->get(), dataset()->compression_, kCurrentVersion, file->get(), dataset()->compression_, kCurrentVersion,
dataset()->output_dtypes()); dataset()->output_dtypes());
*bytes_written = 0; *bytes_written = 0;
@ -1329,12 +1332,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (!written_final_metadata_file_) { if (!written_final_metadata_file_) {
experimental::SnapshotMetadataRecord metadata; experimental::SnapshotMetadataRecord metadata;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
SnapshotReadMetadataFile(hash_dir_, &metadata)); snapshot_util::ReadMetadataFile(hash_dir_, &metadata));
if (metadata.run_id() == run_id_) { if (metadata.run_id() == run_id_) {
metadata.set_finalized(true); metadata.set_finalized(true);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
SnapshotWriteMetadataFile(hash_dir_, &metadata)); snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
} else { } else {
// TODO(frankchn): We lost the race, remove all snapshots. // TODO(frankchn): We lost the race, remove all snapshots.
} }
@ -1366,9 +1369,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
cond_var_.notify_all(); cond_var_.notify_all();
return; return;
} }
std::unique_ptr<SnapshotWriter> writer( std::unique_ptr<snapshot_util::Writer> writer(
new SnapshotWriter(file.get(), dataset()->compression_, new snapshot_util::Writer(file.get(), dataset()->compression_,
kCurrentVersion, dataset()->output_dtypes())); kCurrentVersion,
dataset()->output_dtypes()));
bool end_of_processing = false; bool end_of_processing = false;
while (!end_of_processing) { while (!end_of_processing) {
@ -1389,8 +1393,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
} }
Status ShouldCloseFile(const string& filename, uint64 bytes_written, Status ShouldCloseFile(const string& filename, uint64 bytes_written,
SnapshotWriter* writer, WritableFile* file, snapshot_util::Writer* writer,
bool* should_close) { WritableFile* file, bool* should_close) {
// If the compression ratio has been estimated, use it to decide // If the compression ratio has been estimated, use it to decide
// whether the file should be closed. We avoid estimating the // whether the file should be closed. We avoid estimating the
// compression ratio repeatedly because it requires syncing the file, // compression ratio repeatedly because it requires syncing the file,
@ -1490,7 +1494,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}; };
string hash_dir_ TF_GUARDED_BY(mu_); string hash_dir_ TF_GUARDED_BY(mu_);
SnapshotMode state_ TF_GUARDED_BY(mu_); snapshot_util::Mode state_ TF_GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_); std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
mutex mu_; mutex mu_;

View File

@ -34,11 +34,10 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace experimental { namespace snapshot_util {
SnapshotWriter::SnapshotWriter(WritableFile* dest, Writer::Writer(WritableFile* dest, const string& compression_type, int version,
const string& compression_type, int version, const DataTypeVector& dtypes)
const DataTypeVector& dtypes)
: dest_(dest), compression_type_(compression_type), version_(version) { : dest_(dest), compression_type_(compression_type), version_(version) {
#if defined(IS_SLIM_BUILD) #if defined(IS_SLIM_BUILD)
if (compression_type != io::compression::kNone) { if (compression_type != io::compression::kNone) {
@ -70,7 +69,7 @@ SnapshotWriter::SnapshotWriter(WritableFile* dest,
} }
} }
Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) { Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
if (compression_type_ != io::compression::kSnappy) { if (compression_type_ != io::compression::kSnappy) {
experimental::SnapshotRecord record; experimental::SnapshotRecord record;
for (const auto& tensor : tensors) { for (const auto& tensor : tensors) {
@ -96,11 +95,12 @@ Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
tensor_buffers.reserve(num_simple_); tensor_buffers.reserve(num_simple_);
std::vector<TensorProto> tensor_protos; std::vector<TensorProto> tensor_protos;
tensor_protos.reserve(num_complex_); tensor_protos.reserve(num_complex_);
SnapshotTensorMetadata metadata; experimental::SnapshotTensorMetadata metadata;
int64 total_size = 0; int64 total_size = 0;
for (int i = 0; i < tensors.size(); ++i) { for (int i = 0; i < tensors.size(); ++i) {
const Tensor& tensor = tensors[i]; const Tensor& tensor = tensors[i];
TensorMetadata* tensor_metadata = metadata.add_tensor_metadata(); experimental::TensorMetadata* tensor_metadata =
metadata.add_tensor_metadata();
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape()); tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
int64 size = 0; int64 size = 0;
if (simple_tensor_mask_[i]) { if (simple_tensor_mask_[i]) {
@ -150,9 +150,9 @@ Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
return Status::OK(); return Status::OK();
} }
Status SnapshotWriter::Sync() { return dest_->Sync(); } Status Writer::Sync() { return dest_->Sync(); }
Status SnapshotWriter::Close() { Status Writer::Close() {
if (dest_is_owned_) { if (dest_is_owned_) {
Status s = dest_->Close(); Status s = dest_->Close();
delete dest_; delete dest_;
@ -162,7 +162,7 @@ Status SnapshotWriter::Close() {
return Status::OK(); return Status::OK();
} }
SnapshotWriter::~SnapshotWriter() { Writer::~Writer() {
if (dest_ != nullptr) { if (dest_ != nullptr) {
Status s = Close(); Status s = Close();
if (!s.ok()) { if (!s.ok()) {
@ -171,7 +171,7 @@ SnapshotWriter::~SnapshotWriter() {
} }
} }
Status SnapshotWriter::WriteRecord(const StringPiece& data) { Status Writer::WriteRecord(const StringPiece& data) {
char header[kHeaderSize]; char header[kHeaderSize];
core::EncodeFixed64(header, data.size()); core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
@ -179,7 +179,7 @@ Status SnapshotWriter::WriteRecord(const StringPiece& data) {
} }
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
Status SnapshotWriter::WriteRecord(const absl::Cord& data) { Status Writer::WriteRecord(const absl::Cord& data) {
char header[kHeaderSize]; char header[kHeaderSize];
core::EncodeFixed64(header, data.size()); core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
@ -187,9 +187,8 @@ Status SnapshotWriter::WriteRecord(const absl::Cord& data) {
} }
#endif // PLATFORM_GOOGLE #endif // PLATFORM_GOOGLE
SnapshotReader::SnapshotReader(RandomAccessFile* file, Reader::Reader(RandomAccessFile* file, const string& compression_type,
const string& compression_type, int version, int version, const DataTypeVector& dtypes)
const DataTypeVector& dtypes)
: file_(file), : file_(file),
input_stream_(new io::RandomAccessInputStream(file)), input_stream_(new io::RandomAccessInputStream(file)),
compression_type_(compression_type), compression_type_(compression_type),
@ -231,7 +230,7 @@ SnapshotReader::SnapshotReader(RandomAccessFile* file,
} }
} }
Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) { Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
profiler::TraceMe activity( profiler::TraceMe activity(
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
@ -245,7 +244,7 @@ Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
return errors::InvalidArgument("Version 1 only supports snappy."); return errors::InvalidArgument("Version 1 only supports snappy.");
} }
SnapshotTensorMetadata metadata; experimental::SnapshotTensorMetadata metadata;
tstring metadata_str; tstring metadata_str;
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str)); TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) { if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
@ -296,7 +295,7 @@ Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
return Status::OK(); return Status::OK();
} }
Status SnapshotReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) { Status Reader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
experimental::SnapshotRecord record; experimental::SnapshotRecord record;
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
absl::Cord c; absl::Cord c;
@ -317,8 +316,9 @@ Status SnapshotReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
return Status::OK(); return Status::OK();
} }
Status SnapshotReader::SnappyUncompress( Status Reader::SnappyUncompress(
const SnapshotTensorMetadata* metadata, std::vector<Tensor>* simple_tensors, const experimental::SnapshotTensorMetadata* metadata,
std::vector<Tensor>* simple_tensors,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
tensor_proto_strs) { tensor_proto_strs) {
tstring compressed; tstring compressed;
@ -365,7 +365,7 @@ Status SnapshotReader::SnappyUncompress(
return Status::OK(); return Status::OK();
} }
Status SnapshotReader::ReadRecord(tstring* record) { Status Reader::ReadRecord(tstring* record) {
tstring header; tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data()); uint64 length = core::DecodeFixed64(header.data());
@ -373,7 +373,7 @@ Status SnapshotReader::ReadRecord(tstring* record) {
} }
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
Status SnapshotReader::ReadRecord(absl::Cord* record) { Status Reader::ReadRecord(absl::Cord* record) {
tstring header; tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data()); uint64 length = core::DecodeFixed64(header.data());
@ -392,10 +392,9 @@ Status SnapshotReader::ReadRecord(absl::Cord* record) {
} }
#endif #endif
Status SnapshotWriteMetadataFile( Status WriteMetadataFile(const string& hash_dir,
const string& hash_dir, const experimental::SnapshotMetadataRecord* metadata) {
const experimental::SnapshotMetadataRecord* metadata) { string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir)); TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
std::string tmp_filename = std::string tmp_filename =
absl::StrCat(metadata_filename, "-tmp-", random::New64()); absl::StrCat(metadata_filename, "-tmp-", random::New64());
@ -403,15 +402,15 @@ Status SnapshotWriteMetadataFile(
return Env::Default()->RenameFile(tmp_filename, metadata_filename); return Env::Default()->RenameFile(tmp_filename, metadata_filename);
} }
Status SnapshotReadMetadataFile( Status ReadMetadataFile(const string& hash_dir,
const string& hash_dir, experimental::SnapshotMetadataRecord* metadata) { experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename)); TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
return ReadBinaryProto(Env::Default(), metadata_filename, metadata); return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
} }
Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash, Status DumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph) { const GraphDef* graph) {
std::string hash_hex = std::string hash_hex =
strings::StrCat(strings::Hex(hash, strings::kZeroPad16)); strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
std::string graph_file = std::string graph_file =
@ -422,11 +421,12 @@ Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
return WriteTextProto(Env::Default(), graph_file, *graph); return WriteTextProto(Env::Default(), graph_file, *graph);
} }
Status SnapshotDetermineOpState( Status DetermineOpState(const std::string& mode_string,
const std::string& mode_string, const Status& file_status, const Status& file_status,
const experimental::SnapshotMetadataRecord* metadata, const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode) { const uint64 pending_snapshot_expiry_seconds,
if (mode_string == kSnapshotModeRead) { Mode* mode) {
if (mode_string == kModeRead) {
// In read mode, we should expect a metadata file is written. // In read mode, we should expect a metadata file is written.
if (errors::IsNotFound(file_status)) { if (errors::IsNotFound(file_status)) {
return file_status; return file_status;
@ -436,13 +436,13 @@ Status SnapshotDetermineOpState(
return Status::OK(); return Status::OK();
} }
if (mode_string == kSnapshotModeWrite) { if (mode_string == kModeWrite) {
LOG(INFO) << "Overriding mode to writer."; LOG(INFO) << "Overriding mode to writer.";
*mode = WRITER; *mode = WRITER;
return Status::OK(); return Status::OK();
} }
if (mode_string == kSnapshotModePassthrough) { if (mode_string == kModePassthrough) {
LOG(INFO) << "Overriding mode to passthrough."; LOG(INFO) << "Overriding mode to passthrough.";
*mode = PASSTHROUGH; *mode = PASSTHROUGH;
return Status::OK(); return Status::OK();
@ -476,6 +476,6 @@ Status SnapshotDetermineOpState(
} }
} }
} // namespace experimental } // namespace snapshot_util
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -28,22 +28,26 @@ namespace tensorflow {
class GraphDef; class GraphDef;
namespace data { namespace data {
namespace experimental { namespace experimental {
class SnapshotMetadataRecord; class SnapshotMetadataRecord;
constexpr char kSnapshotFilename[] = "snapshot.metadata";
constexpr char kSnapshotModeAuto[] = "auto";
constexpr char kSnapshotModeWrite[] = "write";
constexpr char kSnapshotModeRead[] = "read";
constexpr char kSnapshotModePassthrough[] = "passthrough";
enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
class SnapshotTensorMetadata; class SnapshotTensorMetadata;
class SnapshotWriter { } // namespace experimental
namespace snapshot_util {
constexpr char kMetadataFilename[] = "snapshot.metadata";
constexpr char kModeAuto[] = "auto";
constexpr char kModeWrite[] = "write";
constexpr char kModeRead[] = "read";
constexpr char kModePassthrough[] = "passthrough";
enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
class Writer {
public: public:
static constexpr const size_t kHeaderSize = sizeof(uint64); static constexpr const size_t kHeaderSize = sizeof(uint64);
@ -52,8 +56,8 @@ class SnapshotWriter {
static constexpr const char* const kWriteCord = "WriteCord"; static constexpr const char* const kWriteCord = "WriteCord";
static constexpr const char* const kSeparator = "::"; static constexpr const char* const kSeparator = "::";
explicit SnapshotWriter(WritableFile* dest, const string& compression_type, explicit Writer(WritableFile* dest, const string& compression_type,
int version, const DataTypeVector& dtypes); int version, const DataTypeVector& dtypes);
Status WriteTensors(const std::vector<Tensor>& tensors); Status WriteTensors(const std::vector<Tensor>& tensors);
@ -61,7 +65,7 @@ class SnapshotWriter {
Status Close(); Status Close();
~SnapshotWriter(); ~Writer();
private: private:
Status WriteRecord(const StringPiece& data); Status WriteRecord(const StringPiece& data);
@ -79,7 +83,7 @@ class SnapshotWriter {
int num_complex_ = 0; int num_complex_ = 0;
}; };
class SnapshotReader { class Reader {
public: public:
// The reader input buffer size is deliberately large because the input reader // The reader input buffer size is deliberately large because the input reader
// will throw an error if the compressed block length cannot fit in the input // will throw an error if the compressed block length cannot fit in the input
@ -96,9 +100,8 @@ class SnapshotReader {
static constexpr const char* const kReadCord = "ReadCord"; static constexpr const char* const kReadCord = "ReadCord";
static constexpr const char* const kSeparator = "::"; static constexpr const char* const kSeparator = "::";
explicit SnapshotReader(RandomAccessFile* file, explicit Reader(RandomAccessFile* file, const string& compression_type,
const string& compression_type, int version, int version, const DataTypeVector& dtypes);
const DataTypeVector& dtypes);
Status ReadTensors(std::vector<Tensor>* read_tensors); Status ReadTensors(std::vector<Tensor>* read_tensors);
@ -106,7 +109,7 @@ class SnapshotReader {
Status ReadTensorsV0(std::vector<Tensor>* read_tensors); Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
Status SnappyUncompress( Status SnappyUncompress(
const SnapshotTensorMetadata* metadata, const experimental::SnapshotTensorMetadata* metadata,
std::vector<Tensor>* simple_tensors, std::vector<Tensor>* simple_tensors,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
tensor_proto_strs); tensor_proto_strs);
@ -127,22 +130,22 @@ class SnapshotReader {
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
}; };
Status SnapshotWriteMetadataFile( Status WriteMetadataFile(const string& hash_dir,
const string& hash_dir, const experimental::SnapshotMetadataRecord* metadata);
const experimental::SnapshotMetadataRecord* metadata);
Status SnapshotReadMetadataFile(const string& hash_dir, Status ReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata); experimental::SnapshotMetadataRecord* metadata);
Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash, Status DumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph); const GraphDef* graph);
Status SnapshotDetermineOpState( Status DetermineOpState(const std::string& mode_string,
const std::string& mode_string, const Status& file_status, const Status& file_status,
const experimental::SnapshotMetadataRecord* metadata, const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode); const uint64 pending_snapshot_expiry_seconds,
Mode* mode);
} // namespace experimental } // namespace snapshot_util
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow