[tf.data] Add TF 2.0 support for concurrent iterators over an in-memory cache.

PiperOrigin-RevId: 273995286
This commit is contained in:
Jiri Simsa 2019-10-10 11:01:21 -07:00 committed by TensorFlower Gardener
parent 8e6f9ce4f3
commit 08f41c6216
4 changed files with 106 additions and 117 deletions

View File

@ -224,7 +224,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
lockfile_created_(false),
iteration_completed_(false) {}
~FileWriterIterator() {
~FileWriterIterator() override {
if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
std::vector<string> cache_files;
Status s = dataset()->env_->GetMatchingPaths(
@ -630,6 +630,57 @@ class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDataset {
const Tensor resource_handle_;
};
namespace {
template <typename T, typename FullNameFn>
Status SaveCache(IteratorStateWriter* writer, T* cache, FullNameFn full_name) {
size_t cache_size = cache->size();
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheSize), cache_size));
for (size_t i = 0; i < cache_size; i++) {
auto& element = cache->at(i);
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
element.size()));
for (size_t j = 0; j < element.size(); ++j) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
element[j]));
}
}
return Status::OK();
}
template <typename T, typename FullNameFn>
Status RestoreCache(IteratorContext* ctx, IteratorStateReader* reader, T* cache,
FullNameFn full_name) {
size_t cache_size;
{
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp));
cache_size = static_cast<size_t>(temp);
}
for (size_t i = 0; i < cache_size; ++i) {
std::vector<Tensor> element;
size_t element_size;
{
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)), &temp));
element_size = static_cast<size_t>(temp);
}
element.reserve(element_size);
for (size_t j = 0; j < element_size; ++j) {
element.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
&element.back()));
}
cache->emplace_back(std::move(element));
}
return Status::OK();
}
} // namespace
class CacheDatasetOp::MemoryDataset : public DatasetBase {
public:
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
@ -714,12 +765,7 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
return Status::OK();
}));
}
mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read;
InitializeIterator();
if (mode_ == Mode::read && !cache_->IsCompleted()) {
return errors::Internal(
"Cache should only be read after it has been completed.");
}
return iterator_->Initialize(ctx);
}
@ -739,27 +785,10 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_));
if (cache_->IsClaimed()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheClaimed), ""));
size_t cache_size = cache_->size();
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kCacheSize), cache_size));
for (size_t i = 0; i < cache_size; i++) {
auto& element = cache_->at(i);
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
element.size()));
for (size_t j = 0; j < element.size(); ++j) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
element[j]));
}
}
if (cache_->IsCompleted()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kCacheCompleted), ""));
}
if (cache_->IsCompleted()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), ""));
TF_RETURN_IF_ERROR(SaveCache(
writer, cache_, [this](const string& s) { return full_name(s); }));
}
return SaveInput(writer, iterator_);
}
@ -769,41 +798,12 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
mutex_lock l(mu_);
iterator_.reset();
cache_->Reset();
{
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp));
mode_ = static_cast<Mode>(temp);
}
if (reader->Contains(full_name(kCacheClaimed))) {
CHECK(cache_->MaybeClaim());
size_t cache_size;
{
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCacheSize), &temp));
cache_size = static_cast<size_t>(temp);
}
for (size_t i = 0; i < cache_size; ++i) {
std::vector<Tensor> element;
size_t element_size;
{
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kCache, "[", i, "]", kSizeSuffix)),
&temp));
element_size = static_cast<size_t>(temp);
}
element.reserve(element_size);
for (size_t j = 0; j < element_size; ++j) {
element.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat(kCache, "[", i, "][", j, "]")),
&element.back()));
}
cache_->emplace_back(std::move(element));
}
if (reader->Contains(full_name(kCacheCompleted))) {
cache_->Complete();
}
if (reader->Contains(full_name(kCacheCompleted))) {
std::vector<std::vector<Tensor>> temp_cache;
TF_RETURN_IF_ERROR(
RestoreCache(ctx, reader, &temp_cache,
[this](const string& s) { return full_name(s); }));
cache_->Complete(std::move(temp_cache));
}
InitializeIterator();
TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
@ -814,13 +814,11 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
public:
explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
: DatasetIterator<MemoryDataset>(params), cache_(cache) {
CHECK(cache_);
}
: DatasetIterator<MemoryDataset>(params), cache_(cache) {}
~MemoryWriterIterator() override {
mutex_lock l(mu_);
if (cache_->size() > 0 && !cache_->IsCompleted()) {
if (!temp_cache_.empty() && !cache_->IsCompleted()) {
LOG(WARNING)
<< "The calling iterator did not fully read the dataset being "
"cached. In order to avoid unexpected truncation of the "
@ -843,11 +841,11 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (*end_of_sequence) {
cache_->Complete();
cache_->Complete(std::move(temp_cache_));
return Status::OK();
}
RecordBufferEnqueue(ctx, *out_tensors);
cache_->emplace_back(*out_tensors);
temp_cache_.emplace_back(*out_tensors);
return Status::OK();
}
@ -860,12 +858,22 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (!cache_->IsCompleted()) {
TF_RETURN_IF_ERROR(
SaveCache(writer, &temp_cache_,
[this](const string& s) { return full_name(s); }));
}
return SaveInput(writer, input_impl_);
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name(kCacheCompleted))) {
TF_RETURN_IF_ERROR(
RestoreCache(ctx, reader, &temp_cache_,
[this](const string& s) { return full_name(s); }));
}
return RestoreInput(ctx, reader, input_impl_);
}
@ -873,7 +881,8 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
MemoryCache* const cache_ GUARDED_BY(mu_); // not owned.
}; // MemoryWriterIterator
std::vector<std::vector<Tensor>> temp_cache_ GUARDED_BY(mu_);
}; // MemoryWriterIterator
class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
public:
@ -943,25 +952,21 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
}; // MemoryReaderIterator
void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
switch (mode_) {
case Mode::read:
iterator_ = absl::make_unique<MemoryReaderIterator>(
MemoryReaderIterator::Params{dataset(),
strings::StrCat(prefix(), kImpl)},
cache_);
break;
case Mode::write:
iterator_ = absl::make_unique<MemoryWriterIterator>(
MemoryWriterIterator::Params{dataset(),
strings::StrCat(prefix(), kImpl)},
cache_);
if (cache_->IsCompleted()) {
iterator_ = absl::make_unique<MemoryReaderIterator>(
MemoryReaderIterator::Params{dataset(),
strings::StrCat(prefix(), kImpl)},
cache_);
} else {
iterator_ = absl::make_unique<MemoryWriterIterator>(
MemoryWriterIterator::Params{dataset(),
strings::StrCat(prefix(), kImpl)},
cache_);
}
}
mutex mu_;
MemoryCache* cache_ GUARDED_BY(mu_); // not owned.
enum Mode { read, write };
Mode mode_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
}; // MemoryIterator

