Add compression, reader and writer path prefixes to snapshot dataset.
PiperOrigin-RevId: 249114591
This commit is contained in:
parent
4dd726e547
commit
273981699d
@ -38,8 +38,6 @@ enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
|
||||
|
||||
const uint64 kReaderBufferSize = 8 * 1024 * 1024; // 8 MB
|
||||
|
||||
const char* kCompressionType = io::compression::kGzip;
|
||||
|
||||
const uint64 kOneDayInMicroseconds = 24L * 60L * 60L * 1e6L;
|
||||
|
||||
const uint64 kNumElementsPerShard = 10000;
|
||||
@ -117,6 +115,18 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
graph_def_version_(ctx->graph_def_version()) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("reader_path_prefix", &reader_path_prefix_));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("writer_path_prefix", &writer_path_prefix_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("compression", &compression_));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
compression_ == io::compression::kNone ||
|
||||
compression_ == io::compression::kGzip,
|
||||
errors::InvalidArgument("compression must be either '' or 'GZIP'."));
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -138,18 +148,24 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
string graph_fingerprint = strings::StrCat(
|
||||
strings::Hex(Fingerprint64(graph_def_serialized), strings::kZeroPad16));
|
||||
|
||||
*output = new Dataset(ctx, input, path, graph_fingerprint);
|
||||
*output =
|
||||
new Dataset(ctx, input, path, graph_fingerprint, reader_path_prefix_,
|
||||
writer_path_prefix_, compression_);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, const string& path,
|
||||
const string& graph_fingerprint)
|
||||
const string& graph_fingerprint, const string& reader_path_prefix,
|
||||
const string& writer_path_prefix, const string& compression)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
dir_(path),
|
||||
graph_fingerprint_(graph_fingerprint) {
|
||||
graph_fingerprint_(graph_fingerprint),
|
||||
reader_path_prefix_(reader_path_prefix),
|
||||
writer_path_prefix_(writer_path_prefix),
|
||||
compression_(compression) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -179,9 +195,30 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
|
||||
Node* path = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(dir_, &path));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, path}, output));
|
||||
|
||||
AttrValue compression_attr;
|
||||
b->BuildAttrValue(compression_, &compression_attr);
|
||||
|
||||
AttrValue reader_path_prefix_attr;
|
||||
b->BuildAttrValue(reader_path_prefix_, &reader_path_prefix_attr);
|
||||
|
||||
AttrValue writer_path_prefix_attr;
|
||||
b->BuildAttrValue(writer_path_prefix_, &writer_path_prefix_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this,
|
||||
/*inputs=*/
|
||||
{std::make_pair(0, input_graph_node), std::make_pair(1, path)},
|
||||
/*list_inputs=*/
|
||||
{},
|
||||
/*attrs=*/
|
||||
{{"compression", compression_attr},
|
||||
{"reader_path_prefix", reader_path_prefix_attr},
|
||||
{"writer_path_prefix", writer_path_prefix_attr}},
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -254,7 +291,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
run_id_ = metadata_.run_id();
|
||||
run_dir_ = absl::StrCat(fingerprint_dir_, "/", run_id_);
|
||||
run_dir_ = absl::StrCat(dataset()->reader_path_prefix_,
|
||||
fingerprint_dir_, "/", run_id_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -284,7 +322,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
snapshot_data_filename, ¤t_read_file_));
|
||||
auto reader_options =
|
||||
io::RecordReaderOptions::CreateRecordReaderOptions(
|
||||
kCompressionType);
|
||||
dataset()->compression_);
|
||||
reader_options.buffer_size = kReaderBufferSize;
|
||||
|
||||
current_reader_ = absl::make_unique<io::SequentialRecordReader>(
|
||||
@ -348,7 +386,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
run_id_ = strings::StrCat(
|
||||
strings::Hex(random::New64(), strings::kZeroPad4));
|
||||
run_dir_ = absl::StrCat(fingerprint_dir_, "/", run_id_);
|
||||
run_dir_ = absl::StrCat(dataset()->writer_path_prefix_,
|
||||
fingerprint_dir_, "/", run_id_);
|
||||
|
||||
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir_));
|
||||
|
||||
@ -406,7 +445,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
auto writer_options =
|
||||
io::RecordWriterOptions::CreateRecordWriterOptions(
|
||||
kCompressionType);
|
||||
dataset()->compression_);
|
||||
|
||||
TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(
|
||||
snapshot_data_filename, ¤t_write_file_));
|
||||
@ -473,11 +512,19 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
const DatasetBase* const input_;
|
||||
const string dir_;
|
||||
const string graph_fingerprint_;
|
||||
|
||||
const string reader_path_prefix_;
|
||||
const string writer_path_prefix_;
|
||||
const string compression_;
|
||||
};
|
||||
|
||||
const int graph_def_version_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
|
||||
string reader_path_prefix_;
|
||||
string writer_path_prefix_;
|
||||
string compression_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SnapshotDataset").Device(DEVICE_CPU),
|
||||
|
@ -370,6 +370,9 @@ REGISTER_OP("SnapshotDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("compression: string = ''")
|
||||
.Attr("reader_path_prefix: string = ''")
|
||||
.Attr("writer_path_prefix: string = ''")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// snapshot_path should be a scalar.
|
||||
|
@ -42,7 +42,8 @@ class SnapshotDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
os.mkdir(tmp_dir)
|
||||
return tmp_dir
|
||||
|
||||
def _createSimpleDataset(self, num_elems, tmp_dir=None):
|
||||
def _createSimpleDataset(self, num_elems, tmp_dir=None,
|
||||
compression=snapshot.COMPRESSION_NONE):
|
||||
if not tmp_dir:
|
||||
tmp_dir = self._makeSnapshotDirectory()
|
||||
|
||||
@ -50,7 +51,7 @@ class SnapshotDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
dataset = dataset.map(
|
||||
lambda x: gen_array_ops.broadcast_to(x, [50, 50, 3]))
|
||||
dataset = dataset.repeat(num_elems)
|
||||
dataset = dataset.apply(snapshot.snapshot(tmp_dir))
|
||||
dataset = dataset.apply(snapshot.snapshot(tmp_dir, compression=compression))
|
||||
|
||||
return dataset
|
||||
|
||||
@ -63,6 +64,14 @@ class SnapshotDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
def benchmarkWriteSnapshotGzipCompression(self):
|
||||
num_elems = 500000
|
||||
dataset = self._createSimpleDataset(
|
||||
num_elems, compression=snapshot.COMPRESSION_GZIP)
|
||||
|
||||
self.run_and_report_benchmark(dataset, num_elems, "write_gzip",
|
||||
warmup=False, iters=1)
|
||||
|
||||
def benchmarkWriteSnapshotSimple(self):
|
||||
num_elems = 500000
|
||||
dataset = self._createSimpleDataset(num_elems)
|
||||
@ -93,6 +102,15 @@ class SnapshotDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
|
||||
self.run_and_report_benchmark(dataset, num_elems, "read_simple")
|
||||
|
||||
def benchmarkReadSnapshotGzipCompression(self):
|
||||
num_elems = 100000
|
||||
tmp_dir = self._makeSnapshotDirectory()
|
||||
dataset = self._createSimpleDataset(
|
||||
num_elems, tmp_dir, compression=snapshot.COMPRESSION_GZIP)
|
||||
|
||||
self._consumeDataset(dataset, num_elems)
|
||||
self.run_and_report_benchmark(dataset, num_elems, "read_gzip")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
from tensorflow.python.data.experimental.ops import snapshot
|
||||
@ -29,7 +30,8 @@ from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
||||
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(SnapshotDatasetTest, self).setUp()
|
||||
@ -54,19 +56,19 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
||||
def assertSnapshotDirectoryContains(
|
||||
self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files):
|
||||
dirlist = os.listdir(directory)
|
||||
self.assertEqual(len(dirlist), num_fingerprints)
|
||||
self.assertLen(dirlist, num_fingerprints)
|
||||
|
||||
for i in range(num_fingerprints):
|
||||
fingerprint_dir = os.path.join(directory, dirlist[i])
|
||||
fingerprint_dir_list = sorted(os.listdir(fingerprint_dir))
|
||||
self.assertEqual(len(fingerprint_dir_list), num_runs_per_fp + 1)
|
||||
self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1)
|
||||
self.assertEqual(fingerprint_dir_list[num_runs_per_fp],
|
||||
"snapshot.metadata")
|
||||
|
||||
for j in range(num_runs_per_fp):
|
||||
run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j])
|
||||
run_dirlist = sorted(os.listdir(run_dir))
|
||||
self.assertEqual(len(run_dirlist), num_snapshot_files)
|
||||
self.assertLen(run_dirlist, num_snapshot_files)
|
||||
|
||||
file_counter = 0
|
||||
for filename in run_dirlist:
|
||||
@ -105,11 +107,13 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
||||
# one that lost the race would be in passthrough mode.
|
||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
||||
|
||||
def testWriteSnapshotSimpleSuccessful(self):
|
||||
@parameterized.parameters(snapshot.COMPRESSION_NONE,
|
||||
snapshot.COMPRESSION_GZIP)
|
||||
def testWriteSnapshotSimpleSuccessful(self, compression):
|
||||
tmpdir = self.makeSnapshotDirectory()
|
||||
|
||||
dataset = dataset_ops.Dataset.range(1000)
|
||||
dataset = dataset.apply(snapshot.snapshot(tmpdir))
|
||||
dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
|
||||
self.assertDatasetProduces(dataset, list(range(1000)))
|
||||
|
||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
|
||||
@ -133,7 +137,9 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
||||
|
||||
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 2)
|
||||
|
||||
def testReadSnapshotBackAfterWrite(self):
|
||||
@parameterized.parameters(snapshot.COMPRESSION_NONE,
|
||||
snapshot.COMPRESSION_GZIP)
|
||||
def testReadSnapshotBackAfterWrite(self, compression):
|
||||
self.setUpTFRecord()
|
||||
filenames = self.test_filenames
|
||||
|
||||
@ -145,14 +151,15 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
||||
|
||||
tmpdir = self.makeSnapshotDirectory()
|
||||
dataset = core_readers._TFRecordDataset(filenames)
|
||||
dataset = dataset.apply(snapshot.snapshot(tmpdir))
|
||||
dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
|
||||
self.assertDatasetProduces(dataset, expected)
|
||||
|
||||
# remove the original files and try to read the data back only from snapshot
|
||||
self.removeTFRecords()
|
||||
|
||||
dataset2 = core_readers._TFRecordDataset(filenames)
|
||||
dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
|
||||
dataset2 = dataset2.apply(snapshot.snapshot(
|
||||
tmpdir, compression=compression))
|
||||
self.assertDatasetProduces(dataset2, expected)
|
||||
|
||||
def testAdditionalOperationsAfterReadBack(self):
|
||||
|
@ -23,21 +23,43 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
|
||||
|
||||
COMPRESSION_GZIP = "GZIP"
|
||||
COMPRESSION_NONE = None
|
||||
|
||||
|
||||
class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
|
||||
"""A Dataset that captures a snapshot or reads from a snapshot."""
|
||||
|
||||
def __init__(self, input_dataset, path):
|
||||
def __init__(self,
|
||||
input_dataset,
|
||||
path,
|
||||
compression=None,
|
||||
reader_path_prefix=None,
|
||||
writer_path_prefix=None):
|
||||
|
||||
self._compression = compression if compression is not None else ""
|
||||
self._reader_path_prefix = (
|
||||
reader_path_prefix if reader_path_prefix is not None else "")
|
||||
self._writer_path_prefix = (
|
||||
writer_path_prefix if writer_path_prefix is not None else "")
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
|
||||
|
||||
variant_tensor = ged_ops.snapshot_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
path=self._path,
|
||||
compression=self._compression,
|
||||
reader_path_prefix=self._reader_path_prefix,
|
||||
writer_path_prefix=self._writer_path_prefix,
|
||||
**dataset_ops.flat_structure(self))
|
||||
super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
def snapshot(path):
|
||||
def snapshot(path,
|
||||
compression=None,
|
||||
reader_path_prefix=None,
|
||||
writer_path_prefix=None):
|
||||
"""Writes to/reads from a snapshot of a dataset.
|
||||
|
||||
This function attempts to determine whether a valid snapshot exists at the
|
||||
@ -48,6 +70,12 @@ def snapshot(path):
|
||||
Args:
|
||||
path: A directory where we want to save our snapshots and/or read from a
|
||||
previously saved snapshot.
|
||||
compression: The type of compression to apply to the Dataset. Currently
|
||||
supports "GZIP" or None. Defaults to None (no compression).
|
||||
reader_path_prefix: A prefix to add to the path when reading from snapshots.
|
||||
Defaults to None.
|
||||
writer_path_prefix: A prefix to add to the path when writing to snapshots.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@ -55,6 +83,7 @@ def snapshot(path):
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
return _SnapshotDataset(dataset, path)
|
||||
return _SnapshotDataset(dataset, path, compression, reader_path_prefix,
|
||||
writer_path_prefix)
|
||||
|
||||
return _apply_fn
|
||||
|
@ -3474,7 +3474,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "SnapshotDataset"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Softmax"
|
||||
|
@ -3474,7 +3474,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "SnapshotDataset"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Softmax"
|
||||
|
Loading…
Reference in New Issue
Block a user