[tf.data] Log a warning when incomplete file-based cache is finalized.
PiperOrigin-RevId: 339778319 Change-Id: Idaccb9cbb7c315ad36efc3cb23ef8e4086db287a
This commit is contained in:
parent
e0830e4df9
commit
97d32f49b4
@ -38,6 +38,8 @@ namespace data {
|
|||||||
/* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
|
/* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
|
||||||
/* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
|
/* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
|
constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
|
||||||
constexpr char kPaddingSizeStrFormat[] = "%zu";
|
constexpr char kPaddingSizeStrFormat[] = "%zu";
|
||||||
constexpr char kFileDatasetPrefix[] = "File";
|
constexpr char kFileDatasetPrefix[] = "File";
|
||||||
@ -57,6 +59,14 @@ constexpr char kCacheCompleted[] = "cache_completed";
|
|||||||
constexpr char kIndex[] = "index";
|
constexpr char kIndex[] = "index";
|
||||||
constexpr char kImpl[] = "Impl";
|
constexpr char kImpl[] = "Impl";
|
||||||
constexpr char kCacheDataset[] = "CacheDataset";
|
constexpr char kCacheDataset[] = "CacheDataset";
|
||||||
|
constexpr char kIncompleteCacheErrorMessage[] =
|
||||||
|
"The calling iterator did not fully read the dataset being cached. In "
|
||||||
|
"order to avoid unexpected truncation of the dataset, the partially cached "
|
||||||
|
"contents of the dataset will be discarded. This can happen if you have "
|
||||||
|
"an input pipeline similar to `dataset.cache().take(k).repeat()`. You "
|
||||||
|
"should use `dataset.take(k).cache().repeat()` instead.";
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
@ -220,6 +230,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
|||||||
|
|
||||||
~FileWriterIterator() override {
|
~FileWriterIterator() override {
|
||||||
if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
|
if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
|
||||||
|
LOG(WARNING) << kIncompleteCacheErrorMessage;
|
||||||
std::vector<string> cache_files;
|
std::vector<string> cache_files;
|
||||||
Status s = dataset()->env_->GetMatchingPaths(
|
Status s = dataset()->env_->GetMatchingPaths(
|
||||||
strings::StrCat(filename_, "*"), &cache_files);
|
strings::StrCat(filename_, "*"), &cache_files);
|
||||||
@ -754,13 +765,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
|||||||
~MemoryWriterIterator() override {
|
~MemoryWriterIterator() override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (!temp_cache_.empty() && !cache_->IsCompleted()) {
|
if (!temp_cache_.empty() && !cache_->IsCompleted()) {
|
||||||
LOG(WARNING)
|
LOG(WARNING) << kIncompleteCacheErrorMessage;
|
||||||
<< "The calling iterator did not fully read the dataset being "
|
|
||||||
"cached. In order to avoid unexpected truncation of the "
|
|
||||||
"dataset, the partially cached contents of the dataset "
|
|
||||||
"will be discarded. This can happen if you have an input "
|
|
||||||
"pipeline similar to `dataset.cache().take(k).repeat()`. "
|
|
||||||
"You should use `dataset.take(k).cache().repeat()` instead.";
|
|
||||||
cache_->Reset();
|
cache_->Reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,8 +39,7 @@ from tensorflow.python.util import nest
|
|||||||
|
|
||||||
|
|
||||||
def remove_variants(get_next_op):
|
def remove_variants(get_next_op):
|
||||||
# TODO(b/72408568): Remove this once session.run can get
|
# TODO(b/72408568): Remove this once session.run can get variant tensors.
|
||||||
# variant tensors.
|
|
||||||
"""Remove variants from a nest structure, so sess.run will execute."""
|
"""Remove variants from a nest structure, so sess.run will execute."""
|
||||||
|
|
||||||
def _remove_variant(x):
|
def _remove_variant(x):
|
||||||
@ -61,7 +60,7 @@ class DatasetSerializationTestBase(test.TestCase):
|
|||||||
|
|
||||||
# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
|
# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
|
||||||
# (deprecated) saveable `SparseTensorSliceDataset`, once the API
|
# (deprecated) saveable `SparseTensorSliceDataset`, once the API
|
||||||
# `from_sparse_tensor_slices()`and related tests are deleted.
|
# `from_sparse_tensor_slices()` and related tests are deleted.
|
||||||
def run_core_tests(self, ds_fn, num_outputs, sparse_tensors=False):
|
def run_core_tests(self, ds_fn, num_outputs, sparse_tensors=False):
|
||||||
"""Runs the core tests.
|
"""Runs the core tests.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user