[tf.data] Add TF 2.0 support for concurrent iterators over an in-memory cache.
PiperOrigin-RevId: 273995286
This commit is contained in:
parent
8e6f9ce4f3
commit
08f41c6216
@ -224,7 +224,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
|||||||
lockfile_created_(false),
|
lockfile_created_(false),
|
||||||
iteration_completed_(false) {}
|
iteration_completed_(false) {}
|
||||||
|
|
||||||
~FileWriterIterator() {
|
~FileWriterIterator() override {
|
||||||
if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
|
if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
|
||||||
std::vector<string> cache_files;
|
std::vector<string> cache_files;
|
||||||
Status s = dataset()->env_->GetMatchingPaths(
|
Status s = dataset()->env_->GetMatchingPaths(
|
||||||
@ -630,6 +630,57 @@ class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset {
|
|||||||
const Tensor resource_handle_;
|
const Tensor resource_handle_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T, typename FullNameFn>
|
||||||
|
Status SaveCache(IteratorStateWriter* writer, T* cache, FullNameFn full_name) {
|
||||||
|
size_t cache_size = cache->size();
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheSize), cache_size));
|
||||||
|
for (size_t i = 0; i < cache_size; i++) {
|
||||||
|
auto& element = cache->at(i);
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||||
|
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
|
||||||
|
element.size()));
|
||||||
|
for (size_t j = 0; j < element.size(); ++j) {
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||||
|
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
|
||||||
|
element[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename FullNameFn>
|
||||||
|
Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
|
||||||
|
FullNameFn full_name) {
|
||||||
|
size_t cache_size;
|
||||||
|
{
|
||||||
|
int64 temp;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp));
|
||||||
|
cache_size = static_cast<size_t>(temp);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < cache_size; ++i) {
|
||||||
|
std::vector<Tensor> element;
|
||||||
|
size_t element_size;
|
||||||
|
{
|
||||||
|
int64 temp;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||||
|
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), &temp));
|
||||||
|
element_size = static_cast<size_t>(temp);
|
||||||
|
}
|
||||||
|
element.reserve(element_size);
|
||||||
|
for (size_t j = 0; j < element_size; ++j) {
|
||||||
|
element.emplace_back();
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||||
|
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
|
||||||
|
&element.back()));
|
||||||
|
}
|
||||||
|
cache->emplace_back(std::move(element));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
|
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
@ -714,12 +765,7 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read;
|
|
||||||
InitializeIterator();
|
InitializeIterator();
|
||||||
if (mode_ == Mode::read && !cache_->IsCompleted()) {
|
|
||||||
return errors::Internal(
|
|
||||||
"Cache should only be read after it has been completed.");
|
|
||||||
}
|
|
||||||
return iterator_->Initialize(ctx);
|
return iterator_->Initialize(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -739,27 +785,10 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
|
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_));
|
if (cache_->IsCompleted()) {
|
||||||
if (cache_->IsClaimed()) {
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), ""));
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheClaimed), ""));
|
TF_RETURN_IF_ERROR(SaveCache(
|
||||||
size_t cache_size = cache_->size();
|
writer, cache_, [this](const string& s) { return full_name(s); }));
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
writer->WriteScalar(full_name(kCacheSize), cache_size));
|
|
||||||
for (size_t i = 0; i < cache_size; i++) {
|
|
||||||
auto& element = cache_->at(i);
|
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
|
||||||
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
|
|
||||||
element.size()));
|
|
||||||
for (size_t j = 0; j < element.size(); ++j) {
|
|
||||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
|
||||||
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
|
|
||||||
element[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (cache_->IsCompleted()) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
writer->WriteScalar(full_name(kCacheCompleted), ""));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return SaveInput(writer, iterator_);
|
return SaveInput(writer, iterator_);
|
||||||
}
|
}
|
||||||
@ -769,41 +798,12 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
iterator_.reset();
|
iterator_.reset();
|
||||||
cache_->Reset();
|
cache_->Reset();
|
||||||
{
|
if (reader->Contains(full_name(kCacheCompleted))) {
|
||||||
int64 temp;
|
std::vector<std::vector<Tensor>> temp_cache;
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp));
|
TF_RETURN_IF_ERROR(
|
||||||
mode_ = static_cast<Mode>(temp);
|
RestoreCache(ctx, reader, &temp_cache,
|
||||||
}
|
[this](const string& s) { return full_name(s); }));
|
||||||
if (reader->Contains(full_name(kCacheClaimed))) {
|
cache_->Complete(std::move(temp_cache));
|
||||||
CHECK(cache_->MaybeClaim());
|
|
||||||
size_t cache_size;
|
|
||||||
{
|
|
||||||
int64 temp;
|
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp));
|
|
||||||
cache_size = static_cast<size_t>(temp);
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < cache_size; ++i) {
|
|
||||||
std::vector<Tensor> element;
|
|
||||||
size_t element_size;
|
|
||||||
{
|
|
||||||
int64 temp;
|
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
|
||||||
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
|
|
||||||
&temp));
|
|
||||||
element_size = static_cast<size_t>(temp);
|
|
||||||
}
|
|
||||||
element.reserve(element_size);
|
|
||||||
for (size_t j = 0; j < element_size; ++j) {
|
|
||||||
element.emplace_back();
|
|
||||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
|
||||||
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
|
|
||||||
&element.back()));
|
|
||||||
}
|
|
||||||
cache_->emplace_back(std::move(element));
|
|
||||||
}
|
|
||||||
if (reader->Contains(full_name(kCacheCompleted))) {
|
|
||||||
cache_->Complete();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
InitializeIterator();
|
InitializeIterator();
|
||||||
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
|
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
|
||||||
@ -814,13 +814,11 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
|
class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
|
||||||
public:
|
public:
|
||||||
explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
|
explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
|
||||||
: DatasetIterator<MemoryDataset>(params), cache_(cache) {
|
: DatasetIterator<MemoryDataset>(params), cache_(cache) {}
|
||||||
CHECK(cache_);
|
|
||||||
}
|
|
||||||
|
|
||||||
~MemoryWriterIterator() override {
|
~MemoryWriterIterator() override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (cache_->size() > 0 && !cache_->IsCompleted()) {
|
if (!temp_cache_.empty() && !cache_->IsCompleted()) {
|
||||||
LOG(WARNING)
|
LOG(WARNING)
|
||||||
<< "The calling iterator did not fully read the dataset being "
|
<< "The calling iterator did not fully read the dataset being "
|
||||||
"cached. In order to avoid unexpected truncation of the "
|
"cached. In order to avoid unexpected truncation of the "
|
||||||
@ -843,11 +841,11 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||||
if (*end_of_sequence) {
|
if (*end_of_sequence) {
|
||||||
cache_->Complete();
|
cache_->Complete(std::move(temp_cache_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
RecordBufferEnqueue(ctx, *out_tensors);
|
RecordBufferEnqueue(ctx, *out_tensors);
|
||||||
cache_->emplace_back(*out_tensors);
|
temp_cache_.emplace_back(*out_tensors);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -860,12 +858,22 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
|
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
if (!cache_->IsCompleted()) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
SaveCache(writer, &temp_cache_,
|
||||||
|
[this](const string& s) { return full_name(s); }));
|
||||||
|
}
|
||||||
return SaveInput(writer, input_impl_);
|
return SaveInput(writer, input_impl_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
if (!reader->Contains(full_name(kCacheCompleted))) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
RestoreCache(ctx, reader, &temp_cache_,
|
||||||
|
[this](const string& s) { return full_name(s); }));
|
||||||
|
}
|
||||||
return RestoreInput(ctx, reader, input_impl_);
|
return RestoreInput(ctx, reader, input_impl_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -873,7 +881,8 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
mutex mu_;
|
mutex mu_;
|
||||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||||
MemoryCache* const cache_ GUARDED_BY(mu_); // not owned.
|
MemoryCache* const cache_ GUARDED_BY(mu_); // not owned.
|
||||||
}; // MemoryWriterIterator
|
std::vector<std::vector<Tensor>> temp_cache_ GUARDED_BY(mu_);
|
||||||
|
}; // MemoryWriterIterator
|
||||||
|
|
||||||
class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
|
class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
|
||||||
public:
|
public:
|
||||||
@ -943,25 +952,21 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
}; // MemoryReaderIterator
|
}; // MemoryReaderIterator
|
||||||
|
|
||||||
void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
switch (mode_) {
|
if (cache_->IsCompleted()) {
|
||||||
case Mode::read:
|
iterator_ = absl::make_unique<MemoryReaderIterator>(
|
||||||
iterator_ = absl::make_unique<MemoryReaderIterator>(
|
MemoryReaderIterator::Params{dataset(),
|
||||||
MemoryReaderIterator::Params{dataset(),
|
strings::StrCat(prefix(), kImpl)},
|
||||||
strings::StrCat(prefix(), kImpl)},
|
cache_);
|
||||||
cache_);
|
} else {
|
||||||
break;
|
iterator_ = absl::make_unique<MemoryWriterIterator>(
|
||||||
case Mode::write:
|
MemoryWriterIterator::Params{dataset(),
|
||||||
iterator_ = absl::make_unique<MemoryWriterIterator>(
|
strings::StrCat(prefix(), kImpl)},
|
||||||
MemoryWriterIterator::Params{dataset(),
|
cache_);
|
||||||
strings::StrCat(prefix(), kImpl)},
|
|
||||||
cache_);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
MemoryCache* cache_ GUARDED_BY(mu_); // not owned.
|
MemoryCache* cache_ GUARDED_BY(mu_); // not owned.
|
||||||
enum Mode { read, write };
|
|
||||||
Mode mode_ GUARDED_BY(mu_);
|
|
||||||
std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
|
std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
|
||||||
}; // MemoryIterator
|
}; // MemoryIterator
|
||||||
|
|
||||||
|
@ -32,14 +32,12 @@ const char kMemoryCache[] = "MemoryCache";
|
|||||||
|
|
||||||
string MemoryCache::DebugString() const { return kMemoryCache; }
|
string MemoryCache::DebugString() const { return kMemoryCache; }
|
||||||
|
|
||||||
void MemoryCache::Complete() {
|
void MemoryCache::Complete(std::vector<std::vector<Tensor>>&& cache) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
completed_ = true;
|
if (!completed_) {
|
||||||
}
|
cache_ = std::move(cache);
|
||||||
|
completed_ = true;
|
||||||
bool MemoryCache::IsClaimed() {
|
}
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
return claimed_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemoryCache::IsCompleted() {
|
bool MemoryCache::IsCompleted() {
|
||||||
@ -47,18 +45,8 @@ bool MemoryCache::IsCompleted() {
|
|||||||
return completed_;
|
return completed_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemoryCache::MaybeClaim() {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
if (!claimed_) {
|
|
||||||
claimed_ = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MemoryCache::Reset() {
|
void MemoryCache::Reset() {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
claimed_ = false;
|
|
||||||
completed_ = false;
|
completed_ = false;
|
||||||
cache_.clear();
|
cache_.clear();
|
||||||
}
|
}
|
||||||
@ -69,11 +57,6 @@ const std::vector<Tensor>& MemoryCache::at(int64 index) {
|
|||||||
return cache_[index];
|
return cache_[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemoryCache::emplace_back(std::vector<Tensor> element) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
cache_.emplace_back(std::move(element));
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t MemoryCache::size() {
|
size_t MemoryCache::size() {
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
return cache_.size();
|
return cache_.size();
|
||||||
|
@ -34,33 +34,22 @@ class MemoryCache : public ResourceBase {
|
|||||||
string DebugString() const override;
|
string DebugString() const override;
|
||||||
|
|
||||||
// Marks the cache as completed.
|
// Marks the cache as completed.
|
||||||
void Complete();
|
void Complete(std::vector<std::vector<Tensor>>&& cache);
|
||||||
|
|
||||||
// Returns whether the cache is claimed.
|
|
||||||
bool IsClaimed();
|
|
||||||
|
|
||||||
// Returns whether the cache is completed.
|
// Returns whether the cache is completed.
|
||||||
bool IsCompleted();
|
bool IsCompleted();
|
||||||
|
|
||||||
// Attempts to claim the cache, returning whether the cache was claimed.
|
|
||||||
bool MaybeClaim();
|
|
||||||
|
|
||||||
// Resets the cache.
|
// Resets the cache.
|
||||||
void Reset();
|
void Reset();
|
||||||
|
|
||||||
// Returns the element at the given index.
|
// Returns the element at the given index.
|
||||||
const std::vector<Tensor>& at(int64 index);
|
const std::vector<Tensor>& at(int64 index);
|
||||||
|
|
||||||
// Adds the element to the cache.
|
|
||||||
void emplace_back(std::vector<Tensor> element);
|
|
||||||
|
|
||||||
// Returns the size of the cache.
|
// Returns the size of the cache.
|
||||||
size_t size();
|
size_t size();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
// Determines whether a writer has claimed the cache.
|
|
||||||
bool claimed_ GUARDED_BY(mu_) = false;
|
|
||||||
// Determines whether all elements of the dataset have been cached.
|
// Determines whether all elements of the dataset have been cached.
|
||||||
bool completed_ GUARDED_BY(mu_) = false;
|
bool completed_ GUARDED_BY(mu_) = false;
|
||||||
std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
|
std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
|
||||||
|
@ -352,6 +352,18 @@ class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(results, range(10))
|
self.assertAllEqual(results, range(10))
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||||
|
def testCacheV2ConcurrentIterators(self):
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(10).cache()
|
||||||
|
|
||||||
|
it1 = iter(dataset)
|
||||||
|
it2 = iter(dataset)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
self.assertEqual(next(it1), i)
|
||||||
|
self.assertEqual(next(it2), i)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user