Add compression, reader and writer path prefixes to snapshot dataset.

PiperOrigin-RevId: 249114591
This commit is contained in:
Frank Chen 2019-05-20 13:31:38 -07:00 committed by TensorFlower Gardener
parent 4dd726e547
commit 273981699d
7 changed files with 130 additions and 26 deletions

View File

@ -38,8 +38,6 @@ enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
const uint64 kReaderBufferSize = 8 * 1024 * 1024; // 8 MB
const char* kCompressionType = io::compression::kGzip;
const uint64 kOneDayInMicroseconds = 24L * 60L * 60L * 1e6L;
const uint64 kNumElementsPerShard = 10000;
@ -117,6 +115,18 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("reader_path_prefix", &reader_path_prefix_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("writer_path_prefix", &writer_path_prefix_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("compression", &compression_));
OP_REQUIRES(
ctx,
compression_ == io::compression::kNone ||
compression_ == io::compression::kGzip,
errors::InvalidArgument("compression must be either '' or 'GZIP'."));
}
protected:
@ -138,18 +148,24 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
string graph_fingerprint = strings::StrCat(
strings::Hex(Fingerprint64(graph_def_serialized), strings::kZeroPad16));
*output = new Dataset(ctx, input, path, graph_fingerprint);
*output =
new Dataset(ctx, input, path, graph_fingerprint, reader_path_prefix_,
writer_path_prefix_, compression_);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, const string& path,
const string& graph_fingerprint)
const string& graph_fingerprint, const string& reader_path_prefix,
const string& writer_path_prefix, const string& compression)
: DatasetBase(DatasetContext(ctx)),
input_(input),
dir_(path),
graph_fingerprint_(graph_fingerprint) {
graph_fingerprint_(graph_fingerprint),
reader_path_prefix_(reader_path_prefix),
writer_path_prefix_(writer_path_prefix),
compression_(compression) {
input_->Ref();
}
@ -179,9 +195,30 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* path = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(dir_, &path));
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, path}, output));
AttrValue compression_attr;
b->BuildAttrValue(compression_, &compression_attr);
AttrValue reader_path_prefix_attr;
b->BuildAttrValue(reader_path_prefix_, &reader_path_prefix_attr);
AttrValue writer_path_prefix_attr;
b->BuildAttrValue(writer_path_prefix_, &writer_path_prefix_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this,
/*inputs=*/
{std::make_pair(0, input_graph_node), std::make_pair(1, path)},
/*list_inputs=*/
{},
/*attrs=*/
{{"compression", compression_attr},
{"reader_path_prefix", reader_path_prefix_attr},
{"writer_path_prefix", writer_path_prefix_attr}},
output));
return Status::OK();
}
@ -254,7 +291,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
run_id_ = metadata_.run_id();
run_dir_ = absl::StrCat(fingerprint_dir_, "/", run_id_);
run_dir_ = absl::StrCat(dataset()->reader_path_prefix_,
fingerprint_dir_, "/", run_id_);
return Status::OK();
}
@ -284,7 +322,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
snapshot_data_filename, &current_read_file_));
auto reader_options =
io::RecordReaderOptions::CreateRecordReaderOptions(
kCompressionType);
dataset()->compression_);
reader_options.buffer_size = kReaderBufferSize;
current_reader_ = absl::make_unique<io::SequentialRecordReader>(
@ -348,7 +386,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
run_id_ = strings::StrCat(
strings::Hex(random::New64(), strings::kZeroPad4));
run_dir_ = absl::StrCat(fingerprint_dir_, "/", run_id_);
run_dir_ = absl::StrCat(dataset()->writer_path_prefix_,
fingerprint_dir_, "/", run_id_);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir_));
@ -406,7 +445,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
auto writer_options =
io::RecordWriterOptions::CreateRecordWriterOptions(
kCompressionType);
dataset()->compression_);
TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(
snapshot_data_filename, &current_write_file_));
@ -473,11 +512,19 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
const DatasetBase* const input_;
const string dir_;
const string graph_fingerprint_;
const string reader_path_prefix_;
const string writer_path_prefix_;
const string compression_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
string reader_path_prefix_;
string writer_path_prefix_;
string compression_;
};
REGISTER_KERNEL_BUILDER(Name("SnapshotDataset").Device(DEVICE_CPU),

View File

@ -370,6 +370,9 @@ REGISTER_OP("SnapshotDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("compression: string = ''")
.Attr("reader_path_prefix: string = ''")
.Attr("writer_path_prefix: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// snapshot_path should be a scalar.

View File

@ -42,7 +42,8 @@ class SnapshotDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
os.mkdir(tmp_dir)
return tmp_dir
def _createSimpleDataset(self, num_elems, tmp_dir=None):
def _createSimpleDataset(self, num_elems, tmp_dir=None,
compression=snapshot.COMPRESSION_NONE):
if not tmp_dir:
tmp_dir = self._makeSnapshotDirectory()
@ -50,7 +51,7 @@ class SnapshotDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
dataset = dataset.map(
lambda x: gen_array_ops.broadcast_to(x, [50, 50, 3]))
dataset = dataset.repeat(num_elems)
dataset = dataset.apply(snapshot.snapshot(tmp_dir))
dataset = dataset.apply(snapshot.snapshot(tmp_dir, compression=compression))
return dataset
@ -63,6 +64,14 @@ class SnapshotDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
except errors.OutOfRangeError:
pass
def benchmarkWriteSnapshotGzipCompression(self):
num_elems = 500000
dataset = self._createSimpleDataset(
num_elems, compression=snapshot.COMPRESSION_GZIP)
self.run_and_report_benchmark(dataset, num_elems, "write_gzip",
warmup=False, iters=1)
def benchmarkWriteSnapshotSimple(self):
num_elems = 500000
dataset = self._createSimpleDataset(num_elems)
@ -93,6 +102,15 @@ class SnapshotDatasetBenchmark(benchmark_base.DatasetBenchmarkBase):
self.run_and_report_benchmark(dataset, num_elems, "read_simple")
def benchmarkReadSnapshotGzipCompression(self):
num_elems = 100000
tmp_dir = self._makeSnapshotDirectory()
dataset = self._createSimpleDataset(
num_elems, tmp_dir, compression=snapshot.COMPRESSION_GZIP)
self._consumeDataset(dataset, num_elems)
self.run_and_report_benchmark(dataset, num_elems, "read_gzip")
if __name__ == "__main__":
test.main()

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import snapshot
@ -29,7 +30,8 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
parameterized.TestCase):
def setUp(self):
super(SnapshotDatasetTest, self).setUp()
@ -54,19 +56,19 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
def assertSnapshotDirectoryContains(
self, directory, num_fingerprints, num_runs_per_fp, num_snapshot_files):
dirlist = os.listdir(directory)
self.assertEqual(len(dirlist), num_fingerprints)
self.assertLen(dirlist, num_fingerprints)
for i in range(num_fingerprints):
fingerprint_dir = os.path.join(directory, dirlist[i])
fingerprint_dir_list = sorted(os.listdir(fingerprint_dir))
self.assertEqual(len(fingerprint_dir_list), num_runs_per_fp + 1)
self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1)
self.assertEqual(fingerprint_dir_list[num_runs_per_fp],
"snapshot.metadata")
for j in range(num_runs_per_fp):
run_dir = os.path.join(fingerprint_dir, fingerprint_dir_list[j])
run_dirlist = sorted(os.listdir(run_dir))
self.assertEqual(len(run_dirlist), num_snapshot_files)
self.assertLen(run_dirlist, num_snapshot_files)
file_counter = 0
for filename in run_dirlist:
@ -105,11 +107,13 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
# one that lost the race would be in passthrough mode.
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
def testWriteSnapshotSimpleSuccessful(self):
@parameterized.parameters(snapshot.COMPRESSION_NONE,
snapshot.COMPRESSION_GZIP)
def testWriteSnapshotSimpleSuccessful(self, compression):
tmpdir = self.makeSnapshotDirectory()
dataset = dataset_ops.Dataset.range(1000)
dataset = dataset.apply(snapshot.snapshot(tmpdir))
dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
self.assertDatasetProduces(dataset, list(range(1000)))
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@ -133,7 +137,9 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 2)
def testReadSnapshotBackAfterWrite(self):
@parameterized.parameters(snapshot.COMPRESSION_NONE,
snapshot.COMPRESSION_GZIP)
def testReadSnapshotBackAfterWrite(self, compression):
self.setUpTFRecord()
filenames = self.test_filenames
@ -145,14 +151,15 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase):
tmpdir = self.makeSnapshotDirectory()
dataset = core_readers._TFRecordDataset(filenames)
dataset = dataset.apply(snapshot.snapshot(tmpdir))
dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression))
self.assertDatasetProduces(dataset, expected)
# remove the original files and try to read the data back only from snapshot
self.removeTFRecords()
dataset2 = core_readers._TFRecordDataset(filenames)
dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
dataset2 = dataset2.apply(snapshot.snapshot(
tmpdir, compression=compression))
self.assertDatasetProduces(dataset2, expected)
def testAdditionalOperationsAfterReadBack(self):

