Change namespace for snapshot utilities to snapshot_util from experimental
PiperOrigin-RevId: 299197783 Change-Id: I8c2f33e319379df7b69246b1ae3116e17a5d2039
This commit is contained in:
parent
059c3253d9
commit
94c04fc77e
@ -125,7 +125,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_));
|
||||
|
||||
mode_ = kSnapshotModeAuto;
|
||||
mode_ = snapshot_util::kModeAuto;
|
||||
if (ctx->HasAttr("mode")) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_));
|
||||
}
|
||||
@ -160,14 +160,15 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
errors::InvalidArgument(
|
||||
"pending_snapshot_expiry_seconds must be at least 1 second."));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
mode_ == kSnapshotModeAuto || mode_ == kSnapshotModeRead ||
|
||||
mode_ == kSnapshotModeWrite || mode_ == kSnapshotModePassthrough,
|
||||
errors::InvalidArgument("mode must be either '", kSnapshotModeAuto,
|
||||
"', '", kSnapshotModeRead, "', '",
|
||||
kSnapshotModeWrite, "', or '",
|
||||
kSnapshotModePassthrough, "'."));
|
||||
OP_REQUIRES(ctx,
|
||||
mode_ == snapshot_util::kModeAuto ||
|
||||
mode_ == snapshot_util::kModeRead ||
|
||||
mode_ == snapshot_util::kModeWrite ||
|
||||
mode_ == snapshot_util::kModePassthrough,
|
||||
errors::InvalidArgument(
|
||||
"mode must be either '", snapshot_util::kModeAuto, "', '",
|
||||
snapshot_util::kModeRead, "', '", snapshot_util::kModeWrite,
|
||||
"', or '", snapshot_util::kModePassthrough, "'."));
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -190,7 +191,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
uint64 hash;
|
||||
OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
|
||||
|
||||
Status dump_status = SnapshotDumpDatasetGraph(path, hash, &graph_def);
|
||||
Status dump_status =
|
||||
snapshot_util::DumpDatasetGraph(path, hash, &graph_def);
|
||||
if (!dump_status.ok()) {
|
||||
LOG(WARNING) << "Unable to write graphdef to disk, error: "
|
||||
<< dump_status.ToString();
|
||||
@ -376,8 +378,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
mutex_lock l(mu_);
|
||||
if (iterator_ == nullptr) {
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
Status s = SnapshotReadMetadataFile(hash_dir_, &metadata);
|
||||
TF_RETURN_IF_ERROR(SnapshotDetermineOpState(
|
||||
Status s = snapshot_util::ReadMetadataFile(hash_dir_, &metadata);
|
||||
TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
|
||||
dataset()->mode_, s, &metadata,
|
||||
dataset()->pending_snapshot_expiry_seconds_, &state_));
|
||||
VLOG(2) << "Snapshot state: " << state_;
|
||||
@ -411,10 +413,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
{
|
||||
int64 temp;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp));
|
||||
state_ = SnapshotMode(temp);
|
||||
state_ = snapshot_util::Mode(temp);
|
||||
}
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
TF_RETURN_IF_ERROR(SnapshotReadMetadataFile(hash_dir_, &metadata));
|
||||
TF_RETURN_IF_ERROR(
|
||||
snapshot_util::ReadMetadataFile(hash_dir_, &metadata));
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
|
||||
VLOG(2) << "Restoring Snapshot iterator: " << state_;
|
||||
return RestoreInput(ctx, reader, iterator_);
|
||||
@ -434,13 +437,13 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
switch (state_) {
|
||||
case WRITER:
|
||||
case snapshot_util::WRITER:
|
||||
iterator_ = absl::make_unique<SnapshotWriterIterator>(
|
||||
SnapshotWriterIterator::Params{
|
||||
dataset(), absl::StrCat(prefix(), "WriterImpl")},
|
||||
hash_dir_, run_id);
|
||||
break;
|
||||
case READER:
|
||||
case snapshot_util::READER:
|
||||
if (run_id.empty() && metadata.run_id().empty()) {
|
||||
return errors::NotFound(
|
||||
"Could not find a valid snapshot to read.");
|
||||
@ -468,7 +471,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
dataset(), absl::StrCat(prefix(), "ReaderImpl")},
|
||||
hash_dir_, run_id, metadata.version());
|
||||
break;
|
||||
case PASSTHROUGH:
|
||||
case snapshot_util::PASSTHROUGH:
|
||||
iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
|
||||
SnapshotPassthroughIterator::Params{
|
||||
dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
|
||||
@ -727,8 +730,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Env::Default()->NewRandomAccessFile(filename, &file));
|
||||
SnapshotReader reader(file.get(), dataset()->compression_, version_,
|
||||
dataset()->output_dtypes());
|
||||
snapshot_util::Reader reader(file.get(), dataset()->compression_,
|
||||
version_, dataset()->output_dtypes());
|
||||
|
||||
while (true) {
|
||||
// Wait for a slot in the buffer.
|
||||
@ -953,7 +956,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
metadata.set_finalized(false);
|
||||
TF_RETURN_IF_ERROR(
|
||||
SnapshotWriteMetadataFile(hash_dir_, &metadata));
|
||||
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
|
||||
}
|
||||
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
|
||||
++num_active_threads_;
|
||||
@ -1252,7 +1255,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status ProcessOneElement(int64* bytes_written,
|
||||
string* snapshot_data_filename,
|
||||
std::unique_ptr<WritableFile>* file,
|
||||
std::unique_ptr<SnapshotWriter>* writer,
|
||||
std::unique_ptr<snapshot_util::Writer>* writer,
|
||||
bool* end_of_processing) {
|
||||
profiler::TraceMe activity(
|
||||
[&]() {
|
||||
@ -1312,7 +1315,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
*snapshot_data_filename = GetSnapshotFilename();
|
||||
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
|
||||
*snapshot_data_filename, file));
|
||||
*writer = absl::make_unique<SnapshotWriter>(
|
||||
*writer = absl::make_unique<snapshot_util::Writer>(
|
||||
file->get(), dataset()->compression_, kCurrentVersion,
|
||||
dataset()->output_dtypes());
|
||||
*bytes_written = 0;
|
||||
@ -1329,12 +1332,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (!written_final_metadata_file_) {
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
TF_RETURN_IF_ERROR(
|
||||
SnapshotReadMetadataFile(hash_dir_, &metadata));
|
||||
snapshot_util::ReadMetadataFile(hash_dir_, &metadata));
|
||||
|
||||
if (metadata.run_id() == run_id_) {
|
||||
metadata.set_finalized(true);
|
||||
TF_RETURN_IF_ERROR(
|
||||
SnapshotWriteMetadataFile(hash_dir_, &metadata));
|
||||
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
|
||||
} else {
|
||||
// TODO(frankchn): We lost the race, remove all snapshots.
|
||||
}
|
||||
@ -1366,9 +1369,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
cond_var_.notify_all();
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<SnapshotWriter> writer(
|
||||
new SnapshotWriter(file.get(), dataset()->compression_,
|
||||
kCurrentVersion, dataset()->output_dtypes()));
|
||||
std::unique_ptr<snapshot_util::Writer> writer(
|
||||
new snapshot_util::Writer(file.get(), dataset()->compression_,
|
||||
kCurrentVersion,
|
||||
dataset()->output_dtypes()));
|
||||
|
||||
bool end_of_processing = false;
|
||||
while (!end_of_processing) {
|
||||
@ -1389,8 +1393,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status ShouldCloseFile(const string& filename, uint64 bytes_written,
|
||||
SnapshotWriter* writer, WritableFile* file,
|
||||
bool* should_close) {
|
||||
snapshot_util::Writer* writer,
|
||||
WritableFile* file, bool* should_close) {
|
||||
// If the compression ratio has been estimated, use it to decide
|
||||
// whether the file should be closed. We avoid estimating the
|
||||
// compression ratio repeatedly because it requires syncing the file,
|
||||
@ -1490,7 +1494,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
|
||||
string hash_dir_ TF_GUARDED_BY(mu_);
|
||||
SnapshotMode state_ TF_GUARDED_BY(mu_);
|
||||
snapshot_util::Mode state_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
|
||||
|
||||
mutex mu_;
|
||||
|
@ -34,11 +34,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace experimental {
|
||||
namespace snapshot_util {
|
||||
|
||||
SnapshotWriter::SnapshotWriter(WritableFile* dest,
|
||||
const string& compression_type, int version,
|
||||
const DataTypeVector& dtypes)
|
||||
Writer::Writer(WritableFile* dest, const string& compression_type, int version,
|
||||
const DataTypeVector& dtypes)
|
||||
: dest_(dest), compression_type_(compression_type), version_(version) {
|
||||
#if defined(IS_SLIM_BUILD)
|
||||
if (compression_type != io::compression::kNone) {
|
||||
@ -70,7 +69,7 @@ SnapshotWriter::SnapshotWriter(WritableFile* dest,
|
||||
}
|
||||
}
|
||||
|
||||
Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
if (compression_type_ != io::compression::kSnappy) {
|
||||
experimental::SnapshotRecord record;
|
||||
for (const auto& tensor : tensors) {
|
||||
@ -96,11 +95,12 @@ Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
tensor_buffers.reserve(num_simple_);
|
||||
std::vector<TensorProto> tensor_protos;
|
||||
tensor_protos.reserve(num_complex_);
|
||||
SnapshotTensorMetadata metadata;
|
||||
experimental::SnapshotTensorMetadata metadata;
|
||||
int64 total_size = 0;
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
const Tensor& tensor = tensors[i];
|
||||
TensorMetadata* tensor_metadata = metadata.add_tensor_metadata();
|
||||
experimental::TensorMetadata* tensor_metadata =
|
||||
metadata.add_tensor_metadata();
|
||||
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
|
||||
int64 size = 0;
|
||||
if (simple_tensor_mask_[i]) {
|
||||
@ -150,9 +150,9 @@ Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotWriter::Sync() { return dest_->Sync(); }
|
||||
Status Writer::Sync() { return dest_->Sync(); }
|
||||
|
||||
Status SnapshotWriter::Close() {
|
||||
Status Writer::Close() {
|
||||
if (dest_is_owned_) {
|
||||
Status s = dest_->Close();
|
||||
delete dest_;
|
||||
@ -162,7 +162,7 @@ Status SnapshotWriter::Close() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
SnapshotWriter::~SnapshotWriter() {
|
||||
Writer::~Writer() {
|
||||
if (dest_ != nullptr) {
|
||||
Status s = Close();
|
||||
if (!s.ok()) {
|
||||
@ -171,7 +171,7 @@ SnapshotWriter::~SnapshotWriter() {
|
||||
}
|
||||
}
|
||||
|
||||
Status SnapshotWriter::WriteRecord(const StringPiece& data) {
|
||||
Status Writer::WriteRecord(const StringPiece& data) {
|
||||
char header[kHeaderSize];
|
||||
core::EncodeFixed64(header, data.size());
|
||||
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
||||
@ -179,7 +179,7 @@ Status SnapshotWriter::WriteRecord(const StringPiece& data) {
|
||||
}
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
Status SnapshotWriter::WriteRecord(const absl::Cord& data) {
|
||||
Status Writer::WriteRecord(const absl::Cord& data) {
|
||||
char header[kHeaderSize];
|
||||
core::EncodeFixed64(header, data.size());
|
||||
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
|
||||
@ -187,9 +187,8 @@ Status SnapshotWriter::WriteRecord(const absl::Cord& data) {
|
||||
}
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
SnapshotReader::SnapshotReader(RandomAccessFile* file,
|
||||
const string& compression_type, int version,
|
||||
const DataTypeVector& dtypes)
|
||||
Reader::Reader(RandomAccessFile* file, const string& compression_type,
|
||||
int version, const DataTypeVector& dtypes)
|
||||
: file_(file),
|
||||
input_stream_(new io::RandomAccessInputStream(file)),
|
||||
compression_type_(compression_type),
|
||||
@ -231,7 +230,7 @@ SnapshotReader::SnapshotReader(RandomAccessFile* file,
|
||||
}
|
||||
}
|
||||
|
||||
Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||
profiler::TraceMe activity(
|
||||
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
@ -245,7 +244,7 @@ Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||
return errors::InvalidArgument("Version 1 only supports snappy.");
|
||||
}
|
||||
|
||||
SnapshotTensorMetadata metadata;
|
||||
experimental::SnapshotTensorMetadata metadata;
|
||||
tstring metadata_str;
|
||||
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
|
||||
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
|
||||
@ -296,7 +295,7 @@ Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
|
||||
Status Reader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
|
||||
experimental::SnapshotRecord record;
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
absl::Cord c;
|
||||
@ -317,8 +316,9 @@ Status SnapshotReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotReader::SnappyUncompress(
|
||||
const SnapshotTensorMetadata* metadata, std::vector<Tensor>* simple_tensors,
|
||||
Status Reader::SnappyUncompress(
|
||||
const experimental::SnapshotTensorMetadata* metadata,
|
||||
std::vector<Tensor>* simple_tensors,
|
||||
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
|
||||
tensor_proto_strs) {
|
||||
tstring compressed;
|
||||
@ -365,7 +365,7 @@ Status SnapshotReader::SnappyUncompress(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotReader::ReadRecord(tstring* record) {
|
||||
Status Reader::ReadRecord(tstring* record) {
|
||||
tstring header;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||
uint64 length = core::DecodeFixed64(header.data());
|
||||
@ -373,7 +373,7 @@ Status SnapshotReader::ReadRecord(tstring* record) {
|
||||
}
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
Status SnapshotReader::ReadRecord(absl::Cord* record) {
|
||||
Status Reader::ReadRecord(absl::Cord* record) {
|
||||
tstring header;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||
uint64 length = core::DecodeFixed64(header.data());
|
||||
@ -392,10 +392,9 @@ Status SnapshotReader::ReadRecord(absl::Cord* record) {
|
||||
}
|
||||
#endif
|
||||
|
||||
Status SnapshotWriteMetadataFile(
|
||||
const string& hash_dir,
|
||||
const experimental::SnapshotMetadataRecord* metadata) {
|
||||
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
|
||||
Status WriteMetadataFile(const string& hash_dir,
|
||||
const experimental::SnapshotMetadataRecord* metadata) {
|
||||
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
|
||||
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
|
||||
std::string tmp_filename =
|
||||
absl::StrCat(metadata_filename, "-tmp-", random::New64());
|
||||
@ -403,15 +402,15 @@ Status SnapshotWriteMetadataFile(
|
||||
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
|
||||
}
|
||||
|
||||
Status SnapshotReadMetadataFile(
|
||||
const string& hash_dir, experimental::SnapshotMetadataRecord* metadata) {
|
||||
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
|
||||
Status ReadMetadataFile(const string& hash_dir,
|
||||
experimental::SnapshotMetadataRecord* metadata) {
|
||||
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
|
||||
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
|
||||
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
|
||||
}
|
||||
|
||||
Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
|
||||
const GraphDef* graph) {
|
||||
Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
||||
const GraphDef* graph) {
|
||||
std::string hash_hex =
|
||||
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
|
||||
std::string graph_file =
|
||||
@ -422,11 +421,12 @@ Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
|
||||
return WriteTextProto(Env::Default(), graph_file, *graph);
|
||||
}
|
||||
|
||||
Status SnapshotDetermineOpState(
|
||||
const std::string& mode_string, const Status& file_status,
|
||||
const experimental::SnapshotMetadataRecord* metadata,
|
||||
const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode) {
|
||||
if (mode_string == kSnapshotModeRead) {
|
||||
Status DetermineOpState(const std::string& mode_string,
|
||||
const Status& file_status,
|
||||
const experimental::SnapshotMetadataRecord* metadata,
|
||||
const uint64 pending_snapshot_expiry_seconds,
|
||||
Mode* mode) {
|
||||
if (mode_string == kModeRead) {
|
||||
// In read mode, we should expect a metadata file is written.
|
||||
if (errors::IsNotFound(file_status)) {
|
||||
return file_status;
|
||||
@ -436,13 +436,13 @@ Status SnapshotDetermineOpState(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (mode_string == kSnapshotModeWrite) {
|
||||
if (mode_string == kModeWrite) {
|
||||
LOG(INFO) << "Overriding mode to writer.";
|
||||
*mode = WRITER;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (mode_string == kSnapshotModePassthrough) {
|
||||
if (mode_string == kModePassthrough) {
|
||||
LOG(INFO) << "Overriding mode to passthrough.";
|
||||
*mode = PASSTHROUGH;
|
||||
return Status::OK();
|
||||
@ -476,6 +476,6 @@ Status SnapshotDetermineOpState(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace snapshot_util
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -28,22 +28,26 @@ namespace tensorflow {
|
||||
class GraphDef;
|
||||
|
||||
namespace data {
|
||||
|
||||
namespace experimental {
|
||||
|
||||
class SnapshotMetadataRecord;
|
||||
|
||||
constexpr char kSnapshotFilename[] = "snapshot.metadata";
|
||||
|
||||
constexpr char kSnapshotModeAuto[] = "auto";
|
||||
constexpr char kSnapshotModeWrite[] = "write";
|
||||
constexpr char kSnapshotModeRead[] = "read";
|
||||
constexpr char kSnapshotModePassthrough[] = "passthrough";
|
||||
|
||||
enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
|
||||
|
||||
class SnapshotTensorMetadata;
|
||||
|
||||
class SnapshotWriter {
|
||||
} // namespace experimental
|
||||
|
||||
namespace snapshot_util {
|
||||
|
||||
constexpr char kMetadataFilename[] = "snapshot.metadata";
|
||||
|
||||
constexpr char kModeAuto[] = "auto";
|
||||
constexpr char kModeWrite[] = "write";
|
||||
constexpr char kModeRead[] = "read";
|
||||
constexpr char kModePassthrough[] = "passthrough";
|
||||
|
||||
enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
|
||||
|
||||
class Writer {
|
||||
public:
|
||||
static constexpr const size_t kHeaderSize = sizeof(uint64);
|
||||
|
||||
@ -52,8 +56,8 @@ class SnapshotWriter {
|
||||
static constexpr const char* const kWriteCord = "WriteCord";
|
||||
static constexpr const char* const kSeparator = "::";
|
||||
|
||||
explicit SnapshotWriter(WritableFile* dest, const string& compression_type,
|
||||
int version, const DataTypeVector& dtypes);
|
||||
explicit Writer(WritableFile* dest, const string& compression_type,
|
||||
int version, const DataTypeVector& dtypes);
|
||||
|
||||
Status WriteTensors(const std::vector<Tensor>& tensors);
|
||||
|
||||
@ -61,7 +65,7 @@ class SnapshotWriter {
|
||||
|
||||
Status Close();
|
||||
|
||||
~SnapshotWriter();
|
||||
~Writer();
|
||||
|
||||
private:
|
||||
Status WriteRecord(const StringPiece& data);
|
||||
@ -79,7 +83,7 @@ class SnapshotWriter {
|
||||
int num_complex_ = 0;
|
||||
};
|
||||
|
||||
class SnapshotReader {
|
||||
class Reader {
|
||||
public:
|
||||
// The reader input buffer size is deliberately large because the input reader
|
||||
// will throw an error if the compressed block length cannot fit in the input
|
||||
@ -96,9 +100,8 @@ class SnapshotReader {
|
||||
static constexpr const char* const kReadCord = "ReadCord";
|
||||
static constexpr const char* const kSeparator = "::";
|
||||
|
||||
explicit SnapshotReader(RandomAccessFile* file,
|
||||
const string& compression_type, int version,
|
||||
const DataTypeVector& dtypes);
|
||||
explicit Reader(RandomAccessFile* file, const string& compression_type,
|
||||
int version, const DataTypeVector& dtypes);
|
||||
|
||||
Status ReadTensors(std::vector<Tensor>* read_tensors);
|
||||
|
||||
@ -106,7 +109,7 @@ class SnapshotReader {
|
||||
Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
|
||||
|
||||
Status SnappyUncompress(
|
||||
const SnapshotTensorMetadata* metadata,
|
||||
const experimental::SnapshotTensorMetadata* metadata,
|
||||
std::vector<Tensor>* simple_tensors,
|
||||
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
|
||||
tensor_proto_strs);
|
||||
@ -127,22 +130,22 @@ class SnapshotReader {
|
||||
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
|
||||
};
|
||||
|
||||
Status SnapshotWriteMetadataFile(
|
||||
const string& hash_dir,
|
||||
const experimental::SnapshotMetadataRecord* metadata);
|
||||
Status WriteMetadataFile(const string& hash_dir,
|
||||
const experimental::SnapshotMetadataRecord* metadata);
|
||||
|
||||
Status SnapshotReadMetadataFile(const string& hash_dir,
|
||||
experimental::SnapshotMetadataRecord* metadata);
|
||||
Status ReadMetadataFile(const string& hash_dir,
|
||||
experimental::SnapshotMetadataRecord* metadata);
|
||||
|
||||
Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
|
||||
const GraphDef* graph);
|
||||
Status DumpDatasetGraph(const std::string& path, uint64 hash,
|
||||
const GraphDef* graph);
|
||||
|
||||
Status SnapshotDetermineOpState(
|
||||
const std::string& mode_string, const Status& file_status,
|
||||
const experimental::SnapshotMetadataRecord* metadata,
|
||||
const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode);
|
||||
Status DetermineOpState(const std::string& mode_string,
|
||||
const Status& file_status,
|
||||
const experimental::SnapshotMetadataRecord* metadata,
|
||||
const uint64 pending_snapshot_expiry_seconds,
|
||||
Mode* mode);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace snapshot_util
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user