From 8211365f9e8aed8cec7b63d7eb992ab104422f8c Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sat, 15 Jun 2019 16:53:26 -0700 Subject: [PATCH] Fix build error on Windows caused by potential int32 overflow. PiperOrigin-RevId: 253411612 --- .../kernels/data/experimental/snapshot_dataset_op.cc | 10 ++++++++++ tensorflow/python/data/experimental/ops/snapshot.py | 5 ++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index f033176b1d3..313e0a8ba71 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -46,6 +46,9 @@ namespace { enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; +// Defaults to 10 GiB per shard. +const int64 kDefaultShardSizeBytes = 10L * 1024 * 1024 * 1024; + const size_t kHeaderSize = sizeof(uint64); const char kSnapshotFilename[] = "snapshot.metadata"; @@ -255,6 +258,13 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("pending_snapshot_expiry_seconds", &pending_snapshot_expiry_seconds_)); + if (shard_size_bytes_ == -1) shard_size_bytes_ = kDefaultShardSizeBytes; + + // Default to 1 day expiry for snapshots. + if (pending_snapshot_expiry_seconds_ == -1) { + pending_snapshot_expiry_seconds_ = 86400; + } + OP_REQUIRES( ctx, compression_ == io::compression::kNone || diff --git a/tensorflow/python/data/experimental/ops/snapshot.py b/tensorflow/python/data/experimental/ops/snapshot.py index 9581f737480..5a4e0cb59c1 100644 --- a/tensorflow/python/data/experimental/ops/snapshot.py +++ b/tensorflow/python/data/experimental/ops/snapshot.py @@ -45,11 +45,10 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): self._writer_path_prefix = ( writer_path_prefix if writer_path_prefix is not None else "") self._shard_size_bytes = ( - shard_size_bytes - if shard_size_bytes is not None else 10 * 1024 * 1024 * 1024) + shard_size_bytes if shard_size_bytes is not None else -1) self._pending_snapshot_expiry_seconds = ( pending_snapshot_expiry_seconds - if pending_snapshot_expiry_seconds is not None else 86400) + if pending_snapshot_expiry_seconds is not None else -1) self._input_dataset = input_dataset self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")