View File

@ -23,21 +23,43 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
COMPRESSION_GZIP = "GZIP"
COMPRESSION_NONE = None
class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A Dataset that captures a snapshot or reads from a snapshot."""
def __init__(self, input_dataset, path):
def __init__(self,
input_dataset,
path,
compression=None,
reader_path_prefix=None,
writer_path_prefix=None):
self._compression = compression if compression is not None else ""
self._reader_path_prefix = (
reader_path_prefix if reader_path_prefix is not None else "")
self._writer_path_prefix = (
writer_path_prefix if writer_path_prefix is not None else "")
self._input_dataset = input_dataset
self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
variant_tensor = ged_ops.snapshot_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
path=self._path,
compression=self._compression,
reader_path_prefix=self._reader_path_prefix,
writer_path_prefix=self._writer_path_prefix,
**dataset_ops.flat_structure(self))
super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
def snapshot(path):
def snapshot(path,
compression=None,
reader_path_prefix=None,
writer_path_prefix=None):
"""Writes to/reads from a snapshot of a dataset.
This function attempts to determine whether a valid snapshot exists at the
@ -48,6 +70,12 @@ def snapshot(path):
Args:
path: A directory where we want to save our snapshots and/or read from a
previously saved snapshot.
compression: The type of compression to apply to the Dataset. Currently
supports "GZIP" or None. Defaults to None (no compression).
reader_path_prefix: A prefix to add to the path when reading from snapshots.
Defaults to None.
writer_path_prefix: A prefix to add to the path when writing to snapshots.
Defaults to None.
Returns:
A `Dataset` transformation function, which can be passed to
@ -55,6 +83,7 @@ def snapshot(path):
"""
def _apply_fn(dataset):
return _SnapshotDataset(dataset, path)
return _SnapshotDataset(dataset, path, compression, reader_path_prefix,
writer_path_prefix)
return _apply_fn

View File

@ -3474,7 +3474,7 @@ tf_module {
}
member_method {
name: "SnapshotDataset"
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
}
member_method {
name: "Softmax"

View File

@ -3474,7 +3474,7 @@ tf_module {
}
member_method {
name: "SnapshotDataset"
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
}
member_method {
name: "Softmax"