View File

@ -32,14 +32,12 @@ const char kMemoryCache[] = "MemoryCache";
string MemoryCache::DebugString() const { return kMemoryCache; }
void MemoryCache::Complete() {
void MemoryCache::Complete(std::vector<std::vector<Tensor>>&& cache) {
mutex_lock l(mu_);
completed_ = true;
}
bool MemoryCache::IsClaimed() {
tf_shared_lock l(mu_);
return claimed_;
if (!completed_) {
cache_ = std::move(cache);
completed_ = true;
}
}
bool MemoryCache::IsCompleted() {
@ -47,18 +45,8 @@ bool MemoryCache::IsCompleted() {
return completed_;
}
bool MemoryCache::MaybeClaim() {
mutex_lock l(mu_);
if (!claimed_) {
claimed_ = true;
return true;
}
return false;
}
void MemoryCache::Reset() {
mutex_lock l(mu_);
claimed_ = false;
completed_ = false;
cache_.clear();
}
@ -69,11 +57,6 @@ const std::vector<Tensor>& MemoryCache::at(int64 index) {
return cache_[index];
}
void MemoryCache::emplace_back(std::vector<Tensor> element) {
mutex_lock l(mu_);
cache_.emplace_back(std::move(element));
}
size_t MemoryCache::size() {
tf_shared_lock l(mu_);
return cache_.size();

View File

@ -34,33 +34,22 @@ class MemoryCache : public ResourceBase {
string DebugString() const override;
// Marks the cache as completed.
void Complete();
// Returns whether the cache is claimed.
bool IsClaimed();
void Complete(std::vector<std::vector<Tensor>>&& cache);
// Returns whether the cache is completed.
bool IsCompleted();
// Attempts to claim the cache, returning whether the cache was claimed.
bool MaybeClaim();
// Resets the cache.
void Reset();
// Returns the element at the given index.
const std::vector<Tensor>& at(int64 index);
// Adds the element to the cache.
void emplace_back(std::vector<Tensor> element);
// Returns the size of the cache.
size_t size();
private:
mutex mu_;
// Determines whether a writer has claimed the cache.
bool claimed_ GUARDED_BY(mu_) = false;
// Determines whether all elements of the dataset have been cached.
bool completed_ GUARDED_BY(mu_) = false;
std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);

View File

@ -352,6 +352,18 @@ class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual(results, range(10))
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
def testCacheV2ConcurrentIterators(self):
dataset = dataset_ops.Dataset.range(10).cache()
it1 = iter(dataset)
it2 = iter(dataset)
for i in range(10):
self.assertEqual(next(it1), i)
self.assertEqual(next(it2), i)
if __name__ == "__main__":
test.main()