[tf.data] Add TF 2.0 support for concurrent iterators over an in-memory cache.
PiperOrigin-RevId: 273995286
This commit is contained in:
parent
8e6f9ce4f3
commit
08f41c6216
@ -224,7 +224,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||
lockfile_created_(false),
|
||||
iteration_completed_(false) {}
|
||||
|
||||
~FileWriterIterator() {
|
||||
~FileWriterIterator() override {
|
||||
if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
|
||||
std::vector<string> cache_files;
|
||||
Status s = dataset()->env_->GetMatchingPaths(
|
||||
@ -630,6 +630,57 @@ class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset {
|
||||
const Tensor resource_handle_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
template <typename T, typename FullNameFn>
|
||||
Status SaveCache(IteratorStateWriter* writer, T* cache, FullNameFn full_name) {
|
||||
size_t cache_size = cache->size();
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheSize), cache_size));
|
||||
for (size_t i = 0; i < cache_size; i++) {
|
||||
auto& element = cache->at(i);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
|
||||
element.size()));
|
||||
for (size_t j = 0; j < element.size(); ++j) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
|
||||
element[j]));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, typename FullNameFn>
|
||||
Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
|
||||
FullNameFn full_name) {
|
||||
size_t cache_size;
|
||||
{
|
||||
int64 temp;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp));
|
||||
cache_size = static_cast<size_t>(temp);
|
||||
}
|
||||
for (size_t i = 0; i < cache_size; ++i) {
|
||||
std::vector<Tensor> element;
|
||||
size_t element_size;
|
||||
{
|
||||
int64 temp;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), &temp));
|
||||
element_size = static_cast<size_t>(temp);
|
||||
}
|
||||
element.reserve(element_size);
|
||||
for (size_t j = 0; j < element_size; ++j) {
|
||||
element.emplace_back();
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
|
||||
&element.back()));
|
||||
}
|
||||
cache->emplace_back(std::move(element));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
public:
|
||||
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
@ -714,12 +765,7 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read;
|
||||
InitializeIterator();
|
||||
if (mode_ == Mode::read && !cache_->IsCompleted()) {
|
||||
return errors::Internal(
|
||||
"Cache should only be read after it has been completed.");
|
||||
}
|
||||
return iterator_->Initialize(ctx);
|
||||
}
|
||||
|
||||
@ -739,27 +785,10 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_));
|
||||
if (cache_->IsClaimed()) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheClaimed), ""));
|
||||
size_t cache_size = cache_->size();
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kCacheSize), cache_size));
|
||||
for (size_t i = 0; i < cache_size; i++) {
|
||||
auto& element = cache_->at(i);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
|
||||
element.size()));
|
||||
for (size_t j = 0; j < element.size(); ++j) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
|
||||
element[j]));
|
||||
}
|
||||
}
|
||||
if (cache_->IsCompleted()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kCacheCompleted), ""));
|
||||
}
|
||||
if (cache_->IsCompleted()) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), ""));
|
||||
TF_RETURN_IF_ERROR(SaveCache(
|
||||
writer, cache_, [this](const string& s) { return full_name(s); }));
|
||||
}
|
||||
return SaveInput(writer, iterator_);
|
||||
}
|
||||
@ -769,41 +798,12 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
mutex_lock l(mu_);
|
||||
iterator_.reset();
|
||||
cache_->Reset();
|
||||
{
|
||||
int64 temp;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp));
|
||||
mode_ = static_cast<Mode>(temp);
|
||||
}
|
||||
if (reader->Contains(full_name(kCacheClaimed))) {
|
||||
CHECK(cache_->MaybeClaim());
|
||||
size_t cache_size;
|
||||
{
|
||||
int64 temp;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp));
|
||||
cache_size = static_cast<size_t>(temp);
|
||||
}
|
||||
for (size_t i = 0; i < cache_size; ++i) {
|
||||
std::vector<Tensor> element;
|
||||
size_t element_size;
|
||||
{
|
||||
int64 temp;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
|
||||
&temp));
|
||||
element_size = static_cast<size_t>(temp);
|
||||
}
|
||||
element.reserve(element_size);
|
||||
for (size_t j = 0; j < element_size; ++j) {
|
||||
element.emplace_back();
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
|
||||
&element.back()));
|
||||
}
|
||||
cache_->emplace_back(std::move(element));
|
||||
}
|
||||
if (reader->Contains(full_name(kCacheCompleted))) {
|
||||
cache_->Complete();
|
||||
}
|
||||
if (reader->Contains(full_name(kCacheCompleted))) {
|
||||
std::vector<std::vector<Tensor>> temp_cache;
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreCache(ctx, reader, &temp_cache,
|
||||
[this](const string& s) { return full_name(s); }));
|
||||
cache_->Complete(std::move(temp_cache));
|
||||
}
|
||||
InitializeIterator();
|
||||
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
|
||||
@ -814,13 +814,11 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
|
||||
public:
|
||||
explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
|
||||
: DatasetIterator<MemoryDataset>(params), cache_(cache) {
|
||||
CHECK(cache_);
|
||||
}
|
||||
: DatasetIterator<MemoryDataset>(params), cache_(cache) {}
|
||||
|
||||
~MemoryWriterIterator() override {
|
||||
mutex_lock l(mu_);
|
||||
if (cache_->size() > 0 && !cache_->IsCompleted()) {
|
||||
if (!temp_cache_.empty() && !cache_->IsCompleted()) {
|
||||
LOG(WARNING)
|
||||
<< "The calling iterator did not fully read the dataset being "
|
||||
"cached. In order to avoid unexpected truncation of the "
|
||||
@ -843,11 +841,11 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
cache_->Complete();
|
||||
cache_->Complete(std::move(temp_cache_));
|
||||
return Status::OK();
|
||||
}
|
||||
RecordBufferEnqueue(ctx, *out_tensors);
|
||||
cache_->emplace_back(*out_tensors);
|
||||
temp_cache_.emplace_back(*out_tensors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -860,12 +858,22 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!cache_->IsCompleted()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SaveCache(writer, &temp_cache_,
|
||||
[this](const string& s) { return full_name(s); }));
|
||||
}
|
||||
return SaveInput(writer, input_impl_);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!reader->Contains(full_name(kCacheCompleted))) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreCache(ctx, reader, &temp_cache_,
|
||||
[this](const string& s) { return full_name(s); }));
|
||||
}
|
||||
return RestoreInput(ctx, reader, input_impl_);
|
||||
}
|
||||
|
||||
@ -873,7 +881,8 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
MemoryCache* const cache_ GUARDED_BY(mu_); // not owned.
|
||||
}; // MemoryWriterIterator
|
||||
std::vector<std::vector<Tensor>> temp_cache_ GUARDED_BY(mu_);
|
||||
}; // MemoryWriterIterator
|
||||
|
||||
class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
|
||||
public:
|
||||
@ -943,25 +952,21 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
}; // MemoryReaderIterator
|
||||
|
||||
void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
switch (mode_) {
|
||||
case Mode::read:
|
||||
iterator_ = absl::make_unique<MemoryReaderIterator>(
|
||||
MemoryReaderIterator::Params{dataset(),
|
||||
strings::StrCat(prefix(), kImpl)},
|
||||
cache_);
|
||||
break;
|
||||
case Mode::write:
|
||||
iterator_ = absl::make_unique<MemoryWriterIterator>(
|
||||
MemoryWriterIterator::Params{dataset(),
|
||||
strings::StrCat(prefix(), kImpl)},
|
||||
cache_);
|
||||
if (cache_->IsCompleted()) {
|
||||
iterator_ = absl::make_unique<MemoryReaderIterator>(
|
||||
MemoryReaderIterator::Params{dataset(),
|
||||
strings::StrCat(prefix(), kImpl)},
|
||||
cache_);
|
||||
} else {
|
||||
iterator_ = absl::make_unique<MemoryWriterIterator>(
|
||||
MemoryWriterIterator::Params{dataset(),
|
||||
strings::StrCat(prefix(), kImpl)},
|
||||
cache_);
|
||||
}
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
MemoryCache* cache_ GUARDED_BY(mu_); // not owned.
|
||||
enum Mode { read, write };
|
||||
Mode mode_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
|
||||
}; // MemoryIterator
|
||||
|
||||
|
@ -32,14 +32,12 @@ const char kMemoryCache[] = "MemoryCache";
|
||||
|
||||
string MemoryCache::DebugString() const { return kMemoryCache; }
|
||||
|
||||
void MemoryCache::Complete() {
|
||||
void MemoryCache::Complete(std::vector<std::vector<Tensor>>&& cache) {
|
||||
mutex_lock l(mu_);
|
||||
completed_ = true;
|
||||
}
|
||||
|
||||
bool MemoryCache::IsClaimed() {
|
||||
tf_shared_lock l(mu_);
|
||||
return claimed_;
|
||||
if (!completed_) {
|
||||
cache_ = std::move(cache);
|
||||
completed_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool MemoryCache::IsCompleted() {
|
||||
@ -47,18 +45,8 @@ bool MemoryCache::IsCompleted() {
|
||||
return completed_;
|
||||
}
|
||||
|
||||
bool MemoryCache::MaybeClaim() {
|
||||
mutex_lock l(mu_);
|
||||
if (!claimed_) {
|
||||
claimed_ = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void MemoryCache::Reset() {
|
||||
mutex_lock l(mu_);
|
||||
claimed_ = false;
|
||||
completed_ = false;
|
||||
cache_.clear();
|
||||
}
|
||||
@ -69,11 +57,6 @@ const std::vector<Tensor>& MemoryCache::at(int64 index) {
|
||||
return cache_[index];
|
||||
}
|
||||
|
||||
void MemoryCache::emplace_back(std::vector<Tensor> element) {
|
||||
mutex_lock l(mu_);
|
||||
cache_.emplace_back(std::move(element));
|
||||
}
|
||||
|
||||
size_t MemoryCache::size() {
|
||||
tf_shared_lock l(mu_);
|
||||
return cache_.size();
|
||||
|
@ -34,33 +34,22 @@ class MemoryCache : public ResourceBase {
|
||||
string DebugString() const override;
|
||||
|
||||
// Marks the cache as completed.
|
||||
void Complete();
|
||||
|
||||
// Returns whether the cache is claimed.
|
||||
bool IsClaimed();
|
||||
void Complete(std::vector<std::vector<Tensor>>&& cache);
|
||||
|
||||
// Returns whether the cache is completed.
|
||||
bool IsCompleted();
|
||||
|
||||
// Attempts to claim the cache, returning whether the cache was claimed.
|
||||
bool MaybeClaim();
|
||||
|
||||
// Resets the cache.
|
||||
void Reset();
|
||||
|
||||
// Returns the element at the given index.
|
||||
const std::vector<Tensor>& at(int64 index);
|
||||
|
||||
// Adds the element to the cache.
|
||||
void emplace_back(std::vector<Tensor> element);
|
||||
|
||||
// Returns the size of the cache.
|
||||
size_t size();
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
// Determines whether a writer has claimed the cache.
|
||||
bool claimed_ GUARDED_BY(mu_) = false;
|
||||
// Determines whether all elements of the dataset have been cached.
|
||||
bool completed_ GUARDED_BY(mu_) = false;
|
||||
std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
|
||||
|
@ -352,6 +352,18 @@ class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(results, range(10))
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
def testCacheV2ConcurrentIterators(self):
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).cache()
|
||||
|
||||
it1 = iter(dataset)
|
||||
it2 = iter(dataset)
|
||||
|
||||
for i in range(10):
|
||||
self.assertEqual(next(it1), i)
|
||||
self.assertEqual(next(it2), i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user