Change namespace for snapshot utilities to snapshot_util from experimental

PiperOrigin-RevId: 299197783
Change-Id: I8c2f33e319379df7b69246b1ae3116e17a5d2039
This commit is contained in:
Frank Chen 2020-03-05 14:42:13 -08:00 committed by TensorFlower Gardener
parent 059c3253d9
commit 94c04fc77e
3 changed files with 107 additions and 100 deletions

View File

@ -125,7 +125,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_));
mode_ = kSnapshotModeAuto;
mode_ = snapshot_util::kModeAuto;
if (ctx->HasAttr("mode")) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_));
}
@ -160,14 +160,15 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument(
"pending_snapshot_expiry_seconds must be at least 1 second."));
OP_REQUIRES(
ctx,
mode_ == kSnapshotModeAuto || mode_ == kSnapshotModeRead ||
mode_ == kSnapshotModeWrite || mode_ == kSnapshotModePassthrough,
errors::InvalidArgument("mode must be either '", kSnapshotModeAuto,
"', '", kSnapshotModeRead, "', '",
kSnapshotModeWrite, "', or '",
kSnapshotModePassthrough, "'."));
OP_REQUIRES(ctx,
mode_ == snapshot_util::kModeAuto ||
mode_ == snapshot_util::kModeRead ||
mode_ == snapshot_util::kModeWrite ||
mode_ == snapshot_util::kModePassthrough,
errors::InvalidArgument(
"mode must be either '", snapshot_util::kModeAuto, "', '",
snapshot_util::kModeRead, "', '", snapshot_util::kModeWrite,
"', or '", snapshot_util::kModePassthrough, "'."));
}
protected:
@ -190,7 +191,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
uint64 hash;
OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
Status dump_status = SnapshotDumpDatasetGraph(path, hash, &graph_def);
Status dump_status =
snapshot_util::DumpDatasetGraph(path, hash, &graph_def);
if (!dump_status.ok()) {
LOG(WARNING) << "Unable to write graphdef to disk, error: "
<< dump_status.ToString();
@ -376,8 +378,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
if (iterator_ == nullptr) {
experimental::SnapshotMetadataRecord metadata;
Status s = SnapshotReadMetadataFile(hash_dir_, &metadata);
TF_RETURN_IF_ERROR(SnapshotDetermineOpState(
Status s = snapshot_util::ReadMetadataFile(hash_dir_, &metadata);
TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
dataset()->mode_, s, &metadata,
dataset()->pending_snapshot_expiry_seconds_, &state_));
VLOG(2) << "Snapshot state: " << state_;
@ -411,10 +413,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
{
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp));
state_ = SnapshotMode(temp);
state_ = snapshot_util::Mode(temp);
}
experimental::SnapshotMetadataRecord metadata;
TF_RETURN_IF_ERROR(SnapshotReadMetadataFile(hash_dir_, &metadata));
TF_RETURN_IF_ERROR(
snapshot_util::ReadMetadataFile(hash_dir_, &metadata));
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
VLOG(2) << "Restoring Snapshot iterator: " << state_;
return RestoreInput(ctx, reader, iterator_);
@ -434,13 +437,13 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
switch (state_) {
case WRITER:
case snapshot_util::WRITER:
iterator_ = absl::make_unique<SnapshotWriterIterator>(
SnapshotWriterIterator::Params{
dataset(), absl::StrCat(prefix(), "WriterImpl")},
hash_dir_, run_id);
break;
case READER:
case snapshot_util::READER:
if (run_id.empty() && metadata.run_id().empty()) {
return errors::NotFound(
"Could not find a valid snapshot to read.");
@ -468,7 +471,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
dataset(), absl::StrCat(prefix(), "ReaderImpl")},
hash_dir_, run_id, metadata.version());
break;
case PASSTHROUGH:
case snapshot_util::PASSTHROUGH:
iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
SnapshotPassthroughIterator::Params{
dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
@ -727,8 +730,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(
Env::Default()->NewRandomAccessFile(filename, &file));
SnapshotReader reader(file.get(), dataset()->compression_, version_,
dataset()->output_dtypes());
snapshot_util::Reader reader(file.get(), dataset()->compression_,
version_, dataset()->output_dtypes());
while (true) {
// Wait for a slot in the buffer.
@ -953,7 +956,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
metadata.set_finalized(false);
TF_RETURN_IF_ERROR(
SnapshotWriteMetadataFile(hash_dir_, &metadata));
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
}
for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
++num_active_threads_;
@ -1252,7 +1255,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
Status ProcessOneElement(int64* bytes_written,
string* snapshot_data_filename,
std::unique_ptr<WritableFile>* file,
std::unique_ptr<SnapshotWriter>* writer,
std::unique_ptr<snapshot_util::Writer>* writer,
bool* end_of_processing) {
profiler::TraceMe activity(
[&]() {
@ -1312,7 +1315,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
*snapshot_data_filename = GetSnapshotFilename();
TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile(
*snapshot_data_filename, file));
*writer = absl::make_unique<SnapshotWriter>(
*writer = absl::make_unique<snapshot_util::Writer>(
file->get(), dataset()->compression_, kCurrentVersion,
dataset()->output_dtypes());
*bytes_written = 0;
@ -1329,12 +1332,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
if (!written_final_metadata_file_) {
experimental::SnapshotMetadataRecord metadata;
TF_RETURN_IF_ERROR(
SnapshotReadMetadataFile(hash_dir_, &metadata));
snapshot_util::ReadMetadataFile(hash_dir_, &metadata));
if (metadata.run_id() == run_id_) {
metadata.set_finalized(true);
TF_RETURN_IF_ERROR(
SnapshotWriteMetadataFile(hash_dir_, &metadata));
snapshot_util::WriteMetadataFile(hash_dir_, &metadata));
} else {
// TODO(frankchn): We lost the race, remove all snapshots.
}
@ -1366,9 +1369,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
cond_var_.notify_all();
return;
}
std::unique_ptr<SnapshotWriter> writer(
new SnapshotWriter(file.get(), dataset()->compression_,
kCurrentVersion, dataset()->output_dtypes()));
std::unique_ptr<snapshot_util::Writer> writer(
new snapshot_util::Writer(file.get(), dataset()->compression_,
kCurrentVersion,
dataset()->output_dtypes()));
bool end_of_processing = false;
while (!end_of_processing) {
@ -1389,8 +1393,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
}
Status ShouldCloseFile(const string& filename, uint64 bytes_written,
SnapshotWriter* writer, WritableFile* file,
bool* should_close) {
snapshot_util::Writer* writer,
WritableFile* file, bool* should_close) {
// If the compression ratio has been estimated, use it to decide
// whether the file should be closed. We avoid estimating the
// compression ratio repeatedly because it requires syncing the file,
@ -1490,7 +1494,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
};
string hash_dir_ TF_GUARDED_BY(mu_);
SnapshotMode state_ TF_GUARDED_BY(mu_);
snapshot_util::Mode state_ TF_GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
mutex mu_;

View File

@ -34,11 +34,10 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace experimental {
namespace snapshot_util {
SnapshotWriter::SnapshotWriter(WritableFile* dest,
const string& compression_type, int version,
const DataTypeVector& dtypes)
Writer::Writer(WritableFile* dest, const string& compression_type, int version,
const DataTypeVector& dtypes)
: dest_(dest), compression_type_(compression_type), version_(version) {
#if defined(IS_SLIM_BUILD)
if (compression_type != io::compression::kNone) {
@ -70,7 +69,7 @@ SnapshotWriter::SnapshotWriter(WritableFile* dest,
}
}
Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
Status Writer::WriteTensors(const std::vector<Tensor>& tensors) {
if (compression_type_ != io::compression::kSnappy) {
experimental::SnapshotRecord record;
for (const auto& tensor : tensors) {
@ -96,11 +95,12 @@ Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
tensor_buffers.reserve(num_simple_);
std::vector<TensorProto> tensor_protos;
tensor_protos.reserve(num_complex_);
SnapshotTensorMetadata metadata;
experimental::SnapshotTensorMetadata metadata;
int64 total_size = 0;
for (int i = 0; i < tensors.size(); ++i) {
const Tensor& tensor = tensors[i];
TensorMetadata* tensor_metadata = metadata.add_tensor_metadata();
experimental::TensorMetadata* tensor_metadata =
metadata.add_tensor_metadata();
tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
int64 size = 0;
if (simple_tensor_mask_[i]) {
@ -150,9 +150,9 @@ Status SnapshotWriter::WriteTensors(const std::vector<Tensor>& tensors) {
return Status::OK();
}
Status SnapshotWriter::Sync() { return dest_->Sync(); }
Status Writer::Sync() { return dest_->Sync(); }
Status SnapshotWriter::Close() {
Status Writer::Close() {
if (dest_is_owned_) {
Status s = dest_->Close();
delete dest_;
@ -162,7 +162,7 @@ Status SnapshotWriter::Close() {
return Status::OK();
}
SnapshotWriter::~SnapshotWriter() {
Writer::~Writer() {
if (dest_ != nullptr) {
Status s = Close();
if (!s.ok()) {
@ -171,7 +171,7 @@ SnapshotWriter::~SnapshotWriter() {
}
}
Status SnapshotWriter::WriteRecord(const StringPiece& data) {
Status Writer::WriteRecord(const StringPiece& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
@ -179,7 +179,7 @@ Status SnapshotWriter::WriteRecord(const StringPiece& data) {
}
#if defined(PLATFORM_GOOGLE)
Status SnapshotWriter::WriteRecord(const absl::Cord& data) {
Status Writer::WriteRecord(const absl::Cord& data) {
char header[kHeaderSize];
core::EncodeFixed64(header, data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
@ -187,9 +187,8 @@ Status SnapshotWriter::WriteRecord(const absl::Cord& data) {
}
#endif // PLATFORM_GOOGLE
SnapshotReader::SnapshotReader(RandomAccessFile* file,
const string& compression_type, int version,
const DataTypeVector& dtypes)
Reader::Reader(RandomAccessFile* file, const string& compression_type,
int version, const DataTypeVector& dtypes)
: file_(file),
input_stream_(new io::RandomAccessInputStream(file)),
compression_type_(compression_type),
@ -231,7 +230,7 @@ SnapshotReader::SnapshotReader(RandomAccessFile* file,
}
}
Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) {
profiler::TraceMe activity(
[&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
profiler::TraceMeLevel::kInfo);
@ -245,7 +244,7 @@ Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
return errors::InvalidArgument("Version 1 only supports snappy.");
}
SnapshotTensorMetadata metadata;
experimental::SnapshotTensorMetadata metadata;
tstring metadata_str;
TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
@ -296,7 +295,7 @@ Status SnapshotReader::ReadTensors(std::vector<Tensor>* read_tensors) {
return Status::OK();
}
Status SnapshotReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
Status Reader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
experimental::SnapshotRecord record;
#if defined(PLATFORM_GOOGLE)
absl::Cord c;
@ -317,8 +316,9 @@ Status SnapshotReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
return Status::OK();
}
Status SnapshotReader::SnappyUncompress(
const SnapshotTensorMetadata* metadata, std::vector<Tensor>* simple_tensors,
Status Reader::SnappyUncompress(
const experimental::SnapshotTensorMetadata* metadata,
std::vector<Tensor>* simple_tensors,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
tensor_proto_strs) {
tstring compressed;
@ -365,7 +365,7 @@ Status SnapshotReader::SnappyUncompress(
return Status::OK();
}
Status SnapshotReader::ReadRecord(tstring* record) {
Status Reader::ReadRecord(tstring* record) {
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
@ -373,7 +373,7 @@ Status SnapshotReader::ReadRecord(tstring* record) {
}
#if defined(PLATFORM_GOOGLE)
Status SnapshotReader::ReadRecord(absl::Cord* record) {
Status Reader::ReadRecord(absl::Cord* record) {
tstring header;
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
uint64 length = core::DecodeFixed64(header.data());
@ -392,10 +392,9 @@ Status SnapshotReader::ReadRecord(absl::Cord* record) {
}
#endif
Status SnapshotWriteMetadataFile(
const string& hash_dir,
const experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
Status WriteMetadataFile(const string& hash_dir,
const experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir));
std::string tmp_filename =
absl::StrCat(metadata_filename, "-tmp-", random::New64());
@ -403,15 +402,15 @@ Status SnapshotWriteMetadataFile(
return Env::Default()->RenameFile(tmp_filename, metadata_filename);
}
Status SnapshotReadMetadataFile(
const string& hash_dir, experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename);
Status ReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata) {
string metadata_filename = io::JoinPath(hash_dir, kMetadataFilename);
TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename));
return ReadBinaryProto(Env::Default(), metadata_filename, metadata);
}
Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph) {
Status DumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph) {
std::string hash_hex =
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
std::string graph_file =
@ -422,11 +421,12 @@ Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
return WriteTextProto(Env::Default(), graph_file, *graph);
}
Status SnapshotDetermineOpState(
const std::string& mode_string, const Status& file_status,
const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode) {
if (mode_string == kSnapshotModeRead) {
Status DetermineOpState(const std::string& mode_string,
const Status& file_status,
const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds,
Mode* mode) {
if (mode_string == kModeRead) {
// In read mode, we should expect a metadata file is written.
if (errors::IsNotFound(file_status)) {
return file_status;
@ -436,13 +436,13 @@ Status SnapshotDetermineOpState(
return Status::OK();
}
if (mode_string == kSnapshotModeWrite) {
if (mode_string == kModeWrite) {
LOG(INFO) << "Overriding mode to writer.";
*mode = WRITER;
return Status::OK();
}
if (mode_string == kSnapshotModePassthrough) {
if (mode_string == kModePassthrough) {
LOG(INFO) << "Overriding mode to passthrough.";
*mode = PASSTHROUGH;
return Status::OK();
@ -476,6 +476,6 @@ Status SnapshotDetermineOpState(
}
}
} // namespace experimental
} // namespace snapshot_util
} // namespace data
} // namespace tensorflow

View File

@ -28,22 +28,26 @@ namespace tensorflow {
class GraphDef;
namespace data {
namespace experimental {
class SnapshotMetadataRecord;
constexpr char kSnapshotFilename[] = "snapshot.metadata";
constexpr char kSnapshotModeAuto[] = "auto";
constexpr char kSnapshotModeWrite[] = "write";
constexpr char kSnapshotModeRead[] = "read";
constexpr char kSnapshotModePassthrough[] = "passthrough";
enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
class SnapshotTensorMetadata;
class SnapshotWriter {
} // namespace experimental
namespace snapshot_util {
constexpr char kMetadataFilename[] = "snapshot.metadata";
constexpr char kModeAuto[] = "auto";
constexpr char kModeWrite[] = "write";
constexpr char kModeRead[] = "read";
constexpr char kModePassthrough[] = "passthrough";
enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
class Writer {
public:
static constexpr const size_t kHeaderSize = sizeof(uint64);
@ -52,8 +56,8 @@ class SnapshotWriter {
static constexpr const char* const kWriteCord = "WriteCord";
static constexpr const char* const kSeparator = "::";
explicit SnapshotWriter(WritableFile* dest, const string& compression_type,
int version, const DataTypeVector& dtypes);
explicit Writer(WritableFile* dest, const string& compression_type,
int version, const DataTypeVector& dtypes);
Status WriteTensors(const std::vector<Tensor>& tensors);
@ -61,7 +65,7 @@ class SnapshotWriter {
Status Close();
~SnapshotWriter();
~Writer();
private:
Status WriteRecord(const StringPiece& data);
@ -79,7 +83,7 @@ class SnapshotWriter {
int num_complex_ = 0;
};
class SnapshotReader {
class Reader {
public:
// The reader input buffer size is deliberately large because the input reader
// will throw an error if the compressed block length cannot fit in the input
@ -96,9 +100,8 @@ class SnapshotReader {
static constexpr const char* const kReadCord = "ReadCord";
static constexpr const char* const kSeparator = "::";
explicit SnapshotReader(RandomAccessFile* file,
const string& compression_type, int version,
const DataTypeVector& dtypes);
explicit Reader(RandomAccessFile* file, const string& compression_type,
int version, const DataTypeVector& dtypes);
Status ReadTensors(std::vector<Tensor>* read_tensors);
@ -106,7 +109,7 @@ class SnapshotReader {
Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
Status SnappyUncompress(
const SnapshotTensorMetadata* metadata,
const experimental::SnapshotTensorMetadata* metadata,
std::vector<Tensor>* simple_tensors,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
tensor_proto_strs);
@ -127,22 +130,22 @@ class SnapshotReader {
std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
};
Status SnapshotWriteMetadataFile(
const string& hash_dir,
const experimental::SnapshotMetadataRecord* metadata);
Status WriteMetadataFile(const string& hash_dir,
const experimental::SnapshotMetadataRecord* metadata);
Status SnapshotReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata);
Status ReadMetadataFile(const string& hash_dir,
experimental::SnapshotMetadataRecord* metadata);
Status SnapshotDumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph);
Status DumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef* graph);
Status SnapshotDetermineOpState(
const std::string& mode_string, const Status& file_status,
const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode);
Status DetermineOpState(const std::string& mode_string,
const Status& file_status,
const experimental::SnapshotMetadataRecord* metadata,
const uint64 pending_snapshot_expiry_seconds,
Mode* mode);
} // namespace experimental
} // namespace snapshot_util
} // namespace data
} // namespace tensorflow