Add reader_prefix and writer_prefix to SnapshotDatasetV2Op
PiperOrigin-RevId: 334901918 Change-Id: I1c66546d06ee158479fcc8cd89ab35ae5a4f9558
This commit is contained in:
parent
741f2c3a09
commit
16249195f1
tensorflow
core
python/data/experimental/ops
tools/api/golden
@ -108,6 +108,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
|
||||
const std::string& path, const std::string& compression,
|
||||
const std::string& reader_prefix, const std::string& writer_prefix,
|
||||
std::unique_ptr<CapturedFunction> reader_func,
|
||||
std::unique_ptr<CapturedFunction> shard_func);
|
||||
|
||||
@ -138,6 +139,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {
|
||||
const uint64 hash_;
|
||||
const tstring path_;
|
||||
const std::string compression_;
|
||||
const std::string reader_prefix_;
|
||||
const std::string writer_prefix_;
|
||||
|
||||
std::unique_ptr<CapturedFunction> reader_func_;
|
||||
std::unique_ptr<CapturedFunction> shard_func_;
|
||||
@ -294,6 +297,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Passthrough
|
||||
SnapshotDatasetV2Op::Dataset::Dataset(
|
||||
OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
|
||||
const std::string& path, const std::string& compression,
|
||||
const std::string& reader_prefix, const std::string& writer_prefix,
|
||||
std::unique_ptr<CapturedFunction> reader_func,
|
||||
std::unique_ptr<CapturedFunction> shard_func)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
@ -302,6 +306,8 @@ SnapshotDatasetV2Op::Dataset::Dataset(
|
||||
path_(path),
|
||||
compression_(compression == kCompressionAuto ? io::compression::kSnappy
|
||||
: compression),
|
||||
reader_prefix_(reader_prefix),
|
||||
writer_prefix_(writer_prefix),
|
||||
reader_func_(std::move(reader_func)),
|
||||
shard_func_(std::move(shard_func)) {
|
||||
input_->Ref();
|
||||
@ -363,6 +369,12 @@ Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
|
||||
AttrValue compression_attr;
|
||||
b->BuildAttrValue(compression_, &compression_attr);
|
||||
|
||||
AttrValue reader_prefix_attr;
|
||||
b->BuildAttrValue(reader_prefix_, &reader_prefix_attr);
|
||||
|
||||
AttrValue writer_prefix_attr;
|
||||
b->BuildAttrValue(writer_prefix_, &writer_prefix_attr);
|
||||
|
||||
AttrValue reader_func_attr;
|
||||
b->BuildAttrValue(reader_func_->func(), &reader_func_attr);
|
||||
|
||||
@ -386,6 +398,8 @@ Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
|
||||
std::make_pair(3, shard_func_other_args)},
|
||||
/*attrs=*/
|
||||
{{kCompression, compression_attr},
|
||||
{kReaderPrefix, reader_prefix_attr},
|
||||
{kWriterPrefix, writer_prefix_attr},
|
||||
{kReaderFunc, reader_func_attr},
|
||||
{kShardFunc, shard_func_attr},
|
||||
{kReaderFuncTarguments, reader_func_arguments_types_attr},
|
||||
@ -401,7 +415,8 @@ SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
|
||||
IteratorContext* ctx) {
|
||||
return ctx->env()->RecursivelyCreateDir(hash_dir_);
|
||||
return ctx->env()->RecursivelyCreateDir(
|
||||
io::JoinPath(dataset()->writer_prefix_, hash_dir_));
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::SaveInternal(
|
||||
@ -460,7 +475,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
ctx->env(), hash_dir_, &metadata, &file_exists));
|
||||
ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_),
|
||||
&metadata, &file_exists));
|
||||
if (!file_exists) {
|
||||
return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
|
||||
" does not exist any more.");
|
||||
@ -476,7 +492,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
bool file_exists;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
ctx->env(), hash_dir_, &metadata, &file_exists));
|
||||
ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_),
|
||||
&metadata, &file_exists));
|
||||
|
||||
// `pending_snapshot_expiry_seconds` is a legacy option where we would not
|
||||
// write snapshots that we think were still on-going. We decided that this
|
||||
@ -522,8 +539,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
|
||||
|
||||
auto hash_dir =
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_);
|
||||
auto hash_dir = snapshot_util::HashDirectory(
|
||||
io::JoinPath(dataset()->reader_prefix_, dataset()->path_),
|
||||
dataset()->hash_);
|
||||
bool metadata_file_exists;
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
|
||||
@ -624,8 +642,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
|
||||
metadata.add_dtype(output_dtype);
|
||||
}
|
||||
metadata.set_finalized(finalized);
|
||||
tstring hash_directory =
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_);
|
||||
tstring hash_directory = io::JoinPath(
|
||||
dataset()->writer_prefix_,
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_));
|
||||
|
||||
return snapshot_util::WriteMetadataFile(env, hash_directory, &metadata);
|
||||
}
|
||||
@ -678,7 +697,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
|
||||
|
||||
// Creates the run directory.
|
||||
run_dir_ = snapshot_util::RunDirectory(
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_),
|
||||
snapshot_util::HashDirectory(
|
||||
io::JoinPath(dataset()->writer_prefix_, dataset()->path_),
|
||||
dataset()->hash_),
|
||||
run_id_);
|
||||
TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
|
||||
TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
|
||||
@ -757,7 +778,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
|
||||
|
||||
run_id_ = static_cast<uint64>(run_id_signed);
|
||||
run_dir_ = snapshot_util::RunDirectory(
|
||||
snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_),
|
||||
snapshot_util::HashDirectory(
|
||||
io::JoinPath(dataset()->writer_prefix_, dataset()->path_),
|
||||
dataset()->hash_),
|
||||
run_id_);
|
||||
current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
|
||||
|
||||
@ -798,6 +821,14 @@ SnapshotDatasetV2Op::SnapshotDatasetV2Op(OpKernelConstruction* ctx)
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
|
||||
|
||||
if (ctx->HasAttr(kReaderPrefix)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kReaderPrefix, &reader_prefix_));
|
||||
}
|
||||
|
||||
if (ctx->HasAttr(kWriterPrefix)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kWriterPrefix, &writer_prefix_));
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, reader_params,
|
||||
&reader_func_metadata_));
|
||||
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, shard_params,
|
||||
@ -834,8 +865,8 @@ void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
kShardFuncOtherArgs, &shard_func));
|
||||
|
||||
*output = new SnapshotDatasetV2Op::Dataset(
|
||||
ctx, input, graph_hash, path, compression_, std::move(reader_func),
|
||||
std::move(shard_func));
|
||||
ctx, input, graph_hash, path, compression_, reader_prefix_,
|
||||
writer_prefix_, std::move(reader_func), std::move(shard_func));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -44,6 +44,8 @@ class SnapshotDatasetV2Op : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kCompression = "compression";
|
||||
static constexpr const char* const kReaderPrefix = "reader_prefix";
|
||||
static constexpr const char* const kWriterPrefix = "writer_prefix";
|
||||
static constexpr const char* const kCompressionAuto = "AUTO";
|
||||
static constexpr const char* const kReaderFunc = "reader_func";
|
||||
static constexpr const char* const kShardFunc = "shard_func";
|
||||
@ -71,6 +73,8 @@ class SnapshotDatasetV2Op : public UnaryDatasetOpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
|
||||
std::string compression_;
|
||||
std::string reader_prefix_;
|
||||
std::string writer_prefix_;
|
||||
|
||||
std::shared_ptr<FunctionMetadata> reader_func_metadata_;
|
||||
std::shared_ptr<FunctionMetadata> shard_func_metadata_;
|
||||
|
@ -966,6 +966,8 @@ REGISTER_OP("SnapshotDatasetV2")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("compression: string = ''")
|
||||
.Attr("reader_prefix: string = ''")
|
||||
.Attr("writer_prefix: string = ''")
|
||||
.Attr("reader_func: func")
|
||||
.Attr("shard_func: func")
|
||||
.Attr("Treader_func_args: list(type) >= 0")
|
||||
|
@ -334,23 +334,27 @@ def snapshot(path, compression="AUTO", reader_func=None, shard_func=None):
|
||||
|
||||
def _apply_fn(dataset):
|
||||
"""Actual dataset transformation."""
|
||||
project_func = None
|
||||
if shard_func is None:
|
||||
dataset = dataset.enumerate()
|
||||
dataset = _SnapshotDataset(
|
||||
input_dataset=dataset,
|
||||
path=path,
|
||||
compression=compression,
|
||||
reader_func=reader_func,
|
||||
# This will not do the right thing where the graph is built on a
|
||||
# different machine than the executor (e.g. Cloud TPUs).
|
||||
shard_func=lambda index, _: index % multiprocessing.cpu_count())
|
||||
return dataset.map(lambda _, elem: elem)
|
||||
# This sets the amount of parallelism based on the number of CPU cores on
|
||||
# the machine where this Python code is executed, which may differ from
|
||||
# the number of CPU cores where the input pipeline graph is actually
|
||||
# executed (e.g. remote Cloud TPU workers).
|
||||
local_shard_func = lambda index, _: index % multiprocessing.cpu_count()
|
||||
project_func = lambda _, elem: elem
|
||||
else:
|
||||
return _SnapshotDataset(
|
||||
input_dataset=dataset,
|
||||
path=path,
|
||||
compression=compression,
|
||||
reader_func=reader_func,
|
||||
shard_func=shard_func)
|
||||
local_shard_func = shard_func
|
||||
dataset = _SnapshotDataset(
|
||||
input_dataset=dataset,
|
||||
path=path,
|
||||
compression=compression,
|
||||
reader_func=reader_func,
|
||||
# This will not do the right thing where the graph is built on a
|
||||
# different machine than the executor (e.g. Cloud TPUs).
|
||||
shard_func=local_shard_func)
|
||||
if project_func is not None:
|
||||
dataset = dataset.map(project_func)
|
||||
return dataset
|
||||
|
||||
return _apply_fn
|
||||
|
@ -4110,7 +4110,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "SnapshotDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'reader_prefix\', \'writer_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SobolSample"
|
||||
|
@ -4110,7 +4110,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "SnapshotDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'reader_prefix\', \'writer_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SobolSample"
|
||||
|
Loading…
Reference in New Issue
Block a user