Add reader_prefix and writer_prefix to SnapshotDatasetV2Op

PiperOrigin-RevId: 334901918
Change-Id: I1c66546d06ee158479fcc8cd89ab35ae5a4f9558
This commit is contained in:
Frank Chen 2020-10-01 14:50:36 -07:00 committed by TensorFlower Gardener
parent 741f2c3a09
commit 16249195f1
6 changed files with 69 additions and 28 deletions
tensorflow

View File

@ -108,6 +108,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
const std::string& path, const std::string& compression,
const std::string& reader_prefix, const std::string& writer_prefix,
std::unique_ptr<CapturedFunction> reader_func,
std::unique_ptr<CapturedFunction> shard_func);
@ -138,6 +139,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {
const uint64 hash_;
const tstring path_;
const std::string compression_;
const std::string reader_prefix_;
const std::string writer_prefix_;
std::unique_ptr<CapturedFunction> reader_func_;
std::unique_ptr<CapturedFunction> shard_func_;
@ -294,6 +297,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Passthrough
SnapshotDatasetV2Op::Dataset::Dataset(
OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
const std::string& path, const std::string& compression,
const std::string& reader_prefix, const std::string& writer_prefix,
std::unique_ptr<CapturedFunction> reader_func,
std::unique_ptr<CapturedFunction> shard_func)
: DatasetBase(DatasetContext(ctx)),
@ -302,6 +306,8 @@ SnapshotDatasetV2Op::Dataset::Dataset(
path_(path),
compression_(compression == kCompressionAuto ? io::compression::kSnappy
: compression),
reader_prefix_(reader_prefix),
writer_prefix_(writer_prefix),
reader_func_(std::move(reader_func)),
shard_func_(std::move(shard_func)) {
input_->Ref();
@ -363,6 +369,12 @@ Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
AttrValue compression_attr;
b->BuildAttrValue(compression_, &compression_attr);
AttrValue reader_prefix_attr;
b->BuildAttrValue(reader_prefix_, &reader_prefix_attr);
AttrValue writer_prefix_attr;
b->BuildAttrValue(writer_prefix_, &writer_prefix_attr);
AttrValue reader_func_attr;
b->BuildAttrValue(reader_func_->func(), &reader_func_attr);
@ -386,6 +398,8 @@ Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
std::make_pair(3, shard_func_other_args)},
/*attrs=*/
{{kCompression, compression_attr},
{kReaderPrefix, reader_prefix_attr},
{kWriterPrefix, writer_prefix_attr},
{kReaderFunc, reader_func_attr},
{kShardFunc, shard_func_attr},
{kReaderFuncTarguments, reader_func_arguments_types_attr},
@ -401,7 +415,8 @@ SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
IteratorContext* ctx) {
return ctx->env()->RecursivelyCreateDir(hash_dir_);
return ctx->env()->RecursivelyCreateDir(
io::JoinPath(dataset()->writer_prefix_, hash_dir_));
}
Status SnapshotDatasetV2Op::Dataset::Iterator::SaveInternal(
@ -460,7 +475,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
ctx->env(), hash_dir_, &metadata, &file_exists));
ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_),
&metadata, &file_exists));
if (!file_exists) {
return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
" does not exist any more.");
@ -476,7 +492,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
experimental::SnapshotMetadataRecord metadata;
bool file_exists;
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
ctx->env(), hash_dir_, &metadata, &file_exists));
ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_),
&metadata, &file_exists));
// `pending_snapshot_expiry_seconds` is a legacy option where we would not
// write snapshots that we think were still on-going. We decided that this
@ -522,8 +539,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
TF_RETURN_IF_ERROR(
dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
auto hash_dir =
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_);
auto hash_dir = snapshot_util::HashDirectory(
io::JoinPath(dataset()->reader_prefix_, dataset()->path_),
dataset()->hash_);
bool metadata_file_exists;
experimental::SnapshotMetadataRecord metadata;
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
@ -624,8 +642,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
metadata.add_dtype(output_dtype);
}
metadata.set_finalized(finalized);
tstring hash_directory =
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_);
tstring hash_directory = io::JoinPath(
dataset()->writer_prefix_,
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_));
return snapshot_util::WriteMetadataFile(env, hash_directory, &metadata);
}
@ -678,7 +697,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
// Creates the run directory.
run_dir_ = snapshot_util::RunDirectory(
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_),
snapshot_util::HashDirectory(
io::JoinPath(dataset()->writer_prefix_, dataset()->path_),
dataset()->hash_),
run_id_);
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
@ -757,7 +778,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
run_id_ = static_cast<uint64>(run_id_signed);
run_dir_ = snapshot_util::RunDirectory(
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_),
snapshot_util::HashDirectory(
io::JoinPath(dataset()->writer_prefix_, dataset()->path_),
dataset()->hash_),
run_id_);
current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
@ -798,6 +821,14 @@ SnapshotDatasetV2Op::SnapshotDatasetV2Op(OpKernelConstruction* ctx)
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
if (ctx->HasAttr(kReaderPrefix)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kReaderPrefix, &reader_prefix_));
}
if (ctx->HasAttr(kWriterPrefix)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kWriterPrefix, &writer_prefix_));
}
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, reader_params,
&reader_func_metadata_));
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, shard_params,
@ -834,8 +865,8 @@ void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
kShardFuncOtherArgs, &shard_func));
*output = new SnapshotDatasetV2Op::Dataset(
ctx, input, graph_hash, path, compression_, std::move(reader_func),
std::move(shard_func));
ctx, input, graph_hash, path, compression_, reader_prefix_,
writer_prefix_, std::move(reader_func), std::move(shard_func));
}
namespace {

View File

@ -44,6 +44,8 @@ class SnapshotDatasetV2Op : public UnaryDatasetOpKernel {
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kCompression = "compression";
static constexpr const char* const kReaderPrefix = "reader_prefix";
static constexpr const char* const kWriterPrefix = "writer_prefix";
static constexpr const char* const kCompressionAuto = "AUTO";
static constexpr const char* const kReaderFunc = "reader_func";
static constexpr const char* const kShardFunc = "shard_func";
@ -71,6 +73,8 @@ class SnapshotDatasetV2Op : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
std::string compression_;
std::string reader_prefix_;
std::string writer_prefix_;
std::shared_ptr<FunctionMetadata> reader_func_metadata_;
std::shared_ptr<FunctionMetadata> shard_func_metadata_;

View File

@ -966,6 +966,8 @@ REGISTER_OP("SnapshotDatasetV2")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("compression: string = ''")
.Attr("reader_prefix: string = ''")
.Attr("writer_prefix: string = ''")
.Attr("reader_func: func")
.Attr("shard_func: func")
.Attr("Treader_func_args: list(type) >= 0")

View File

@ -334,23 +334,27 @@ def snapshot(path, compression="AUTO", reader_func=None, shard_func=None):
def _apply_fn(dataset):
"""Actual dataset transformation."""
project_func = None
if shard_func is None:
dataset = dataset.enumerate()
dataset = _SnapshotDataset(
input_dataset=dataset,
path=path,
compression=compression,
reader_func=reader_func,
# This will not do the right thing where the graph is built on a
# different machine than the executor (e.g. Cloud TPUs).
shard_func=lambda index, _: index % multiprocessing.cpu_count())
return dataset.map(lambda _, elem: elem)
# This sets the amount of parallelism based on the number of CPU cores on
# the machine where this Python code is executed, which may differ from
# the number of CPU cores where the input pipeline graph is actually
# executed (e.g. remote Cloud TPU workers).
local_shard_func = lambda index, _: index % multiprocessing.cpu_count()
project_func = lambda _, elem: elem
else:
return _SnapshotDataset(
input_dataset=dataset,
path=path,
compression=compression,
reader_func=reader_func,
shard_func=shard_func)
local_shard_func = shard_func
dataset = _SnapshotDataset(
input_dataset=dataset,
path=path,
compression=compression,
reader_func=reader_func,
# This will not do the right thing where the graph is built on a
# different machine than the executor (e.g. Cloud TPUs).
shard_func=local_shard_func)
if project_func is not None:
dataset = dataset.map(project_func)
return dataset
return _apply_fn

View File

@ -4110,7 +4110,7 @@ tf_module {
}
member_method {
name: "SnapshotDatasetV2"
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'reader_prefix\', \'writer_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
}
member_method {
name: "SobolSample"

View File

@ -4110,7 +4110,7 @@ tf_module {
}
member_method {
name: "SnapshotDatasetV2"
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'reader_prefix\', \'writer_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
}
member_method {
name: "SobolSample"